- Published on
ControlNet代码及模型结构详解(3)—— 一次估计噪声的计算流程
- Authors
- Name
- zmy
概述
通过语义图控制文本引导图像生成的过程来自顶向下分析ControlNet的采样过程,以了解其代码和模型结构,分为以下几部分:
- 首先看采样脚本 gradio_seg.py,将模型及DDIM过程都当作一个黑盒,分析传入的各个参数的作用
- 仍然将模型视作一个黑盒,只需要知道模型每次接收噪声图像x、时间步t、条件c,输出预估噪声,着重看DDIM的采样过程
- 研究从配置文件加载模型的过程,以此来了解模型结构及其计算过程
前面分析了各个参数的作用,以及DDIM的采样过程,采样一步中涉及走模型估计噪声的部分当作黑盒,本篇将探究模型这个黑盒的计算过程
模型计算过程
主要是 self.model.apply_model(x, t, c) 调用模型进行计算
self.model 是cldm/cldm.py下的ControlLDM这个类,调用的是其apply_model方法
ControlLDM.apply_model()
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
cond_txt = torch.cat(cond['c_crossattn'], 1)
if cond['c_concat'] is None:
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
else:
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
control = [c * scale for c, scale in zip(control, self.control_scales)]
if self.global_average_pooling:
control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
return eps
分析上面这段代码:
- self 指代的就是 ControlLDM,其由 StableDiffsion 和 ControlNet 组成,可以认为 ControlLDM = Stable Diffusion + ControNet
- PS:这里的 stable diffusion是改进后的,添加了控制条件输入,即添加了ControNet输出的控制条件
- 第三行的diffusion_model 即 stable diffusion
- 第五行是取出cond这个字典的c_crossattn字段,也就是prompt字段,cond这个字典是由两个字段组成的,c_concat为控制条件,c_crossattn为文本prompt
- 接下来的 if-else 判断是看是否传入了控制条件
- 如果没有控制条件的话,则不需要ControlNet,直接调用stable diffusion估计噪声
- 否则的话,先将噪声图像、控制条件、时间步、文本prompt送入ControlNet,计算出一个list表示连接到diffusion不同层的control信息,然后乘对应的control_scales;将此作为控制信息送入diffusion_model预估噪声
- 这里需要继续往下关注的就是以下两个计算:
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt) eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
- self 指代的就是 ControlLDM,其由 StableDiffsion 和 ControlNet 组成,可以认为 ControlLDM = Stable Diffusion + ControNet
self.control_model
- 主要关注ControlNet的计算,接收噪声图像、控制条件、时间步、文本prompt
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
- 进入的是ControlNet的forward方法:
def forward(self, x, hint, timesteps, context, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context)
outs = []
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
return outs
分析这个前向传播的流程:
- 先处理时间步,通过 timestep_embedding 和 self.time_embed 计算时间步的embedding
- 接下来通过 input_hint_block 处理控制条件
- 然后就是遍历模型的每一个block,每一个block跟一个zero_conv,将每个zero_conv的输出结果添加到outs中
- 最终返回的outs是一个长度为13的list
时间步的处理过程:
时间步骤timesteps维度为 b* 1,b为batch_size
time_steps 经过 timestep_embedding 处理后得到 t_emb,维度变为了 b * model_channels,为 b*320
t_emb 经过 self.time_embed 处理,self.time_embed 如下所示,维度变为了 b * (4 model_channels)
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
input_hint_block 流程:
- 如下,input_hint_block 是一系列二维卷积+SiLU的堆叠,最后是一个零卷积层
- 作用是将hint从 3 * H * W 变为 model_channels * H // 8 * W // 8
self.input_hint_block = TimestepEmbedSequential( conv_nd(dims, hint_channels, 16, 3, padding=1), nn.SiLU(), conv_nd(dims, 16, 16, 3, padding=1), nn.SiLU(), conv_nd(dims, 16, 32, 3, padding=1, stride=2), nn.SiLU(), conv_nd(dims, 32, 32, 3, padding=1), nn.SiLU(), conv_nd(dims, 32, 96, 3, padding=1, stride=2), nn.SiLU(), conv_nd(dims, 96, 96, 3, padding=1), nn.SiLU(), conv_nd(dims, 96, 256, 3, padding=1, stride=2), nn.SiLU(), zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) )
self.input_blocks 和 self.zero_convs 流程:
- 首先这两个blocks的第一个block初始化:
- input_blocks第一个block是一个二维卷积,输入通道是4,输出通道是320,分析上面的代码可以看到,当遍历input_blocks时,guided_hint只在第一块起作用,也就是只与这里这个二维卷积的结果相加,这个二维卷积接收的是隐空间的噪声图像,维度为 b * 4 * H // 8 * W // 8,输出维度为 b * 320 * H // 8 * W // 8,控制条件(语义图)经过hint block后维度从 b * 3 * H * W 变为了b * 320 * H // 8 * W // 8,二者维度匹配,可以相加
- zero_convs第一个了零卷积的卷积核大小为1,输入输出通道数也相等,其实是一个恒等变换
self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
- 然后是循环添加 block 以及 zero_conv:
- 一共添加了13个block以及zero_conv,其中12个EncoderBlock以及1个Middle Block,其中12个EncoderBlock中有四个DownSampleBlock,通过平均池化进行下采样
- channel_mult 是在配置文件中指定的,值为 [ 1, 2, 4, 4 ] ,self.num_res_blocks = len(channel_mult) * [num_res_blocks],其值为 [ 2, 2, 2, 2 ]
- 对channel_mult的遍历可以理解为创建的四个SD Encoder Block
for level, mult in enumerate(channel_mult): for nr in range(self.num_res_blocks[level]): layers = [ ResBlock( ch, time_embed_dim, dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] else: disabled_sa = False if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self.zero_convs.append(self.make_zero_conv(ch)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch ) ) ) ch = out_ch input_block_chans.append(ch) self.zero_convs.append(self.make_zero_conv(ch)) ds *= 2 self._feature_size += ch
- 首先这两个blocks的第一个block初始化:
diffusion_model
- 前面已经简单分析了 control_model的结构,contro_model的input_blocks其实就是复制的diffusion_model的encoder部分,而decoder部分与encoder正好相反,这里不再赘述
- 走的是 cldm.py 下的 ControlledUnetModel 的 forward 方法:
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
hs = []
with torch.no_grad():
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
if control is not None:
h += control.pop()
for i, module in enumerate(self.output_blocks):
if only_mid_control or control is None:
h = torch.cat([h, hs.pop()], dim=1)
else:
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
return self.out(h)
- 分析这段代码:
- hs 存放的是 stable diffusion 的12个encoder block 的输出结果,用于跳跃连接
- 第10行的h是middle_block的输出结果
- 13行 h 与 control 的最后一条相加,即stable diffusion的middle block 和 ControlNet 的 middle block输出结果相加,作为第一个decoder block的输入
- 15行这个循环就是进行decode部分,每个decode block 的输入为:对应的encoder的输出加上对应的ControlNet的输出,然后与上一层输出cat起来作为输入