Published on

ControlNet代码及模型结构详解(2)—— DDIM采样过程

Authors
  • avatar
    Name
    zmy
    Twitter

概述

  • paper:https://arxiv.org/abs/2302.05543

  • github:https://github.com/lllyasviel/ControlNet

  • 通过语义图控制文本引导图像生成的过程来自顶向下分析ControlNet的采样过程,以了解其代码和模型结构,分为以下几部分:

    • 首先看采样脚本 gradio_seg.py,将模型及DDIM过程都当作一个黑盒,分析传入的各个参数的作用
    • 仍然将模型视作一个黑盒,只需要知道模型每次接收噪声图像x、时间步t、条件c,输出预估噪声,着重看DDIM的采样过程
    • 研究从配置文件加载模型的过程,以此来了解模型结构及其计算过程
  • 前一篇分析了各个参数的作用,本篇继续将模型视作一个黑盒,着重分析DDIM的采样过程

DDIM采样过程

  • 在 gradio_seg.py中,直接调用了ddim_sampler.sample方法来采样,然后解码到pixel space作为生成图像:
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                             shape, cond, verbose=False, eta=eta,
                                             unconditional_guidance_scale=scale,
                                             unconditional_conditioning=un_cond)

if config.save_memory:
    model.low_vram_shift(is_diffusing=False)

x_samples = model.decode_first_stage(samples)
  • 说明一下参数:

    • ddim_steps:采样步数
    • num_samples:采样数量,即batch_size,需要注意其要与条件数量相等,每一个条件采一个样本
    • shape:图像尺寸,这里是图像在 latent_space 的尺寸,大小为 (4, H // 8, W // 8)
    • cond:引导条件,包含语义图控制条件以及prompt和add_prompt的拼接
    • unconditional_guidance_scale:无条件引导力度
    • unconditional_conditioning:无条件引导的条件,为 negative_prompt以及可能有语义图控制条件(受 guess_mode 影响)
  • ddim_sampler 类, 位于 cldm/ddim_hacked.py 文件,下面分析其执行流程

  • 注意,分析到应用模型的部分直接将模型视作黑盒,暂时不关心是怎么计算的,只需要知道输入输出即可

sample 方法

首先进入其 sample 方法,主要做了条件校验以及设置采样过程中用到的参数:

  • 判断传入的cond数量(第0维大小)是否与 batch_size(n_samples)相等
if conditioning is not None:
    if isinstance(conditioning, dict):
        ctmp = conditioning[list(conditioning.keys())[0]]
        while isinstance(ctmp, list): ctmp = ctmp[0]
        cbs = ctmp.shape[0]
        if cbs != batch_size:
            print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

    elif isinstance(conditioning, list):
        for ctmp in conditioning:
            if ctmp.shape[0] != batch_size:
                print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

    else:
        if conditioning.shape[0] != batch_size:
            print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
  • 设置一些调度信息(采样器这个类的全局属性),用于控制采样过程:

    • 通过传入的采样步数,设置对应数量的 timestep 值
    • todo 设置了其他一些DDIM采样参数
    self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
    
    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
        alphas_cumprod = self.model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
    
        self.register_buffer('betas', to_torch(self.model.betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
    
        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
    
        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
                                                              alphacums=alphas_cumprod.cpu(),
                                                              ddim_timesteps=self.ddim_timesteps,
                                                              eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
    
  • 调用 self.ddim_sampling,关注以下几个参数:

    • conditioning
    • size:即shape添加了batch_size维度,由CHW变为了BCHW
    • x_T:噪声图像
    • unconditional_guidance_scale
    • unconditional_conditioning

ddim_sampling 方法

  • 首先采样一个纯随机噪声作为噪声图像:
if x_T is None:
    img = torch.randn(shape, device=device)
else:
    img = x_T
  • 接下来设置时间步:
if timesteps is None:
  	# self.ddim_timesteps 为 make_schedule 方法根据传入的采样步数设置
    timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
    subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
    timesteps = self.ddim_timesteps[:subset_end]
  • 遍历 timesteps,每一次都调用 self.p_sample_ddim 对 batch 个输入进行一次消噪,调用self.p_sample_ddim的关键参数如下:
    • img:噪声图像
    • cond
    • ts:时间步
    • unconditional_guidance_scale
    • unconditional_conditioning

p_sample_ddim 方法

  • 调用 ControlLDM 模型,传入 噪声图像x、时间步t、条件 c 进行噪声估计
    • 需要注意,这里的时间步t以及条件c都是未经过Embedding或者编码的,这些操作都是在模型里面进行的处理
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
    model_output = self.model.apply_model(x, t, c)
else:
    model_t = self.model.apply_model(x, t, c)
    model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
    model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
  • todo 根据前面设置的采样参数进行一系列计算