Published on

ControlNet代码及模型结构详解(3)—— 一次估计噪声的计算流程

Authors
  • avatar
    Name
    zmy
    Twitter

概述

  • paper:https://arxiv.org/abs/2302.05543

  • github:https://github.com/lllyasviel/ControlNet

  • 通过语义图控制文本引导图像生成的过程来自顶向下分析ControlNet的采样过程,以了解其代码和模型结构,分为以下几部分:

    • 首先看采样脚本 gradio_seg.py,将模型及DDIM过程都当作一个黑盒,分析传入的各个参数的作用
    • 仍然将模型视作一个黑盒,只需要知道模型每次接收噪声图像x、时间步t、条件c,输出预估噪声,着重看DDIM的采样过程
    • 研究从配置文件加载模型的过程,以此来了解模型结构及其计算过程
  • 前面分析了各个参数的作用,以及DDIM的采样过程,采样一步中涉及走模型估计噪声的部分当作黑盒,本篇将探究模型这个黑盒的计算过程

模型计算过程

  • 主要是 self.model.apply_model(x, t, c) 调用模型进行计算

  • self.model 是cldm/cldm.py下的ControlLDM这个类,调用的是其apply_model方法

ControlLDM.apply_model()

def apply_model(self, x_noisy, t, cond, *args, **kwargs):
    assert isinstance(cond, dict)
    diffusion_model = self.model.diffusion_model

    cond_txt = torch.cat(cond['c_crossattn'], 1)

    if cond['c_concat'] is None:
        eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
    else:
        control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
        control = [c * scale for c, scale in zip(control, self.control_scales)]
        if self.global_average_pooling:
            control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]
        eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)

    return eps
  • 分析上面这段代码:

    • self 指代的就是 ControlLDM,其由 StableDiffsion 和 ControlNet 组成,可以认为 ControlLDM = Stable Diffusion + ControNet
      • PS:这里的 stable diffusion是改进后的,添加了控制条件输入,即添加了ControNet输出的控制条件
    • 第三行的diffusion_model 即 stable diffusion
    • 第五行是取出cond这个字典的c_crossattn字段,也就是prompt字段,cond这个字典是由两个字段组成的,c_concat为控制条件,c_crossattn为文本prompt
    • 接下来的 if-else 判断是看是否传入了控制条件
      • 如果没有控制条件的话,则不需要ControlNet,直接调用stable diffusion估计噪声
      • 否则的话,先将噪声图像、控制条件、时间步、文本prompt送入ControlNet,计算出一个list表示连接到diffusion不同层的control信息,然后乘对应的control_scales;将此作为控制信息送入diffusion_model预估噪声
    • 这里需要继续往下关注的就是以下两个计算:
    control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
    eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
    

self.control_model

  • 主要关注ControlNet的计算,接收噪声图像、控制条件、时间步、文本prompt
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
  • 进入的是ControlNet的forward方法:
def forward(self, x, hint, timesteps, context, **kwargs):
    t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
    emb = self.time_embed(t_emb)

    guided_hint = self.input_hint_block(hint, emb, context)

    outs = []

    h = x.type(self.dtype)
    for module, zero_conv in zip(self.input_blocks, self.zero_convs):
        if guided_hint is not None:
            h = module(h, emb, context)
            h += guided_hint
            guided_hint = None
        else:
            h = module(h, emb, context)
        outs.append(zero_conv(h, emb, context))

    h = self.middle_block(h, emb, context)
    outs.append(self.middle_block_out(h, emb, context))

    return outs
  • 分析这个前向传播的流程:

    • 先处理时间步,通过 timestep_embedding 和 self.time_embed 计算时间步的embedding
    • 接下来通过 input_hint_block 处理控制条件
    • 然后就是遍历模型的每一个block,每一个block跟一个zero_conv,将每个zero_conv的输出结果添加到outs中
    • 最终返回的outs是一个长度为13的list
  • 时间步的处理过程:

  • 时间步骤timesteps维度为 b* 1,b为batch_size

  • time_steps 经过 timestep_embedding 处理后得到 t_emb,维度变为了 b * model_channels,为 b*320

  • t_emb 经过 self.time_embed 处理,self.time_embed 如下所示,维度变为了 b * (4 model_channels)

time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
    linear(model_channels, time_embed_dim),
    nn.SiLU(),
    linear(time_embed_dim, time_embed_dim),
)
  • input_hint_block 流程:

    • 如下,input_hint_block 是一系列二维卷积+SiLU的堆叠,最后是一个零卷积层
    • 作用是将hint从 3 * H * W 变为 model_channels * H // 8 * W // 8
    self.input_hint_block = TimestepEmbedSequential(
        conv_nd(dims, hint_channels, 16, 3, padding=1),
        nn.SiLU(),
        conv_nd(dims, 16, 16, 3, padding=1),
        nn.SiLU(),
        conv_nd(dims, 16, 32, 3, padding=1, stride=2),
        nn.SiLU(),
        conv_nd(dims, 32, 32, 3, padding=1),
        nn.SiLU(),
        conv_nd(dims, 32, 96, 3, padding=1, stride=2),
        nn.SiLU(),
        conv_nd(dims, 96, 96, 3, padding=1),
        nn.SiLU(),
        conv_nd(dims, 96, 256, 3, padding=1, stride=2),
        nn.SiLU(),
        zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
    )
    
  • self.input_blocks 和 self.zero_convs 流程:

    • 首先这两个blocks的第一个block初始化:
      • input_blocks第一个block是一个二维卷积,输入通道是4,输出通道是320,分析上面的代码可以看到,当遍历input_blocks时,guided_hint只在第一块起作用,也就是只与这里这个二维卷积的结果相加,这个二维卷积接收的是隐空间的噪声图像,维度为 b * 4 * H // 8 * W // 8,输出维度为 b * 320 * H // 8 * W // 8,控制条件(语义图)经过hint block后维度从 b * 3 * H * W 变为了b * 320 * H // 8 * W // 8,二者维度匹配,可以相加
      • zero_convs第一个了零卷积的卷积核大小为1,输入输出通道数也相等,其实是一个恒等变换
    self.input_blocks = nn.ModuleList(
        [
            TimestepEmbedSequential(
                conv_nd(dims, in_channels, model_channels, 3, padding=1)
            )
        ]
    )
    self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
    
    • 然后是循环添加 block 以及 zero_conv:
      • 一共添加了13个block以及zero_conv,其中12个EncoderBlock以及1个Middle Block,其中12个EncoderBlock中有四个DownSampleBlock,通过平均池化进行下采样
      • channel_mult 是在配置文件中指定的,值为 [ 1, 2, 4, 4 ] ,self.num_res_blocks = len(channel_mult) * [num_res_blocks],其值为 [ 2, 2, 2, 2 ]
      • 对channel_mult的遍历可以理解为创建的四个SD Encoder Block
    for level, mult in enumerate(channel_mult):
        for nr in range(self.num_res_blocks[level]):
            layers = [
                ResBlock(
                    ch,
                    time_embed_dim,
                    dropout,
                    out_channels=mult * model_channels,
                    dims=dims,
                    use_checkpoint=use_checkpoint,
                    use_scale_shift_norm=use_scale_shift_norm,
                )
            ]
            ch = mult * model_channels
            if ds in attention_resolutions:
                if num_head_channels == -1:
                    dim_head = ch // num_heads
                else:
                    num_heads = ch // num_head_channels
                    dim_head = num_head_channels
                if legacy:
                    # num_heads = 1
                    dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                if exists(disable_self_attentions):
                    disabled_sa = disable_self_attentions[level]
                else:
                    disabled_sa = False
    
                if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
                            disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
                            use_checkpoint=use_checkpoint
                        )
                    )
            self.input_blocks.append(TimestepEmbedSequential(*layers))
            self.zero_convs.append(self.make_zero_conv(ch))
            self._feature_size += ch
            input_block_chans.append(ch)
        if level != len(channel_mult) - 1:
            out_ch = ch
            self.input_blocks.append(
                TimestepEmbedSequential(
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=out_ch,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                        down=True,
                    )
                    if resblock_updown
                    else Downsample(
                        ch, conv_resample, dims=dims, out_channels=out_ch
                    )
                )
            )
            ch = out_ch
            input_block_chans.append(ch)
            self.zero_convs.append(self.make_zero_conv(ch))
            ds *= 2
            self._feature_size += ch
    
    

diffusion_model

  • 前面已经简单分析了 control_model的结构,contro_model的input_blocks其实就是复制的diffusion_model的encoder部分,而decoder部分与encoder正好相反,这里不再赘述
  • 走的是 cldm.py 下的 ControlledUnetModel 的 forward 方法:
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
    hs = []
    with torch.no_grad():
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)
        h = self.middle_block(h, emb, context)

    if control is not None:
        h += control.pop()

    for i, module in enumerate(self.output_blocks):
        if only_mid_control or control is None:
            h = torch.cat([h, hs.pop()], dim=1)
        else:
            h = torch.cat([h, hs.pop() + control.pop()], dim=1)
        h = module(h, emb, context)

    h = h.type(x.dtype)
    return self.out(h)
  • 分析这段代码:
    • hs 存放的是 stable diffusion 的12个encoder block 的输出结果,用于跳跃连接
    • 第10行的h是middle_block的输出结果
    • 13行 h 与 control 的最后一条相加,即stable diffusion的middle block 和 ControlNet 的 middle block输出结果相加,作为第一个decoder block的输入
    • 15行这个循环就是进行decode部分,每个decode block 的输入为:对应的encoder的输出加上对应的ControlNet的输出,然后与上一层输出cat起来作为输入