Post

Latent Bridge Matching 实现解析

首先看lbm/config.py下的BaseConfig类,

1
2
3
4
5
6
7
@dataclass
class BaseConfig:

    name: str = field(init=False)

    def __post_init__(self):
        self.name = self.__class__.__name__

其中 dataclass是Python中一个方便的装饰器,用于自动创建数据类;field(init=False) 方法表示属性 name 被标记为不应该在对象创建时通过参数进行初始化, 也就是说属性 name 不能作为初始化参数传入值; __post_init__ 方法是dataclass 特有的hook,能够在 dataclass 初始化对象之后自动调用这个方法。这里的 __post_init__ 方法表示在完成类的初始化之后,立即调用 __post_init__ 方法,创建成员属性 name,并将其值设置为当前实例所属类的名称(dataclass参考);比如执行如下指令,将会得到类名 BaseConfig

1
2
baseconfig = BaseConfig()
print(baseconfg.name) ## 输出为BaseConfig

from_dict 方法如下所示,用于从给定字典创建一个BaseConfig实例

1
2
3
4
5
6
7
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "BaseConfig":
    try:
        config = cls(**config_dict)
    except (ValidationError, TypeError) as e:
        raise e
    return config

from_yaml 方法用于从一个yaml路径中读取配置文件,然后根据配置文件返回一个BaseConfig实例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
@classmethod
def from_yaml(cls, yaml_path: str) -> "BaseConfig":
    with open(yaml_path, "r") as f:
        ## 异常处理
        try:
            config_dict = safe_load(f)
        except yaml.YAMLError as e:
            raise yaml.YAMLError(
                f"File {yaml_path} not loadable. Maybe not yaml ? \n"
                f"Catch Exception {type(e)} with message: " + str(e)
            ) from e

    ## 从配置字典中取出键名为“name”的值
    config_name = config_dict.pop("name")

    if cls.__name__ != config_name:
        warnings.warn(
            f"You are trying to load a "
            f"`{ cls.__name__}` while a "
            f"`{config_name}` is given."
        )

    ## 根据配置字典返回BaseConfig实例
    return cls.from_dict(config_dict)

##

再来看 lbm/models/base/model_config.py ,这个文件定义了一个如下所示配置类 ModelConfig,继承了上面提到的配置基类 BaseConfig,那么ModelConfig的成员属性 name 的值就是ModelConfig

1
2
3
@dataclass
class ModelConfig(BaseConfig):
    input_key: str = "image"

然后看向,这个文件定义了LBM中所有模块的基类 BaseModel,该基类继承自 nn,Module,初始化的时候赋值四个成员变量,config, input_key, device, dtype,其中 input_key 默认值在 ModelConfig 被指定为 images

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class BaseModel(nn.Module):
    def __init__(self, config: ModelConfig):
        nn.Module.__init__(self)
        self.config = config
        self.input_key = config.input_key
        self.device = torch.device("cpu")
        self.dtype = torch.float32

    def on_fit_start(self, device: torch.device | None = None, *args, **kwargs):
        """Called when the training starts

        Args:
            device (Optional[torch.device], optional): The device to use. Usefull to set
                relevant parameters on the model and embedder to the right device only
                once at the start of the training. Defaults to None.
        """
        if device is not None:
            self.device = device
        self.to(self.device)

    def forward(self, batch: Dict[str, Any], *args, **kwargs):
        raise NotImplementedError("forward method is not implemented")

    def freeze(self):
        """Freeze the model"""
        self.eval()
        for param in self.parameters():
            param.requires_grad = False

    def to(self, *args, **kwargs):
        device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
        self = super().to(
            device=device,
            dtype=dtype,
            non_blocking=non_blocking,
        )

        if device is not None:
            self.device = device
        if dtype is not None:
            self.dtype = dtype
        return self

    def compute_metrics(self, batch: Dict[str, Any], *args, **kwargs):
        """Compute the metrics"""
        return {}

    def sample(self, batch: Dict[str, Any], *args, **kwargs):
        """Sample from the model"""
        return {}

    def log_samples(self, batch: Dict[str, Any], *args, **kwargs):
        """Log the samples"""
        return None

    def on_train_batch_end(self, batch: Dict[str, Any], *args, **kwargs):
        """Update the model an optimization is perforned on a batch."""
        pass

##

看完了base目录,再看embedders目录。embedders目录下面包含了定义embedders的基类的 base 目录和如何进行潜变量连接的 latents_concat 目录,以及一个如何封装embedders的 conditioners_wrapper.py 文件。按规矩,先看 base 目录,该目录下面定义了条件基类配置 base_conditioner_config.py 和条件基类 base_conditioner.py 两个文件,条件基类 BaseConditionerConfig 如下所示,该类继承了配置基类 BaseConfig ,定义了两个成员属性,分别为:默认的输入键名 text ,以及在训练过程中丢取条件的概率 unconditional_conditioning_rate ,然后覆写了 ` post_init 方法,并增添了一个用于判断 unconditional_conditioning_rate` 是否符合规范的断言

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@dataclass
class BaseConditionerConfig(BaseConfig):
    """This is the ClipEmbedderConfig class which defines all the useful parameters to instantiate the model

    Args:

        input_key (str): The key for the input. Defaults to "text".
        unconditional_conditioning_rate (float): Drops the conditioning with this probability during training. Defaults to 0.0.
    """

    input_key: str = "text"
    unconditional_conditioning_rate: float = 0.0

    def __post_init__(self):
        super().__post_init__()

        assert (
            self.unconditional_conditioning_rate >= 0.0
            and self.unconditional_conditioning_rate <= 1.0
        ), "Unconditional conditioning rate should be between 0 and 1"

条件基类 base_conditioner.py 文件首先定义了一个维度到条件类型的字典 DIM2CONDITIONING ,然后定义了一个继承自 BaseModel 的条件基类 BaseConditioner ,相比于 BaseModel 又添加了两个成员属性,分别为维度到条件类型的字典dim2outputkey 和 丢弃条件概率 ucg_rate ,然后还有一个需要被覆写的前向传播函数 forward

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
DIM2CONDITIONING = {
    2: "vector",
    3: "crossattn",
    4: "concat",
}

class BaseConditioner(BaseModel):
    """This is the base class for all the conditioners. This absctacts the conditioning process

    Args:

        config (BaseConditionerConfig): The configuration of the conditioner

    """

    def __init__(self, config: BaseConditionerConfig):
        BaseModel.__init__(self, config)
        self.config = config
        self.input_key = config.input_key
        self.dim2outputkey = DIM2CONDITIONING
        self.ucg_rate = config.unconditional_conditioning_rate

    def forward(
        self, batch: Dict[str, Any], force_zero_embedding: bool = False, *args, **kwargs
    ):
        """
         Forward pass of the embedder.

        Args:

            batch (Dict[str, Any]): A dictionary containing the input data.
            force_zero_embedding (bool): Whether to force zero embedding.
                This will return an embedding with all entries set to 0. Defaults to False.
        """
        raise NotImplementedError("Forward pass must be implemented in child class")

##

lbm/models/embedders/latents_concat 目录下同样也有模型文件latents_concat_embedder_model.py 和配置文件 latents_concat_embedder_config.py,配置文件中增加了两个成员属性 image_keys, mask_keys ;模型文件则定义了一个继承自BaseConditioner, 在给定输入图像以及掩码以计算VAE嵌入的条件类 LatentsConcatEmbedder。其前向传播函数首先检查输入图像batch和输入掩码是否符合规范,具体地通过每一个输入批量都应该有相同的通道维度来判断,具体实现如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class LatentsConcatEmbedder(BaseConditioner):
    """
    Class computing VAE embeddings from given images and resizing the masks.
    Then outputs are then concatenated to the noise in the latent space.

    Args:
        config (LatentsConcatEmbedderConfig): Configs to create the embedder
    """

    def __init__(self, config: LatentsConcatEmbedderConfig):
        BaseConditioner.__init__(self, config)

    def forward(
        self, batch: Dict[str, Any], vae: AutoencoderKLDiffusers, *args, **kwargs
    ) -> dict:
        """
        Args:
            batch (dict): A batch of images to be processed by this embedder. In the batch,
            the images must range between [-1, 1] and the masks range between [0, 1].
            vae (AutoencoderKLDiffusers): VAE

        Returns:
            output (dict): outputs
        """

        # Check if image are of the same size
        ## 检查所有图像输入的高宽是否相同
        dims_list = []
        for image_key in self.config.image_keys:
            dims_list.append(batch[image_key].shape[-2:])
        for mask_key in self.config.mask_keys:
            dims_list.append(batch[mask_key].shape[-2:])
        assert all(
            dims == dims_list[0] for dims in dims_list
        ), "All images and masks must have the same dimensions."

        # Find the latent dimensions
        ## 计算潜变量的维度
        if len(self.config.image_keys) > 0:
            latent_dims = (
                batch[self.config.image_keys[0]].shape[-2] // vae.downsampling_factor,
                batch[self.config.image_keys[0]].shape[-1] // vae.downsampling_factor,
            )
        else:
            latent_dims = (
                batch[self.config.mask_keys[0]].shape[-2] // vae.downsampling_factor,
                batch[self.config.mask_keys[0]].shape[-1] // vae.downsampling_factor,
            )

        outputs = []

        # Resize the masks and concat them
        ## 根据前面确定的潜变量维度来调整掩码的尺寸,然后添加依次添加到一个列表(最后在通道维度上相加)
        for mask_key in self.config.mask_keys:
            curr_latents = F.resize(
                batch[mask_key],
                size=latent_dims,
                interpolation=F.InterpolationMode.BILINEAR,
            )
            outputs.append(curr_latents)

        # Compute VAE embeddings from the images
        ## 计算输入图像的VAE嵌入
        for image_key in self.config.image_keys:
            vae_embs = vae.encode(batch[image_key])
            outputs.append(vae_embs)

        # Concat all the outputs
        ## 在通道维度上相加
        outputs = torch.concat(outputs, dim=1)

        ## 根据维度-条件类型字典来确定outputs字典的键名
        outputs = {self.dim2outputkey[outputs.dim()]: outputs}

        return outputs

接着来看路径 lbm/models/embedders/conditioners_wrapper.py首先该文件下有一个字典 KEY2CATDIM,用于指示在哪一个维度上进行Concat;然后该文件下定义了一个用于分装条件器的类 ConditionerWrapper 的前向传播方法 forward,该方法会传递所有的条件器 conditioner,并且返回一个封装后的字典 wrapper_outputs,其有一个键 cond,对应的值也是一个字典,该字典的键名为条件类型,值是条件张量。前向传播函数的返回字典wrapper_outputs["cond"]的值是所有conditioners的输出conditioner_output中的键key的类型,在映射字典KEY2CATDIM中取出指定的CONCAT维度后,再在该维度上进行CONCAT得到的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
KEY2CATDIM = {
    "vector": 1,
    "crossattn": 2,
    "concat": 1,
}

def forward(
    self,
    batch: Dict[str, Any],
    ucg_keys: List[str] = None,
    set_ucg_rate_zero=False,
    *args,
    **kwargs,
):
    """
    Forward pass through all conditioners

    Args:

        batch: batch of data
        ucg_keys: keys to use for ucg. This will force zero conditioning in all the
            conditioners that have input_keys in ucg_keys
        set_ucg_rate_zero: set the ucg rate to zero for all the conditioners except the ones in ucg_keys

    Returns:

    Dict[str, Any]: The output of the conditioner. The output of the conditioner is a dictionary with the main key "cond" and value
        is a dictionary with the keys as the type of conditioning and the value as the conditioning tensor.
    """
    if ucg_keys is None:
        ucg_keys = []
    wrapper_outputs = dict(cond={})
    for conditioner in self.conditioners:
        if conditioner.input_key in ucg_keys:
            force_zero_embedding = True
        elif conditioner.ucg_rate > 0 and not set_ucg_rate_zero:
            force_zero_embedding = bool(torch.rand(1) < conditioner.ucg_rate)
        else:
            force_zero_embedding = False

        conditioner_output = conditioner.forward(
            batch, force_zero_embedding=force_zero_embedding, *args, **kwargs
        )
        logging.debug(
            f"conditioner:{conditioner.__class__.__name__}, input_key:{conditioner.input_key}, force_ucg_zero_embedding:{force_zero_embedding}"
        )
        for key in conditioner_output:
            logging.debug(
                f"conditioner_output:{key}:{conditioner_output[key].shape}"
            )
            if key in wrapper_outputs["cond"]:
                wrapper_outputs["cond"][key] = torch.cat(
                    [wrapper_outputs["cond"][key], conditioner_output[key]],
                    KEY2CATDIM[key],
                )
            else:
                wrapper_outputs["cond"][key] = conditioner_output[key]

    return wrapper_outputs

LBM将UNet和VAE各自使用了wrapper进行了封装,这里不过多地阐述,下面看路径 lbm/models/lbm/lbm_config.py,里面主要初始化了LBM要用到的一些配置信息,定义了一些需要用到的参数信息以及进行断言对这些参数进行了合法性检查,内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclass
class LBMConfig(ModelConfig):
    """This is the Config for LBM Model class which defines all the useful parameters to be used in the model.
    """
    source_key: str = "lr_image" 
    target_key: str = "image"
    mask_key: Optional[str] = None
    latent_loss_weight: float = 1.0
    latent_loss_type: Literal["l2", "l1"] = "l2" ## 隐空间中使用的损失类型,默认为l2损失
    pixel_loss_type: Literal["l2", "l1", "lpips"] = "l2" ## 像素空间中使用的损失类型,默认为l2损失
    pixel_loss_max_size: int = 512 ## 像素空间的最大尺度,默认值为512
    pixel_loss_weight: float = 0.0 ## 像素空间损失的权重,默认值为0.0
    timestep_sampling: Literal["uniform", "log_normal", "custom_timesteps"] = "uniform" ## timestep sampling的类型
    logit_mean: Optional[float] = 0.0 ## log(正态分布)的均值
    logit_std: Optional[float] = 1.0 ## log(正态分布)的标准差
    selected_timesteps: Optional[List[float]] = None ## 如果使用custom_steps的timestep sampling,该参数定义被选择的timestep列表
    prob: Optional[List[float]] = None ## 如果使用custom_steps的timestep sampling,该参数定义被选择的timestep列表的概率列表
    bridge_noise_sigma: float = 0.001

    def __post_init__(self):
        super().__post_init__()
        if self.timestep_sampling == "log_normal":
            assert isinstance(self.logit_mean, float) and isinstance(
                self.logit_std, float
            ), "logit_mean and logit_std should be float for log_normal timestep sampling"

        if self.timestep_sampling == "custom_timesteps":
            assert isinstance(self.selected_timesteps, list) and isinstance(
                self.prob, list
            ), "timesteps and prob should be list for custom_timesteps timestep sampling"
            assert len(self.selected_timesteps) == len(
                self.prob
            ), "timesteps and prob should be of same length for custom_timesteps timestep sampling"
            assert (
                sum(self.prob) == 1
            ), "prob should sum to 1 for custom_timesteps timestep sampling"

然后来看定义了模型主要框架的文件lbm/models/lbm/lbm_model.py中的类LBMModel,该类定义了LBMModel,初始化函数中定义了需要传入的几个模块,分别是:属于模型的配置信息类LBMConfigconfig,去噪网络denoiser,FlowMatching加噪器asmpling_noise_scheduler,使用KL散度的VAE编码器vae以及属于条件封装类ConditionWrapper的’conditioner’。

该类中定义了一个获取条件的方法_get_conditioning,一个获取$sigma$的方法_get_sigmas以及一个采样方法sample_get_conditioning。方法_get_conditioning会调用属性条件类来获得条件,方法_get_sigmas用于获得时间布调度器的$\sigma$参数。此外该类中还定义了LBM的采样方法sample,具体内容如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
def _get_conditioning(
    self,
    batch: Dict[str, Any],
    ucg_keys: List[str] = None,
    set_ucg_rate_zero=False,
    *args,
    **kwargs,
):
    """
    Get the conditionings
    """
    if self.conditioner is not None:
        return self.conditioner(
            batch,
            ucg_keys=ucg_keys,
            set_ucg_rate_zero=set_ucg_rate_zero,
            vae=self.vae,
            *args,
            **kwargs,
        )
    else:
        return None

def _get_sigmas(
    self, scheduler, timesteps, n_dim=4, dtype=torch.float32, device="cpu"
):
    sigmas = scheduler.sigmas.to(device=device, dtype=dtype)
    schedule_timesteps = scheduler.timesteps.to(device)
    timesteps = timesteps.to(device)
    step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

    sigma = sigmas[step_indices].flatten()
    while len(sigma.shape) < n_dim:
        sigma = sigma.unsqueeze(-1)
    return sigma

@torch.no_grad()
def sample(
    self,
    z: torch.Tensor,
    num_steps: int = 20,
    guidance_scale: float = 1.0,
    conditioner_inputs: Optional[Dict[str, Any]] = None,
    max_samples: Optional[int] = None,
    verbose: bool = False,
):
    self.sampling_noise_scheduler.set_timesteps(
        sigmas=np.linspace(1, 1 / num_steps, num_steps)
    )

    sample = z

    # Get conditioning
    conditioning = self._get_conditioning(
        conditioner_inputs, set_ucg_rate_zero=True, device=z.device
    )

    # If max_samples parameter is provided, limit the number of samples
    if max_samples is not None:
        sample = sample[:max_samples]

    if conditioning:
        conditioning["cond"] = {
            k: v[:max_samples] for k, v in conditioning["cond"].items()
        }

    for i, t in tqdm(
        enumerate(self.sampling_noise_scheduler.timesteps), disable=not verbose
    ):
        if hasattr(self.sampling_noise_scheduler, "scale_model_input"):
            denoiser_input = self.sampling_noise_scheduler.scale_model_input(
                sample, t
            )

        else:
            denoiser_input = sample

        # Predict noise level using denoiser using conditionings
        pred = self.denoiser(
            sample=denoiser_input,
            timestep=t.to(z.device).repeat(denoiser_input.shape[0]),
            conditioning=conditioning,
        )

        # Make one step on the reverse diffusion process
        sample = self.sampling_noise_scheduler.step(
            pred, t, sample, return_dict=False
        )[0]
        if i < len(self.sampling_noise_scheduler.timesteps) - 1:
            timestep = (
                self.sampling_noise_scheduler.timesteps[i + 1]
                .to(z.device)
                .repeat(sample.shape[0])
            )
            sigmas = self._get_sigmas(
                self.sampling_noise_scheduler, timestep, n_dim=4, device=z.device
            )
            sample = sample + self.bridge_noise_sigma * (
                sigmas * (1.0 - sigmas)
            ) ** 0.5 * torch.randn_like(sample)
            sample = sample.to(z.dtype)

    if self.vae is not None:
        decoded_sample = self.vae.decode(sample)

    else:
        decoded_sample = sample

    return decoded_sample
This post is licensed under CC BY 4.0 by the author.

Trending Tags