Published on

ControlNet代码及模型结构详解(1)—— 参数说明

Authors
  • avatar
    Name
    zmy
    Twitter

概述

  • paper:https://arxiv.org/abs/2302.05543
  • github:https://github.com/lllyasviel/ControlNet
  • 通过语义图控制文本引导图像生成的过程来自顶向下分析ControlNet的采样过程,以了解其代码和模型结构,分为以下几部分:
    • 首先看采样脚本 gradio_seg.py,将模型及DDIM过程都当作一个黑盒,分析传入的各个参数的作用
    • 仍然将模型视作一个黑盒,只需要知道模型每次接收噪声图像x、时间步t、条件c,输出预估噪声,着重看DDIM的采样过程
    • 研究从配置文件加载模型的过程,以此来了解模型结构及其计算过程

参数说明

  • 以语义分割图为控制条件为例
  • 主要通过生成图像时接收的所有参数以及参数所影响的地方进行理解
  • process 方法为接收参数,然后返回生成图像的方法,其所有参数如下:
def process(det, input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):

接下来逐个参数进行说明:

  1. det:指定 Detector 类型,即指定对输入的图像的预处理器类型,通过什么预处理器来从原始图像中提取出语义分割图来作为控制条件
global preprocessor

if det == 'Seg_OFCOCO':
  if not isinstance(preprocessor, OneformerCOCODetector):
      preprocessor = OneformerCOCODetector()
if det == 'Seg_OFADE20K':
  if not isinstance(preprocessor, OneformerADE20kDetector):
      preprocessor = OneformerADE20kDetector()
if det == 'Seg_UFADE20K':
  if not isinstance(preprocessor, UniformerDetector):
      preprocessor = UniformerDetector()
  1. input_image:输入图像, 通过 Detector 来生成 detected_map 来作为控制条件
  2. image_resolution:生成图像的尺寸
  3. detect_resolution:语义图尺寸,先将input_image按此尺寸进行resize,然后送入预处理器即返回此尺寸的语义分割图
input_image = HWC3(input_image)

if det == 'None':
    detected_map = input_image.copy()
else:
    detected_map = preprocessor(resize_image(input_image, detect_resolution))
    detected_map = HWC3(detected_map)

img = resize_image(input_image, image_resolution)
H, W, C = img.shape

detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
  1. seed参数为随机种子
if seed == -1:
    seed = random.randint(0, 65535)
seed_everything(seed)
  1. 三个prompt:

    • prompt:文本提示
    • add_prompt:附加的正面 prompt
    • negative_prompt:附加的负面 prompt,如果对生成结果的哪部分不满意,可以将不满意的地方放入此prompt
  2. guess_mode:

# gradio_seg.py
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}

# Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)

# cldm/cldm.py
cond_txt = torch.cat(cond['c_crossattn'], 1)
if cond['c_concat'] is None:
    eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
else:
    control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
    control = [c * scale for c, scale in zip(control, self.control_scales)]
    if self.global_average_pooling:
        control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]
    eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)

针对上面的代码:

  • guess_mode 显式作用在 gradio_seg.py 里,主要控制 un_cond 以及 模型的控制强度,如果Guess Mode选中,在ControlNet部分每层(共13层)的权重会递增,范围从0到1,否则权重全部为1
  • cond_txt 为 prompt 和 add_prompt
  • if cond['c_concat'] is None: 即没有额外控制条件的时候,ControlNet部分就不需要了,只走 diffusion_model
  • 否则的话说明有控制条件,先通过 control_model ,输入噪声图、 control 信息 和 prompt,输出一个长度为13的list,代表Control每一层的输出,然后每一层乘自己对应的权重获得 control list后作用于diffusion_model
  1. strength:控制强度,作用于 Control 每一层的输出(共13层)
  2. scale:无条件引导强度:
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
    model_output = self.model.apply_model(x, t, c)
else:
    model_t = self.model.apply_model(x, t, c)
    model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
    model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
  1. num_sample:采样多少图片
  2. ddim_steps:采样步数
  3. eta:DDIM采样中的eta值。