本文最后更新于 2025年9月25日 凌晨
                  
                
              
            
            
              
                
                
2025/08/04 - 至今
待完善
主要参考资料
本篇文章主要基于《Elucidating the Design Space of Diffusion-Based
Generative Models》以及以下几个视频、博文。
数学推导参考:B站讲解 Double童发发 - EDM论文讲解之扩散模型通用框架系列、苏剑林-科学空间-一般框架之SDE篇
整体脉络、相关图片参考:B站讲座 -【双语】NVIDIA - EDM 以及 论文相关博文
本文按照上述讲座顺序,分为四个部分:
- 第一部分:通用框架
- 第二部分:确定性采样
- 第三部分:随机采样
- 第四部分:预处理与训练
概率流ODE
边缘分布的变形
已有边缘分布
pt(x)=∫Rdp0t(x∣x0)pdata(x0)dx0
转移核
p0t(x(t)∣x(0))=N(x(t); s(t)x(0), s(t)2σ(t)2I)
<span style="color:red;">目标:将边缘分布转化为 s(t) 和 σ(t) 的形式:
pt(x)=∫Rdp0t(x∣x0)pdata(x0)dx0=∫Rdpdata(x0)N(x; s(t)x0, s(t)2σ(t)2I)dx0=∫Rdpdata(x0)[s(t)−dN(x/s(t); x0, σ(t)2I)]dx0=s(t)−d∫Rdpdata(x0)N(x/s(t); x0, σ(t)2I)dx0=s(t)−d∫Rdpdata(x0)N(s(t)x−x0; 0, σ(t)2I)dx0=s(t)−d[pdata∗N(0, σ(t)2I)](s(t)x)
其中 ∗ 表示卷积,添加 s(t)−d 是为了进行归一化。方括号内的表达式对应对 pdata 进行“平滑化”的结果(通过向样本添加独立同分布的高斯噪声实现)。我们将该分布记为 p(x;σ):
p(x;σ)=pdata∗N(0, σ2I) ⇒ pt(x)=s(t)−dp(s(t)x; σ(t))
现在,我们可以用 p(x;σ) 代替 pt(x),重新表达 概率流ODE:
dx=[f(t)x−21g(t)2∇xlog[pt(x)]]dt=[f(t)x−21g(t)2∇xlog[s(t)−dp(x/s(t); σ(t))]]dt=[f(t)x−21g(t)2(∇xlogs(t)−d+∇xlogp(x/s(t); σ(t)))]dt=[f(t)x−21g(t)2∇xlogp(s(t)x; σ(t))]dt
f(t) 与 g(t) 的表示法
已有:
s(t)=exp(∫0tf(ξ)dξ),σ(t)=∫0ts2(ξ)g2(ξ)dξ
<span style="color:red;">目标:利用上面两式,将 f(t) 和 g(t) 变形为 s(t) 和 σ(t) 的形式:
exp(∫0tf(ξ)dξ)=s(t)⇒∫0tf(ξ)dξ=logs(t)⇒dtd[∫0tf(ξ)dξ]=dtd[logs(t)]⇒f(t)=s(t)s˙(t)
∫0ts(ξ)2g(ξ)2dξ=σ(t)⇒∫0ts(ξ)2g(ξ)2dξ=σ(t)2⇒dtd[∫0ts(ξ)2g(ξ)2dξ]=dtd[σ(t)2]⇒s(t)2g(t)2=2σ˙(t)σ(t)⇒s(t)g(t)=2σ˙(t)σ(t)⇒g(t)=s(t)2σ˙(t)σ(t)
最后,将 f 和 g 代入 概率流ODE 中:
dx=[f(t)x−21g(t)2∇xlogp(s(t)x;σ(t))]dt=[s(t)s˙(t)x−21(s(t)2σ˙(t)σ(t))2∇xlogp(s(t)x;σ(t))]dt=[s(t)s˙(t)x−21(2s(t)2σ˙(t)σ(t))∇xlogp(s(t)x;σ(t))]dt=[s(t)s˙(t)x−s(t)2σ˙(t)σ(t)∇xlogp(s(t)x;σ(t))]dt
由此,我们得到了(上)中的公式4;令 s(t)=1 时,即可还原出(上)地公式1:
dx=−σ˙(t)σ(t)∇xlogp(x;σ(t))dt
确定性采样
去噪网络 Dθ 的推导
为了完整起见,我们推导有限数据集下分数匹配与去噪之间的联系。
假设训练集由有限个样本 {y1,…,yY} 组成,则数据分布 pdata(x) 可表示为狄拉克 delta 分布的混合:
pdata(x)=Y1i=1∑Yδ(x−yi)(40)
基于公式20,我们可推导 $ p(\boldsymbol{x}; \sigma) $ 的闭式表达式:
首先,$ p(\boldsymbol{x}; \sigma) $ 是 $ p_{\text{data}} $ 与高斯分布 $ \mathcal{N}(0, \sigma(t)^2 \mathbf{I}) $ 的卷积:
p(x;σ)=pdata∗N(0,σ(t)2I)(41)
展开卷积的积分形式:
p(x;σ)=∫Rdpdata(x0)N(x;x0,σ2I)dx0(42)
代入 $ p_{\text{data}} $ 的表达式(公式40):
p(x;σ)=∫Rd[Y1i=1∑Yδ(x0−yi)]N(x;x0,σ2I)dx0(43)
利用积分与求和的交换性(线性性):
p(x;σ)=Y1i=1∑Y∫RdN(x;x0,σ2I)δ(x0−yi)dx0(44)
根据狄拉克 delta 函数的筛选性质,积分后得到:
p(x;σ)=Y1i=1∑YN(x;yi,σ2I)(45)
接下来,考虑公式2的去噪分数匹配损失。通过展开期望,我们可将其重写为对含噪样本 $ \boldsymbol{x} $ 的积分:
损失函数定义为:
L(D;σ)=Ey∼pdataEn∼N(0,σ2I)∥D(y+n;σ)−y∥22(46)
令 $ \boldsymbol{x} = \boldsymbol{y} + \boldsymbol{n} $(则 $ \boldsymbol{x} \sim \mathcal{N}(\boldsymbol{y}, \sigma^2 \mathbf{I}) $),上式可改写为:
L(D;σ)=Ey∼pdataEx∼N(y,σ2I)∥D(x;σ)−y∥22(47)
将外层期望展开为积分(利用 $ p_{\text{data}} $ 的离散混合形式):
L(D;σ)=Ey∼pdata∫RdN(x;y,σ2I)∥D(x;σ)−y∥22dx(48)
代入 $ p_{\text{data}} $ 的表达式(公式40),将期望转换为求和:
L(D;σ)=Y1i=1∑Y∫RdN(x;yi,σ2I)∥D(x;σ)−yi∥22dx(49)
再次交换求和与积分(有限和的交换性):
L(D;σ)=∫Rd=:L(D;x,σ)Y1i=1∑YN(x;yi,σ2I)∥D(x;σ)−yi∥22dx(50)
公式50表明:我们可通过独立最小化每个 $ \boldsymbol{x} $ 对应的 $ \mathcal{L}(D; \boldsymbol{x}, \sigma) $ 来最小化 $ \mathcal{L}(D; \sigma) $,即:
D(x;σ)=argD(x;σ)minL(D;x,σ)(51)
这是一个凸优化问题,其解可通过令 $ D(\boldsymbol{x}; \sigma) $ 的梯度为零唯一确定:
0=∇D(x;σ)[L(D;x,σ)](52)
代入 $ \mathcal{L}(D; \boldsymbol{x}, \sigma) $ 的表达式(公式50的被积函数):
0=∇D(x;σ)[Y1i=1∑YN(x;yi,σ2I)∥D(x;σ)−yi∥22](53)
提取常数和求和符号,利用梯度的线性性:
0=i=1∑YN(x;yi,σ2I)∇D(x;σ)[∥D(x;σ)−yi∥22](54)
对范数平方求梯度($ \nabla_a |a - b|_2^2 = 2(a - b) $):
0=i=1∑YN(x;yi,σ2I)[2D(x;σ)−2yi](55)
约去公因子2,整理项:
0=[i=1∑YN(x;yi,σ2I)]D(x;σ)−i=1∑YN(x;yi,σ2I)yi(56)
解出 $ D(\boldsymbol{x}; \sigma)$:
D(x;σ)=∑i=1YN(x;yi,σ2I)∑i=1YN(x;yi,σ2I)yi(57)
上式给出了理想去噪器 $ D(\boldsymbol{x}; \sigma) $ 的闭式解。需注意:公式57对小数据集可实际计算——我们在图1b中展示了CIFAR-10的结果。
接下来,考虑公式45定义的分布 ( p(\boldsymbol{x}; \sigma) ) 的分数(score):
∇xlogp(x;σ)=p(x;σ)∇xp(x;σ)(58)
代入公式45的 ( p(\boldsymbol{x}; \sigma) = \frac{1}{Y} \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) ),得:
∇xlogp(x;σ)=Y1∑i=1YN(x;yi,σ2I)∇x[Y1∑i=1YN(x;yi,σ2I)](59)
利用梯度的线性性(求和与梯度交换),分子分母的 ( 1/Y ) 约去,得到:
∇xlogp(x;σ)=∑i=1YN(x;yi,σ2I)∑i=1Y∇xN(x;yi,σ2I)(60)
我们可进一步简化公式60的分子:
高斯分布 ( \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) ) 的表达式为:
N(x;yi,σ2I)=(2πσ2)−d/2exp(2σ2−∥x−yi∥22)(61)
对 ( \boldsymbol{x} ) 求梯度:
∇xN(x;yi,σ2I)=∇x[(2πσ2)−d/2exp(2σ2−∥x−yi∥22)](62)
由于 ( \left( 2\pi\sigma^2 \right)^{-d/2} ) 与 ( \boldsymbol{x} ) 无关,可提出梯度外:
∇xN(x;yi,σ2I)=(2πσ2)−d/2∇xexp(2σ2−∥x−yi∥22)(63)
利用链式法则,指数函数的梯度等于自身乘以指数部分的梯度:
∇xN(x;yi,σ2I)=(2πσ2)−d/2exp(2σ2−∥x−yi∥22)⋅∇x(2σ2−∥x−yi∥22)(64)
注意到前两项正是 ( \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}i, \sigma^2 \mathbf{I}) ),而范数平方的梯度为 ( \nabla{\boldsymbol{x}} |\boldsymbol{x} - \boldsymbol{y}_i|_2^2 = 2(\boldsymbol{x} - \boldsymbol{y}_i) ),因此:
∇xN(x;yi,σ2I)=N(x;yi,σ2I)⋅σ2yi−x(65)
将上述结果代回公式60:
∇xlogp(x;σ)=∑i=1YN(x;yi,σ2I)∑i=1Y∇xN(x;yi,σ2I)(66)
代入公式65的分子:
∇xlogp(x;σ)=∑i=1YN(x;yi,σ2I)∑i=1YN(x;yi,σ2I)⋅σ2yi−x(67)
提取公因子 ( 1/\sigma^2 ) 并整理分子:
∇xlogp(x;σ)=σ21⋅∑i=1YN(x;yi,σ2I)∑i=1YN(x;yi,σ2I)yi−x∑i=1YN(x;yi,σ2I)(68)
观察公式68的分子分数部分,发现其与公式57的 ( D(\boldsymbol{x}; \sigma) ) 完全一致(( D(\boldsymbol{x}; \sigma) = \frac{\sum_i \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \boldsymbol{y}_i}{\sum_i \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I})} ))。因此,公式68可等价写为:
∇xlogp(x;σ)=σ2D(x;σ)−x(69)
这与主论文中的公式3一致。
缩放ODE公式
假设 x 是原始无量纲变量 x^ 的缩放版本,将 x=s(t)x^ 代入缩放ODE(公式4)中的分数项:
∇xlogp(s(t)x;σ(t))=∇[s(t)x^]logp(s(t)s(t)x^;σ(t))=∇s(t)x^logp(x^;σ(t))=s(t)1∇x^logp(x^;σ(t))
利用公式3,我们可进一步将其用 D(⋅) 重写:
∇xlogp(s(t)x;σ(t))=s(t)σ(t)21(D(x^;σ(t))−x^)
接下来,代入公式4,用训练好的模型 Dθ(⋅) 近似理想去噪器 D(⋅):
dx=[s(t)s˙(t)x−s(t)2σ˙(t)σ(t)(s(t)σ(t)21(Dθ(x^;σ(t))−x^))]dt=[s(t)s˙(t)x−σ(t)σ˙(t)s(t)(Dθ(x^;σ(t))−x^)]dt
最后,回代 x^=x/s(t):
dx=[s(t)s˙(t)x−σ(t)σ˙(t)s(t)(Dθ([x^];σ(t))−[x^])]dt=[s(t)s˙(t)x−σ(t)σ˙(t)s(t)(Dθ(s(t)x;σ(t))−s(t)x)]dt=[s(t)s˙(t)x−σ(t)σ˙(t)s(t)Dθ(s(t)x;σ(t))+σ(t)σ˙(t)x]dt=[(σ(t)σ˙(t)+s(t)s˙(t))x−σ(t)σ˙(t)s(t)Dθ(s(t)x;σ(t))]dt
我们可将上式等价写为:
dtdx=(σ(t)σ˙(t)+s(t)s˙(t))x−σ(t)σ˙(t)s(t)Dθ(s(t)x;σ(t))
这与算法1的第4行和第7行一致。
随机采样
B.5 我们的SDE公式(公式6)
我们通过以下策略推导公式6的SDE:
- 期望的边缘密度 ( p(\boldsymbol{x}; \sigma(t)) ) 是数据密度 ( p_{\text{data}} ) 与标准差为 ( \sigma(t) ) 的各向同性高斯密度的卷积(见公式20)。因此,作为时间 ( t ) 的函数,该密度遵循具有时变扩散率的热扩散PDE演化。第一步,我们先找到这个PDE。
- 随后,我们使用Fokker-Planck方程推导出一族SDE,其密度演化符合该PDE。公式6通过对这一族SDE进行合理参数化得到。
B.5.1 通过热扩散生成边缘分布
我们考虑概率密度 ( q(\boldsymbol{x}, t) ) 的时间演化。目标是找到一个PDE,其初始值 ( q(\boldsymbol{x}, 0) := p_{\text{data}}(\boldsymbol{x}) ) 的解为 ( q(\boldsymbol{x}, t) = p(\boldsymbol{x}; \sigma(t)) )。即,该PDE应复现公式20中的边缘分布。
期望的边缘分布是 ( p_{\text{data}} ) 与标准差随时间变化的各向同性正态分布的卷积,因此可由具有时变扩散率 ( \kappa(t) ) 的热方程生成。在傅里叶域分析该情况最方便,此时边缘密度是高斯函数与变换后的数据密度的逐点乘积。为找到诱导正确标准差的扩散率,我们先写出热方程PDE:
∂t∂q(x,t)=κ(t)Δxq(x,t).(82)
公式82的傅里叶变换对应式(变换沿 ( \boldsymbol{x} ) 维度进行)为:
∂t∂q^(ν,t)=−κ(t)∥ν∥2q^(ν,t).(83)
目标解 ( q(\boldsymbol{x}, t) ) 及其傅里叶变换 ( \hat{q}(\boldsymbol{\nu}, t) ) 由公式20给出:
q(x,t)=p(x;σ(t))=pdata(x)∗N(0,σ(t)2I)(84)
q^(ν,t)=p^data(ν)exp(−21∥ν∥2σ(t)2).(85)
沿时间轴对目标解求导,我们有:
∂t∂q^(ν,t)=−σ˙(t)σ(t)∥ν∥2p^data(ν)exp(−21∥ν∥2σ(t)2)(86)
∂t∂q^(ν,t)=−σ˙(t)σ(t)∥ν∥2q^(ν,t).(87)
公式83和公式87的左侧相同。令二者相等,可解出产生期望演化的 ( \kappa(t) ):
−κ(t)∥ν∥2q^(ν,t)=−σ˙(t)σ(t)∥ν∥2q^(ν,t)(88)
κ(t)=σ˙(t)σ(t).(89)
综上,对应噪声水平 ( \sigma(t) ) 的期望边缘密度由以下PDE生成:
∂t∂q(x,t)=σ˙(t)σ(t)Δxq(x,t)(90)
其初始密度为 ( q(\boldsymbol{x}, 0) = p_{\text{data}}(\boldsymbol{x}) )。
B.5.2 我们的SDE推导
考虑如下SDE:
dx=f(x,t)dt+g(x,t)dωt,(91)
Fokker-Planck PDE 描述了其解的概率密度 ( r(\boldsymbol{x}, t) ) 的时间演化:
∂t∂r(x,t)=−∇x⋅(f(x,t)r(x,t))+21∇x∇x:(D(x,t)r(x,t)),(92)
其中 ( \mathbf{D}{ij} = \sum_k g{ik} g_{jk} ) 是扩散张量。我们考虑 ( \boldsymbol{g}(\boldsymbol{x}, t) = g(t),\mathbf{I} ) 的特殊情况(与 ( \boldsymbol{x} ) 无关的白噪声添加),此时方程简化为:
∂t∂r(x,t)=−∇x⋅(f(x,t)r(x,t))+21g(t)2Δxr(x,t).(93)
我们寻求一个SDE,其解的密度由公式90的PDE描述。令 ( r(\boldsymbol{x}, t) = q(\boldsymbol{x}, t) ),并令公式93和公式90相等,可得SDE必须满足的充分条件:
−∇x⋅(f(x,t)q(x,t))+21g(t)2Δxq(x,t)=σ˙(t)σ(t)Δxq(x,t)(94)
∇x⋅(f(x,t)q(x,t))=(21g(t)2−σ˙(t)σ(t))Δxq(x,t).(95)
任何满足该方程的函数 ( \boldsymbol{f}(\boldsymbol{x}, t) ) 和 ( g(t) ) 都构成所求的SDE。现在我们寻找这类解的一个具体族。关键思路源于恒等式 ( \nabla_{\boldsymbol{x}} \cdot \nabla_{\boldsymbol{x}} = \Delta_{\boldsymbol{x}} )。实际上,若令 ( \boldsymbol{f}(\boldsymbol{x}, t),q(\boldsymbol{x}, t) = v(t),\nabla_{\boldsymbol{x}} q(\boldsymbol{x}, t) )(对任意 ( v(t) ) 成立),则 ( \Delta_{\boldsymbol{x}} q(\boldsymbol{x}, t) ) 会出现在两边并抵消:
∇x⋅(v(t)∇xq(x,t))=(21g(t)2−σ˙(t)σ(t))Δxq(x,t)(96)
v(t)Δxq(x,t)=(21g(t)2−σ˙(t)σ(t))Δxq(x,t)(97)
v(t)=21g(t)2−σ˙(t)σ(t).(98)
上述 ( \boldsymbol{f}(\boldsymbol{x}, t) ) 实际上与分数函数成比例,因为该公式匹配了密度对数的梯度:
f(x,t)=v(t)q(x,t)∇xq(x,t)(99)
f(x,t)=v(t)∇xlogq(x,t)(100)
f(x,t)=(21g(t)2−σ˙(t)σ(t))∇xlogq(x,t).(101)