- Published on
ControlNet代码及模型结构详解(2)—— DDIM采样过程
- Authors
- Name
- zmy
概述
通过语义图控制文本引导图像生成的过程来自顶向下分析ControlNet的采样过程,以了解其代码和模型结构,分为以下几部分:
- 首先看采样脚本 gradio_seg.py,将模型及DDIM过程都当作一个黑盒,分析传入的各个参数的作用
- 仍然将模型视作一个黑盒,只需要知道模型每次接收噪声图像x、时间步t、条件c,输出预估噪声,着重看DDIM的采样过程
- 研究从配置文件加载模型的过程,以此来了解模型结构及其计算过程
前一篇分析了各个参数的作用,本篇继续将模型视作一个黑盒,着重分析DDIM的采样过程
DDIM采样过程
- 在 gradio_seg.py中,直接调用了ddim_sampler.sample方法来采样,然后解码到pixel space作为生成图像:
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
if config.save_memory:
model.low_vram_shift(is_diffusing=False)
x_samples = model.decode_first_stage(samples)
说明一下参数:
- ddim_steps:采样步数
- num_samples:采样数量,即batch_size,需要注意其要与条件数量相等,每一个条件采一个样本
- shape:图像尺寸,这里是图像在 latent_space 的尺寸,大小为 (4, H // 8, W // 8)
- cond:引导条件,包含语义图控制条件以及prompt和add_prompt的拼接
- unconditional_guidance_scale:无条件引导力度
- unconditional_conditioning:无条件引导的条件,为 negative_prompt以及可能有语义图控制条件(受 guess_mode 影响)
ddim_sampler 类, 位于 cldm/ddim_hacked.py 文件,下面分析其执行流程
注意,分析到应用模型的部分直接将模型视作黑盒,暂时不关心是怎么计算的,只需要知道输入输出即可
sample 方法
首先进入其 sample 方法,主要做了条件校验以及设置采样过程中用到的参数:
- 判断传入的cond数量(第0维大小)是否与 batch_size(n_samples)相等
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
设置一些调度信息(采样器这个类的全局属性),用于控制采样过程:
- 通过传入的采样步数,设置对应数量的 timestep 值
- todo 设置了其他一些DDIM采样参数
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) alphas_cumprod = self.model.alphas_cumprod assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) # ddim sampling parameters ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta,verbose=verbose) self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
调用 self.ddim_sampling,关注以下几个参数:
- conditioning
- size:即shape添加了batch_size维度,由CHW变为了BCHW
- x_T:噪声图像
- unconditional_guidance_scale
- unconditional_conditioning
ddim_sampling 方法
- 首先采样一个纯随机噪声作为噪声图像:
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
- 接下来设置时间步:
if timesteps is None:
# self.ddim_timesteps 为 make_schedule 方法根据传入的采样步数设置
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
- 遍历 timesteps,每一次都调用 self.p_sample_ddim 对 batch 个输入进行一次消噪,调用self.p_sample_ddim的关键参数如下:
- img:噪声图像
- cond
- ts:时间步
- unconditional_guidance_scale
- unconditional_conditioning
p_sample_ddim 方法
- 调用 ControlLDM 模型,传入 噪声图像x、时间步t、条件 c 进行噪声估计
- 需要注意,这里的时间步t以及条件c都是未经过Embedding或者编码的,这些操作都是在模型里面进行的处理
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
- todo 根据前面设置的采样参数进行一系列计算