type
Post
status
Published
date
Mar 17, 2025
slug
ISP/ControlNet
summary
tags
study
tech
develop
category
Academy
icon
password
Modules
UNet
1️⃣ UNetModel 代码结构解析(in
ldm/modules/openaimodel.py)UNetModel 是一个标准的 U-Net 结构,包含:
- 编码器(input_blocks)
- 瓶颈层(middle_block)
- 解码器(output_blocks)
- 最终输出层(self.out)
📌 主要流程
1. input_blocks(编码器):
通过多个 ResBlock(残差块) 和 Downsample 逐层提取特征
attention_resolutions 控制是否加入注意力机制
特征 h 被存入 hs,用于后续跳跃连接
2. middle_block(瓶颈层):
通过 ResBlock + AttentionBlock 进一步处理特征
3. output_blocks(解码器):
通过 Upsample 层恢复分辨率
核心:与 hs 进行跳跃连接(skip connection)
如果有注意力机制,加入 AttentionBlock
4. self.out(最终输出层)
归一化 + SiLU 激活
通过 conv_nd 生成最终输出
2️⃣ ControlledUnetModel 代码解析
ControlledUnetModel 继承 UNetModel,并 增加了 control 变量,用于外部引导 U-Net 处理特定任务。
📌 主要改进点
1. 编码器部分(input_blocks)
- 加速优化:使用 torch.no_grad(),避免计算梯度,提高推理速度
- hs.append(h):存储编码器输出,用于跳跃连接
2. 中间层(middle_block)
- control 的第一次作用:如果提供了 control,则 在 middle_block 之后 进行相加 h += control.pop()
- 这里 control.pop() 代表 取出控制变量的最后一层
3. 解码器部分(output_blocks)
- 加入 control 的方式
- 如果 only_mid_control=True,则 control 仅在 middle_block 作用,后续解码过程正常否则,在解码过程中不断加入 control.pop()
h = torch.cat([h, hs.pop() + control.pop()], dim=1)- h 是当前特征图
- hs.pop() 是来自编码器的跳跃连接
- control.pop() 是控制信号(用于影响解码过程)
ControlNet
ControlNet 解析
ControlNet 是一种改进的 U-Net 结构,它通过引入额外的“控制”信息(如 hint 输入)来 引导 神经网络的生成过程,使其能够更加精准地遵循某些约束或先验信息。在扩散模型(Diffusion Models)或图像生成任务中,ControlNet 可以用于特定任务,如边缘检测、深度信息控制、姿态引导等。
🔹 ControlNet 相较于标准 U-Net 的主要改进
1️⃣ 额外的 hint 信息输入
- 传统的 U-Net 主要依赖输入 x 进行图像生成,而 ControlNet 额外引入 hint 作为辅助信息,提供某种先验引导。
- hint 通过 input_hint_block 进行处理,这个模块是一个深度卷积网络(CNN),它将 hint 逐步下采样并映射到 model_channels 维度,使其与主 U-Net 结构兼容。
2️⃣ zero_convs 额外控制分支
- zero_convs 是一系列 零初始化卷积层,用于让 ControlNet 直接在各个层级学习额外的偏差信息。
- 这意味着 ControlNet 不是直接干预主 U-Net 的计算,而是以 残差(Residual) 方式调整输出,这样可以在保持预训练 U-Net 结构的同时,让 ControlNet 提供新的信息。

3️⃣ 编码器 input_blocks 中引入 guided_hint
- 在前向传播过程中,ControlNet 逐层融合 hint 信息:
1. hint 经过 input_hint_block 处理后,变成 guided_hint
2. 在编码阶段的 第一层 input_blocks,guided_hint 直接加到 h 上,使得 ControlNet 受 hint 影响
3. 后续层不会重复加 guided_hint
🔹 代码解析
1️⃣ __init__ 构造函数
💡 新增的 hint 处理模块
- 这个 input_hint_block 负责 处理 hint 数据,它是一个深度卷积网络:
- 从 hint_channels 开始,逐步增加通道数,并进行三次 下采样 (stride=2),最终与 U-Net 主网络的 model_channels 维度匹配。
- 这样可以确保 hint 信息可以直接用于调整 U-Net 的隐藏层特征。
2️⃣ forward() 计算流程
- 这个 forward() 接受 4 个主要输入:
1. x - 原始图像
2. hint - 额外的控制信号
3. timesteps - 扩散模型的时间步
4. context - 额外的上下文信息(如文本)
💡 计算时间嵌入
- 时间步 timesteps 先经过 timestep_embedding() 变换,再经过 self.time_embed 提取时间特征。
💡 处理 hint 额外信息
- hint 经过 input_hint_block 处理,生成 guided_hint,它将在 ControlNet 结构中 融合到 U-Net 的编码过程中。
💡 编码时融合 guided_hint
- 逐层遍历 input_blocks 进行编码:
- 其他层正常执行 h = module(h, emb, context)。
3️⃣ middle_block 处理
- 在 瓶颈层 middle_block 继续处理,并存入 outs 作为最终控制信号。
🔹 ControlNet 主要作用
- 与 UNet 主要区别:
🔹 ControlNet VS ControlledUnetModel
DDIMSampler
DDIMSampler 主要用于 扩散模型(Diffusion Model) 的高效采样,采用 DDIM (Denoising Diffusion Implicit Models) 进行降噪,相比于传统 DDPM (Denoising Diffusion Probabilistic Model) 具有更快的采样速度,同时可以调整 eta 来控制生成样本的多样性。
🔹 DDIMSampler 主要结构
1️⃣ 初始化
- 绑定 扩散模型 model,并读取 总时间步长 ddpm_num_timesteps。
- schedule="linear" 代表时间步(timestep)的调度方式。
2️⃣ 预计算采样公式