Latent Bridge Matching 实现解析
首先看lbm/config.py下的BaseConfig类,
1
2
3
4
5
6
7
@dataclass
class BaseConfig:
name: str = field(init=False)
def __post_init__(self):
self.name = self.__class__.__name__
其中 dataclass
是Python中一个方便的装饰器,用于自动创建数据类;field(init=False)
方法表示属性 name
被标记为不应该在对象创建时通过参数进行初始化, 也就是说属性 name
不能作为初始化参数传入值; __post_init__
方法是dataclass 特有的hook,能够在 dataclass 初始化对象之后自动调用这个方法。这里的 __post_init__
方法表示在完成类的初始化之后,立即调用 __post_init__
方法,创建成员属性 name
,并将其值设置为当前实例所属类的名称(dataclass参考);比如执行如下指令,将会得到类名 BaseConfig
1
2
baseconfig = BaseConfig()
print(baseconfg.name) ## 输出为BaseConfig
from_dict
方法如下所示,用于从给定字典创建一个BaseConfig实例
1
2
3
4
5
6
7
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "BaseConfig":
try:
config = cls(**config_dict)
except (ValidationError, TypeError) as e:
raise e
return config
from_yaml
方法用于从一个yaml路径中读取配置文件,然后根据配置文件返回一个BaseConfig实例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
@classmethod
def from_yaml(cls, yaml_path: str) -> "BaseConfig":
with open(yaml_path, "r") as f:
## 异常处理
try:
config_dict = safe_load(f)
except yaml.YAMLError as e:
raise yaml.YAMLError(
f"File {yaml_path} not loadable. Maybe not yaml ? \n"
f"Catch Exception {type(e)} with message: " + str(e)
) from e
## 从配置字典中取出键名为“name”的值
config_name = config_dict.pop("name")
if cls.__name__ != config_name:
warnings.warn(
f"You are trying to load a "
f"`{ cls.__name__}` while a "
f"`{config_name}` is given."
)
## 根据配置字典返回BaseConfig实例
return cls.from_dict(config_dict)
##
再来看 lbm/models/base/model_config.py
,这个文件定义了一个如下所示配置类 ModelConfig
,继承了上面提到的配置基类 BaseConfig
,那么ModelConfig的成员属性 name
的值就是ModelConfig
1
2
3
@dataclass
class ModelConfig(BaseConfig):
input_key: str = "image"
然后看向,这个文件定义了LBM中所有模块的基类 BaseModel
,该基类继承自 nn,Module
,初始化的时候赋值四个成员变量,config, input_key, device, dtype
,其中 input_key
默认值在 ModelConfig
被指定为 images
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class BaseModel(nn.Module):
def __init__(self, config: ModelConfig):
nn.Module.__init__(self)
self.config = config
self.input_key = config.input_key
self.device = torch.device("cpu")
self.dtype = torch.float32
def on_fit_start(self, device: torch.device | None = None, *args, **kwargs):
"""Called when the training starts
Args:
device (Optional[torch.device], optional): The device to use. Usefull to set
relevant parameters on the model and embedder to the right device only
once at the start of the training. Defaults to None.
"""
if device is not None:
self.device = device
self.to(self.device)
def forward(self, batch: Dict[str, Any], *args, **kwargs):
raise NotImplementedError("forward method is not implemented")
def freeze(self):
"""Freeze the model"""
self.eval()
for param in self.parameters():
param.requires_grad = False
def to(self, *args, **kwargs):
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
self = super().to(
device=device,
dtype=dtype,
non_blocking=non_blocking,
)
if device is not None:
self.device = device
if dtype is not None:
self.dtype = dtype
return self
def compute_metrics(self, batch: Dict[str, Any], *args, **kwargs):
"""Compute the metrics"""
return {}
def sample(self, batch: Dict[str, Any], *args, **kwargs):
"""Sample from the model"""
return {}
def log_samples(self, batch: Dict[str, Any], *args, **kwargs):
"""Log the samples"""
return None
def on_train_batch_end(self, batch: Dict[str, Any], *args, **kwargs):
"""Update the model an optimization is perforned on a batch."""
pass
##
看完了base目录,再看embedders目录。embedders目录下面包含了定义embedders的基类的 base
目录和如何进行潜变量连接的 latents_concat
目录,以及一个如何封装embedders的 conditioners_wrapper.py
文件。按规矩,先看 base
目录,该目录下面定义了条件基类配置 base_conditioner_config.py
和条件基类 base_conditioner.py
两个文件,条件基类 BaseConditionerConfig
如下所示,该类继承了配置基类 BaseConfig
,定义了两个成员属性,分别为:默认的输入键名 text
,以及在训练过程中丢取条件的概率 unconditional_conditioning_rate
,然后覆写了 ` post_init 方法,并增添了一个用于判断
unconditional_conditioning_rate` 是否符合规范的断言
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@dataclass
class BaseConditionerConfig(BaseConfig):
"""This is the ClipEmbedderConfig class which defines all the useful parameters to instantiate the model
Args:
input_key (str): The key for the input. Defaults to "text".
unconditional_conditioning_rate (float): Drops the conditioning with this probability during training. Defaults to 0.0.
"""
input_key: str = "text"
unconditional_conditioning_rate: float = 0.0
def __post_init__(self):
super().__post_init__()
assert (
self.unconditional_conditioning_rate >= 0.0
and self.unconditional_conditioning_rate <= 1.0
), "Unconditional conditioning rate should be between 0 and 1"
条件基类 base_conditioner.py
文件首先定义了一个维度到条件类型的字典 DIM2CONDITIONING
,然后定义了一个继承自 BaseModel
的条件基类 BaseConditioner
,相比于 BaseModel
又添加了两个成员属性,分别为维度到条件类型的字典dim2outputkey
和 丢弃条件概率 ucg_rate
,然后还有一个需要被覆写的前向传播函数 forward
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
DIM2CONDITIONING = {
2: "vector",
3: "crossattn",
4: "concat",
}
class BaseConditioner(BaseModel):
"""This is the base class for all the conditioners. This absctacts the conditioning process
Args:
config (BaseConditionerConfig): The configuration of the conditioner
"""
def __init__(self, config: BaseConditionerConfig):
BaseModel.__init__(self, config)
self.config = config
self.input_key = config.input_key
self.dim2outputkey = DIM2CONDITIONING
self.ucg_rate = config.unconditional_conditioning_rate
def forward(
self, batch: Dict[str, Any], force_zero_embedding: bool = False, *args, **kwargs
):
"""
Forward pass of the embedder.
Args:
batch (Dict[str, Any]): A dictionary containing the input data.
force_zero_embedding (bool): Whether to force zero embedding.
This will return an embedding with all entries set to 0. Defaults to False.
"""
raise NotImplementedError("Forward pass must be implemented in child class")
##
lbm/models/embedders/latents_concat
目录下同样也有模型文件latents_concat_embedder_model.py
和配置文件 latents_concat_embedder_config.py
,配置文件中增加了两个成员属性 image_keys, mask_keys
;模型文件则定义了一个继承自BaseConditioner
, 在给定输入图像以及掩码以计算VAE嵌入的条件类 LatentsConcatEmbedder
。其前向传播函数首先检查输入图像batch
和输入掩码是否符合规范,具体地通过每一个输入批量都应该有相同的通道维度来判断,具体实现如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class LatentsConcatEmbedder(BaseConditioner):
"""
Class computing VAE embeddings from given images and resizing the masks.
Then outputs are then concatenated to the noise in the latent space.
Args:
config (LatentsConcatEmbedderConfig): Configs to create the embedder
"""
def __init__(self, config: LatentsConcatEmbedderConfig):
BaseConditioner.__init__(self, config)
def forward(
self, batch: Dict[str, Any], vae: AutoencoderKLDiffusers, *args, **kwargs
) -> dict:
"""
Args:
batch (dict): A batch of images to be processed by this embedder. In the batch,
the images must range between [-1, 1] and the masks range between [0, 1].
vae (AutoencoderKLDiffusers): VAE
Returns:
output (dict): outputs
"""
# Check if image are of the same size
## 检查所有图像输入的高宽是否相同
dims_list = []
for image_key in self.config.image_keys:
dims_list.append(batch[image_key].shape[-2:])
for mask_key in self.config.mask_keys:
dims_list.append(batch[mask_key].shape[-2:])
assert all(
dims == dims_list[0] for dims in dims_list
), "All images and masks must have the same dimensions."
# Find the latent dimensions
## 计算潜变量的维度
if len(self.config.image_keys) > 0:
latent_dims = (
batch[self.config.image_keys[0]].shape[-2] // vae.downsampling_factor,
batch[self.config.image_keys[0]].shape[-1] // vae.downsampling_factor,
)
else:
latent_dims = (
batch[self.config.mask_keys[0]].shape[-2] // vae.downsampling_factor,
batch[self.config.mask_keys[0]].shape[-1] // vae.downsampling_factor,
)
outputs = []
# Resize the masks and concat them
## 根据前面确定的潜变量维度来调整掩码的尺寸,然后添加依次添加到一个列表(最后在通道维度上相加)
for mask_key in self.config.mask_keys:
curr_latents = F.resize(
batch[mask_key],
size=latent_dims,
interpolation=F.InterpolationMode.BILINEAR,
)
outputs.append(curr_latents)
# Compute VAE embeddings from the images
## 计算输入图像的VAE嵌入
for image_key in self.config.image_keys:
vae_embs = vae.encode(batch[image_key])
outputs.append(vae_embs)
# Concat all the outputs
## 在通道维度上相加
outputs = torch.concat(outputs, dim=1)
## 根据维度-条件类型字典来确定outputs字典的键名
outputs = {self.dim2outputkey[outputs.dim()]: outputs}
return outputs
接着来看路径 lbm/models/embedders/conditioners_wrapper.py
,首先该文件下有一个字典 KEY2CATDIM
,用于指示在哪一个维度上进行Concat;然后该文件下定义了一个用于分装条件器的类 ConditionerWrapper
的前向传播方法 forward
,该方法会传递所有的条件器 conditioner
,并且返回一个封装后的字典 wrapper_outputs
,其有一个键 cond
,对应的值也是一个字典,该字典的键名为条件类型,值是条件张量。前向传播函数的返回字典wrapper_outputs["cond"]
的值是所有conditioners
的输出conditioner_output
中的键key
的类型,在映射字典KEY2CATDIM
中取出指定的CONCAT维度后,再在该维度上进行CONCAT得到的:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
KEY2CATDIM = {
"vector": 1,
"crossattn": 2,
"concat": 1,
}
def forward(
self,
batch: Dict[str, Any],
ucg_keys: List[str] = None,
set_ucg_rate_zero=False,
*args,
**kwargs,
):
"""
Forward pass through all conditioners
Args:
batch: batch of data
ucg_keys: keys to use for ucg. This will force zero conditioning in all the
conditioners that have input_keys in ucg_keys
set_ucg_rate_zero: set the ucg rate to zero for all the conditioners except the ones in ucg_keys
Returns:
Dict[str, Any]: The output of the conditioner. The output of the conditioner is a dictionary with the main key "cond" and value
is a dictionary with the keys as the type of conditioning and the value as the conditioning tensor.
"""
if ucg_keys is None:
ucg_keys = []
wrapper_outputs = dict(cond={})
for conditioner in self.conditioners:
if conditioner.input_key in ucg_keys:
force_zero_embedding = True
elif conditioner.ucg_rate > 0 and not set_ucg_rate_zero:
force_zero_embedding = bool(torch.rand(1) < conditioner.ucg_rate)
else:
force_zero_embedding = False
conditioner_output = conditioner.forward(
batch, force_zero_embedding=force_zero_embedding, *args, **kwargs
)
logging.debug(
f"conditioner:{conditioner.__class__.__name__}, input_key:{conditioner.input_key}, force_ucg_zero_embedding:{force_zero_embedding}"
)
for key in conditioner_output:
logging.debug(
f"conditioner_output:{key}:{conditioner_output[key].shape}"
)
if key in wrapper_outputs["cond"]:
wrapper_outputs["cond"][key] = torch.cat(
[wrapper_outputs["cond"][key], conditioner_output[key]],
KEY2CATDIM[key],
)
else:
wrapper_outputs["cond"][key] = conditioner_output[key]
return wrapper_outputs