本文最后更新于 2025年8月7日 晚上
2025/08/04 - 至今
待完善
主要参考资料
本篇文章主要基于《Elucidating the Design Space of Diffusion-Based
Generative Models》以及以下几个视频、博文。
数学推导参考:B站讲解 Double童发发 - EDM论文讲解之扩散模型通用框架系列、苏剑林-科学空间-一般框架之SDE篇
整体脉络、相关图片参考:B站讲座 -【双语】NVIDIA - EDM 以及 论文相关博文
本文按照上述讲座顺序,分为四个部分:
- 第一部分:通用框架
- 第二部分:确定性采样
- 第三部分:随机采样
- 第四部分:预处理与训练
通用框架
通过连续时间随机过程将数据扰动至纯噪声
SDE视角下,扩散模型出现的两个问题
- 网络无法逼近数据的 真实得分:这导致采样轨迹朝着错误方向演进。上图当前轨迹(黑色)为错误采样链,在左侧与目标之间存在error;
- 通过有限步长逼近 理想轨迹:为了近似绿色轨迹,如果每步太长,则会显著偏离目标轨迹;如果每步太短,则会消耗过多计算资源。
事实上,以上两点分别对应于 Traning过程 和 Sampling过程。
轨迹
两种选择:
- 在目标附近采用更短的步长;
- 扭曲噪声调度,以在目标附近花费更多时间。
信号尺度差异过大会导致神经网络难以训练
确定性采样
不改变训练过程,只关注于采样过程。
改进Heun步
dx±=probability flow ODE−σ˙(t)σ(t)∇xlogp(x;σ(t))dt±Langevin diffusion SDE(deterministic noise decay + noise injection)β(t)σ(t)2∇xlogp(x;σ(t))dt+2β(t)σ(t)dωt
纠偏过程
随机采样
预处理与训练
回顾两个误差来源:
Preconditioning
-
为了让CNN更易处理:
- (A) 始终向网络输入 单位标准差的输入(unit stdev inputs)
- (B) ……并使用 单位标准差的目标(unit stdev targets)进行训练
-
网络会产生误差。我们应当:
-
我们的噪声水平变化极大,因此这一点至关重要!
改进Heun步
数学推导
Song等人将其前向SDE定义为:
dx=f(x,t)dt+g(t)dωt
其中
- ωt 是标准维纳过程(Wiener process);
- f(⋅,t):Rd→Rd 为漂移系数;
- g(⋅):R→R 为扩散系数
- d 为数据集维度。
在 方差保持(VP) 和 方差爆炸(VE) 公式中,这两个系数的选择不同;且 f(⋅,t) 始终具有 f(x,t)=f(t)x 的形式(其中 f(⋅):R→R)。因此,该SDE可等价改写为:
dx=f(t)xdt+g(t)dωt
该SDE的扰动核具有如下一般形式:
p0t(x(t)∣x(0))=N(x(t); s(t)x(0), s2(t)σ2(t)I)
其中 N(x; μ, Σ) 表示均值为 μ、协方差为 Σ 的高斯分布在 x 处的概率密度函数,且:
s(t)=exp(∫0tf(ξ)dξ),σ(t)=∫0ts2(ξ)g2(ξ)dξ
边缘分布 pt(x) 可通过对初始状态 x(0) 积分扰动核得到:
pt(x)=∫Rdp0t(x∣x0)pdata(x0)dx0
Song等人[[49]]定义了 概率流ODE,使其服从上述相同的边缘分布 pt(x):
dx=[f(t)x−21g(t)2∇xlogpt(x)]dt
原始ODE公式(式14)围绕函数 f 和 g 构建,这两个函数直接对应公式中的特定项;而边缘分布的性质(式12)只能通过这些函数间接推导。然而,f 和 g 本身的实际意义有限,相比之下,边缘分布在 模型训练、采样过程初始化,以及理解ODE实际行为等方面都至关重要。
由于概率流ODE的核心思想是匹配一组特定的边缘分布,将边缘分布视为“一等公民”,并直接基于 σ(t) 和 s(t) 定义ODE(无需依赖 f(t) 和 g(t)),是更合理的做法。
我们从式13的边缘分布闭式表达入手推导:
pt(x)=∫Rdp0t(x∣x0)pdata(x0)dx0=∫Rdpdata(x0)N(x; s(t)x0, s(t)2σ(t)2I)dx0=∫Rdpdata(x0)[s(t)−dN(x/s(t); x0, σ(t)2I)]dx0=s(t)−d∫Rdpdata(x0)N(x/s(t); x0, σ(t)2I)dx0=s(t)−d[pdata∗N(0, σ(t)2I)](x/s(t))
其中 pa∗pb 表示概率密度函数 pa 和 pb 的卷积。方括号内的表达式对应对 pdata 进行“平滑化”的结果(通过向样本添加独立同分布的高斯噪声实现)。我们将该分布记为 p(x;σ):
p(x;σ)=pdata∗N(0, σ2I),pt(x)=s(t)−dp(x/s(t); σ(t))
现在,我们可以用 p(x;σ) 代替 pt(x),重新表达概率流ODE(式14):
dx=[f(t)x−21g(t)2∇xlog[pt(x)]]dt=[f(t)x−21g(t)2∇xlog[s(t)−dp(x/s(t); σ(t))]]dt=[f(t)x−21g(t)2(∇xlogs(t)−d+∇xlogp(x/s(t); σ(t)))]dt=[f(t)x−21g(t)2∇xlogp(x/s(t); σ(t))]dt