EDM:扩散模型的设计空间(上)

本文最后更新于 2025年9月25日 凌晨

2025/08/04 - 至今
待完善

主要参考资料

本篇文章主要基于《Elucidating the Design Space of Diffusion-Based
Generative Models》以及以下几个视频、博文。

数学推导参考:B站讲解 Double童发发 - EDM论文讲解之扩散模型通用框架系列、苏剑林-科学空间-一般框架之SDE篇

整体脉络、相关图片参考:B站讲座 -【双语】NVIDIA - EDM 以及 论文相关博文

本文按照上述讲座顺序,分为四个部分:

  • 第一部分:通用框架
    • 识别现有研究中的可变部分
  • 第二部分:确定性采样
    • 高效求解 ODE
  • 第三部分:随机采样
    • 为什么用 SDE?如何进行随机步长计算?
  • 第四部分:预处理与训练
    • 如何训练用于评估单步的卷积神经网络(CNN)?

概率流ODE

边缘分布的变形

已有边缘分布

pt(x)=Rdp0t(xx0)pdata(x0)dx0p_t(\boldsymbol{x}) = \int_{\mathbb{R}^d} p_{0t}(\boldsymbol{x} \mid \boldsymbol{x}_0)\, p_{\text{data}}(\boldsymbol{x}_0)\,\mathrm{d}\boldsymbol{x}_0

转移核

p0t(x(t)x(0))=N(x(t); s(t)x(0), s(t)2σ(t)2I)p_{0t}(\boldsymbol{x}(t) \mid \boldsymbol{x}(0)) = \mathcal{N}\big(\boldsymbol{x}(t);\ s(t)\boldsymbol{x}(0),\ s(t)^2\sigma(t)^2\boldsymbol{I}\big)

<span style="color:red;">目标:将边缘分布转化为 s(t)s(t)σ(t)\sigma(t) 的形式:

pt(x)=Rdp0t(xx0)pdata(x0)dx0=Rdpdata(x0)N(x; s(t)x0, s(t)2σ(t)2I)dx0=Rdpdata(x0)[s(t)dN(x/s(t); x0, σ(t)2I)]dx0=s(t)dRdpdata(x0)N(x/s(t); x0, σ(t)2I)dx0=s(t)dRdpdata(x0)N(xs(t)x0; 0, σ(t)2I)dx0=s(t)d[pdataN(0, σ(t)2I)](xs(t))\begin{align*} p_t(\boldsymbol{x}) &= \int_{\mathbb{R}^d} p_{0t}(\boldsymbol{x} \mid \boldsymbol{x}_0)\, p_{\text{data}}(\boldsymbol{x}_0)\,\mathrm{d}\boldsymbol{x}_0 \\ &= \int_{\mathbb{R}^d} p_{\text{data}}(\boldsymbol{x}_0) \, \mathcal{N}\big(\boldsymbol{x};\ s(t)\boldsymbol{x}_0,\ s(t)^2\sigma(t)^2\boldsymbol{I}\big) \,\mathrm{d}\boldsymbol{x}_0 \\ &= \int_{\mathbb{R}^d} p_{\text{data}}(\boldsymbol{x}_0) \, \left[ s(t)^{-d} \mathcal{N}\big(\boldsymbol{x}/s(t);\ \boldsymbol{x}_0,\ \sigma(t)^2\boldsymbol{I}\big) \right] \,\mathrm{d}\boldsymbol{x}_0 \\ &= s(t)^{-d} \int_{\mathbb{R}^d} p_{\text{data}}(\boldsymbol{x}_0) \, \mathcal{N}\big(\boldsymbol{x}/s(t);\ \boldsymbol{x}_0,\ \sigma(t)^2\boldsymbol{I}\big) \,\mathrm{d}\boldsymbol{x}_0 \\ &= s(t)^{-d} \int_{\mathbb{R}^d} p_{\text{data}}(\boldsymbol{x}_0) \, \mathcal{N}\big(\frac{\boldsymbol{x}}{s(t)} - \boldsymbol{x}_0;\ \mathbf{0},\ \sigma(t)^2\boldsymbol{I}\big) \,\mathrm{d}\boldsymbol{x}_0\\ &= s(t)^{-d} \left[ p_{\text{data}} * \mathcal{N}\big(\boldsymbol{0},\ \sigma(t)^2\boldsymbol{I}\big) \right] \big(\frac{\boldsymbol{x}}{s(t)}\big) \end{align*}

其中 * 表示卷积,添加 s(t)ds(t)^{-d} 是为了进行归一化。方括号内的表达式对应对 pdatap_{\text{data}} 进行“平滑化”的结果(通过向样本添加独立同分布的高斯噪声实现)。我们将该分布记为 p(x;σ)p(\boldsymbol{x};\sigma)

p(x;σ)=pdataN(0, σ2I)  pt(x)=s(t)dp(xs(t); σ(t))p(\boldsymbol{x};\sigma) = p_{\text{data}} * \mathcal{N}\big(\boldsymbol{0},\ \sigma^2\boldsymbol{I}\big) \ \Rightarrow \ p_t(\boldsymbol{x}) = s(t)^{-d}\, p\big(\frac{\boldsymbol{x}}{s(t)};\ \sigma(t)\big)

现在,我们可以用 p(x;σ)p(\boldsymbol{x};\sigma) 代替 pt(x)p_t(\boldsymbol{x}),重新表达 概率流ODE

dx=[f(t)x12g(t)2xlog[pt(x)]]dt=[f(t)x12g(t)2xlog[s(t)dp(x/s(t); σ(t))]]dt=[f(t)x12g(t)2(xlogs(t)d+xlogp(x/s(t); σ(t)))]dt=[f(t)x12g(t)2xlogp(xs(t); σ(t))]dt\begin{align*} \mathrm{d}\boldsymbol{x} &= \left[ f(t)\boldsymbol{x} - \frac{1}{2}g(t)^2 \nabla_{\boldsymbol{x}} \log \big[ p_t(\boldsymbol{x}) \big] \right] \mathrm{d}t \\ &= \left[ f(t)\boldsymbol{x} - \frac{1}{2}g(t)^2 \nabla_{\boldsymbol{x}} \log \big[ s(t)^{-d}\, p\big(\boldsymbol{x}/s(t);\ \sigma(t)\big) \big] \right] \mathrm{d}t \\ &= \left[ f(t)\boldsymbol{x} - \frac{1}{2}g(t)^2 \left( \nabla_{\boldsymbol{x}} \log s(t)^{-d} + \nabla_{\boldsymbol{x}} \log p\big(\boldsymbol{x}/s(t);\ \sigma(t)\big) \right) \right] \mathrm{d}t \\ &= \left[ f(t)\boldsymbol{x} - \frac{1}{2}g(t)^2 \nabla_{\boldsymbol{x}} \log p\big(\frac{\boldsymbol{x}}{s(t)};\ \sigma(t)\big) \right] \mathrm{d}t \end{align*}

f(t)f(t)g(t)g(t) 的表示法

已有:

s(t)=exp(0tf(ξ)dξ),σ(t)=0tg2(ξ)s2(ξ)dξs(t) = \exp\left( \int_0^t f(\xi)\mathrm{d}\xi \right), \quad \sigma(t) = \sqrt{ \int_0^t \frac{g^2(\xi)}{s^2(\xi)} \mathrm{d}\xi }

<span style="color:red;">目标:利用上面两式,将 f(t)f(t)g(t)g(t) 变形为 s(t)s(t)σ(t)\sigma(t) 的形式:

exp(0tf(ξ)dξ)=s(t)0tf(ξ)dξ=logs(t)ddt[0tf(ξ)dξ]=ddt[logs(t)]f(t)=s˙(t)s(t)\begin{align*} \exp\left( \int_0^t f(\xi)\,\mathrm{d}\xi \right) = s(t) &\Rightarrow \int_0^t f(\xi)\,\mathrm{d}\xi = \log s(t) \\ &\Rightarrow \frac{\mathrm{d}}{\mathrm{d}t}\left[ \int_0^t f(\xi)\,\mathrm{d}\xi \right] = \frac{\mathrm{d}}{\mathrm{d}t}\left[ \log s(t) \right] \\ &\Rightarrow f(t) = \frac{\dot{s}(t)}{s(t)} \end{align*}

0tg(ξ)2s(ξ)2dξ=σ(t)0tg(ξ)2s(ξ)2dξ=σ(t)2ddt[0tg(ξ)2s(ξ)2dξ]=ddt[σ(t)2]g(t)2s(t)2=2σ˙(t)σ(t)g(t)s(t)=2σ˙(t)σ(t)g(t)=s(t)2σ˙(t)σ(t)\begin{align*} \sqrt{ \int_0^t \frac{g(\xi)^2}{s(\xi)^2}\,\mathrm{d}\xi } = \sigma(t) &\Rightarrow \int_0^t \frac{g(\xi)^2}{s(\xi)^2}\,\mathrm{d}\xi = \sigma(t)^2 \\ &\Rightarrow \frac{\mathrm{d}}{\mathrm{d}t}\left[ \int_0^t \frac{g(\xi)^2}{s(\xi)^2}\,\mathrm{d}\xi \right] = \frac{\mathrm{d}}{\mathrm{d}t}\left[ \sigma(t)^2 \right] \\ & \Rightarrow \frac{g(t)^2}{s(t)^2} = 2\,\dot{\sigma}(t)\,\sigma(t) \\ & \Rightarrow \frac{g(t)}{s(t)} = \sqrt{2\,\dot{\sigma}(t)\,\sigma(t)} \\ & \Rightarrow g(t) = s(t)\,\sqrt{2\,\dot{\sigma}(t)\,\sigma(t)} \end{align*}

最后,将 ffgg 代入 概率流ODE 中:

dx=[f(t)x12g(t)2xlogp(xs(t);σ(t))]dt=[s˙(t)s(t)x12(s(t)2σ˙(t)σ(t))2xlogp(xs(t);σ(t))]dt=[s˙(t)s(t)x12(2s(t)2σ˙(t)σ(t))xlogp(xs(t);σ(t))]dt=[s˙(t)s(t)xs(t)2σ˙(t)σ(t)xlogp(xs(t);σ(t))]dt\begin{align*} \mathrm{d}\boldsymbol{x} &= \left[ f(t)\,\boldsymbol{x} - \frac{1}{2}\,g(t)^2\,\nabla_{\boldsymbol{x}} \log p\left( \frac{\boldsymbol{x}}{s(t)}; \sigma(t) \right) \right] \mathrm{d}t \\ &= \left[ \frac{\dot{s}(t)}{s(t)}\,\boldsymbol{x} - \frac{1}{2}\,\left( s(t)\sqrt{2\,\dot{\sigma}(t)\,\sigma(t)} \right)^2 \nabla_{\boldsymbol{x}} \log p\left( \frac{\boldsymbol{x}}{s(t)}; \sigma(t) \right) \right] \mathrm{d}t \\ &= \left[ \frac{\dot{s}(t)}{s(t)}\,\boldsymbol{x} - \frac{1}{2}\,\left( 2\,s(t)^2\,\dot{\sigma}(t)\,\sigma(t) \right) \nabla_{\boldsymbol{x}} \log p\left( \frac{\boldsymbol{x}}{s(t)}; \sigma(t) \right) \right] \mathrm{d}t \\ &= \left[ \frac{\dot{s}(t)}{s(t)}\,\boldsymbol{x} - s(t)^2\,\dot{\sigma}(t)\,\sigma(t)\,\nabla_{\boldsymbol{x}} \log p\left( \frac{\boldsymbol{x}}{s(t)}; \sigma(t) \right) \right] \mathrm{d}t \end{align*}

由此,我们得到了(上)中的公式4;令 s(t)=1s(t) = 1 时,即可还原出(上)地公式1:

dx=σ˙(t)σ(t)xlogp(x;σ(t))dt\mathrm{d}\boldsymbol{x} = -\dot{\sigma}(t)\,\sigma(t)\,\nabla_{\boldsymbol{x}} \log p\left( \boldsymbol{x}; \sigma(t) \right) \mathrm{d}t

确定性采样

去噪网络 DθD_\theta 的推导

为了完整起见,我们推导有限数据集下分数匹配去噪之间的联系。

假设训练集由有限个样本 {y1,,yY}\{\boldsymbol{y}_1, \ldots, \boldsymbol{y}_Y\} 组成,则数据分布 pdata(x)p_{\text{data}}(\boldsymbol{x}) 可表示为狄拉克 delta 分布的混合:

pdata(x)=1Yi=1Yδ(xyi)(40)p_{\text{data}}(\boldsymbol{x}) = \frac{1}{Y} \sum_{i=1}^Y \delta(\boldsymbol{x} - \boldsymbol{y}_i) \tag{40}

基于公式20,我们可推导 $ p(\boldsymbol{x}; \sigma) $ 的闭式表达式:

首先,$ p(\boldsymbol{x}; \sigma) $ 是 $ p_{\text{data}} $ 与高斯分布 $ \mathcal{N}(0, \sigma(t)^2 \mathbf{I}) $ 的卷积:

p(x;σ)=pdataN(0,σ(t)2I)(41)p(\boldsymbol{x}; \sigma) = p_{\text{data}} * \mathcal{N}(0, \sigma(t)^2 \mathbf{I}) \tag{41}

展开卷积的积分形式:

p(x;σ)=Rdpdata(x0)N(x;x0,σ2I)dx0(42)p(\boldsymbol{x}; \sigma) = \int_{\mathbb{R}^d} p_{\text{data}}(\boldsymbol{x}_0) \mathcal{N}(\boldsymbol{x}; \boldsymbol{x}_0, \sigma^2 \mathbf{I}) \,\mathrm{d}\boldsymbol{x}_0 \tag{42}

代入 $ p_{\text{data}} $ 的表达式(公式40):

p(x;σ)=Rd[1Yi=1Yδ(x0yi)]N(x;x0,σ2I)dx0(43)p(\boldsymbol{x}; \sigma) = \int_{\mathbb{R}^d} \left[ \frac{1}{Y} \sum_{i=1}^Y \delta(\boldsymbol{x}_0 - \boldsymbol{y}_i) \right] \mathcal{N}(\boldsymbol{x}; \boldsymbol{x}_0, \sigma^2 \mathbf{I}) \,\mathrm{d}\boldsymbol{x}_0 \tag{43}

利用积分与求和的交换性(线性性):

p(x;σ)=1Yi=1YRdN(x;x0,σ2I)δ(x0yi)dx0(44)p(\boldsymbol{x}; \sigma) = \frac{1}{Y} \sum_{i=1}^Y \int_{\mathbb{R}^d} \mathcal{N}(\boldsymbol{x}; \boldsymbol{x}_0, \sigma^2 \mathbf{I}) \,\delta(\boldsymbol{x}_0 - \boldsymbol{y}_i) \,\mathrm{d}\boldsymbol{x}_0 \tag{44}

根据狄拉克 delta 函数的筛选性质,积分后得到:

p(x;σ)=1Yi=1YN(x;yi,σ2I)(45)p(\boldsymbol{x}; \sigma) = \frac{1}{Y} \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \tag{45}

接下来,考虑公式2的去噪分数匹配损失。通过展开期望,我们可将其重写为对含噪样本 $ \boldsymbol{x} $ 的积分:

损失函数定义为:

L(D;σ)=EypdataEnN(0,σ2I)D(y+n;σ)y22(46)\mathcal{L}(D; \sigma) = \mathbb{E}_{\boldsymbol{y} \sim p_{\text{data}}} \mathbb{E}_{\boldsymbol{n} \sim \mathcal{N}(0, \sigma^2 \mathbf{I})} \left\| D(\boldsymbol{y} + \boldsymbol{n}; \sigma) - \boldsymbol{y} \right\|_2^2 \tag{46}

令 $ \boldsymbol{x} = \boldsymbol{y} + \boldsymbol{n} $(则 $ \boldsymbol{x} \sim \mathcal{N}(\boldsymbol{y}, \sigma^2 \mathbf{I}) $),上式可改写为:

L(D;σ)=EypdataExN(y,σ2I)D(x;σ)y22(47)\mathcal{L}(D; \sigma) = \mathbb{E}_{\boldsymbol{y} \sim p_{\text{data}}} \mathbb{E}_{\boldsymbol{x} \sim \mathcal{N}(\boldsymbol{y}, \sigma^2 \mathbf{I})} \left\| D(\boldsymbol{x}; \sigma) - \boldsymbol{y} \right\|_2^2 \tag{47}

将外层期望展开为积分(利用 $ p_{\text{data}} $ 的离散混合形式):

L(D;σ)=EypdataRdN(x;y,σ2I)D(x;σ)y22dx(48)\mathcal{L}(D; \sigma) = \mathbb{E}_{\boldsymbol{y} \sim p_{\text{data}}} \int_{\mathbb{R}^d} \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}, \sigma^2 \mathbf{I}) \left\| D(\boldsymbol{x}; \sigma) - \boldsymbol{y} \right\|_2^2 \,\mathrm{d}\boldsymbol{x} \tag{48}

代入 $ p_{\text{data}} $ 的表达式(公式40),将期望转换为求和:

L(D;σ)=1Yi=1YRdN(x;yi,σ2I)D(x;σ)yi22dx(49)\mathcal{L}(D; \sigma) = \frac{1}{Y} \sum_{i=1}^Y \int_{\mathbb{R}^d} \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \left\| D(\boldsymbol{x}; \sigma) - \boldsymbol{y}_i \right\|_2^2 \,\mathrm{d}\boldsymbol{x} \tag{49}

再次交换求和与积分(有限和的交换性):

L(D;σ)=Rd1Yi=1YN(x;yi,σ2I)=:L(D;x,σ)D(x;σ)yi22dx(50)\mathcal{L}(D; \sigma) = \int_{\mathbb{R}^d} \underbrace{ \frac{1}{Y} \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) }_{=: \mathcal{L}(D; \boldsymbol{x}, \sigma)} \left\| D(\boldsymbol{x}; \sigma) - \boldsymbol{y}_i \right\|_2^2 \,\mathrm{d}\boldsymbol{x} \tag{50}

公式50表明:我们可通过独立最小化每个 $ \boldsymbol{x} $ 对应的 $ \mathcal{L}(D; \boldsymbol{x}, \sigma) $ 来最小化 $ \mathcal{L}(D; \sigma) $,即:

D(x;σ)=argminD(x;σ)L(D;x,σ)(51)D(\boldsymbol{x}; \sigma) = \arg\min_{D(\boldsymbol{x}; \sigma)} \mathcal{L}(D; \boldsymbol{x}, \sigma) \tag{51}

这是一个凸优化问题,其解可通过令 $ D(\boldsymbol{x}; \sigma) $ 的梯度为零唯一确定:

0=D(x;σ)[L(D;x,σ)](52)\mathbf{0} = \nabla_{D(\boldsymbol{x}; \sigma)} \left[ \mathcal{L}(D; \boldsymbol{x}, \sigma) \right] \tag{52}

代入 $ \mathcal{L}(D; \boldsymbol{x}, \sigma) $ 的表达式(公式50的被积函数):

0=D(x;σ)[1Yi=1YN(x;yi,σ2I)D(x;σ)yi22](53)\mathbf{0} = \nabla_{D(\boldsymbol{x}; \sigma)} \left[ \frac{1}{Y} \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \left\| D(\boldsymbol{x}; \sigma) - \boldsymbol{y}_i \right\|_2^2 \right] \tag{53}

提取常数和求和符号,利用梯度的线性性:

0=i=1YN(x;yi,σ2I)D(x;σ)[D(x;σ)yi22](54)\mathbf{0} = \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \, \nabla_{D(\boldsymbol{x}; \sigma)} \left[ \left\| D(\boldsymbol{x}; \sigma) - \boldsymbol{y}_i \right\|_2^2 \right] \tag{54}

对范数平方求梯度($ \nabla_a |a - b|_2^2 = 2(a - b) $):

0=i=1YN(x;yi,σ2I)[2D(x;σ)2yi](55)\mathbf{0} = \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \, \left[ 2\,D(\boldsymbol{x}; \sigma) - 2\,\boldsymbol{y}_i \right] \tag{55}

约去公因子2,整理项:

0=[i=1YN(x;yi,σ2I)]D(x;σ)i=1YN(x;yi,σ2I)yi(56)\mathbf{0} = \left[ \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \right] D(\boldsymbol{x}; \sigma) - \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \,\boldsymbol{y}_i \tag{56}

解出 $ D(\boldsymbol{x}; \sigma)$:

D(x;σ)=i=1YN(x;yi,σ2I)yii=1YN(x;yi,σ2I)(57)D(\boldsymbol{x}; \sigma) = \frac{ \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \,\boldsymbol{y}_i }{ \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) } \tag{57}

上式给出了理想去噪器 $ D(\boldsymbol{x}; \sigma) $ 的闭式解。需注意:公式57对小数据集可实际计算——我们在图1b中展示了CIFAR-10的结果。

接下来,考虑公式45定义的分布 ( p(\boldsymbol{x}; \sigma) ) 的分数(score):

xlogp(x;σ)=xp(x;σ)p(x;σ)(58)\nabla_{\boldsymbol{x}} \log p(\boldsymbol{x}; \sigma) = \frac{\nabla_{\boldsymbol{x}} p(\boldsymbol{x}; \sigma)}{p(\boldsymbol{x}; \sigma)} \tag{58}

代入公式45的 ( p(\boldsymbol{x}; \sigma) = \frac{1}{Y} \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) ),得:

xlogp(x;σ)=x[1Yi=1YN(x;yi,σ2I)]1Yi=1YN(x;yi,σ2I)(59)\nabla_{\boldsymbol{x}} \log p(\boldsymbol{x}; \sigma) = \frac{\nabla_{\boldsymbol{x}} \left[ \frac{1}{Y} \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \right]}{\frac{1}{Y} \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I})} \tag{59}

利用梯度的线性性(求和与梯度交换),分子分母的 ( 1/Y ) 约去,得到:

xlogp(x;σ)=i=1YxN(x;yi,σ2I)i=1YN(x;yi,σ2I)(60)\nabla_{\boldsymbol{x}} \log p(\boldsymbol{x}; \sigma) = \frac{\sum_{i=1}^Y \nabla_{\boldsymbol{x}} \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I})}{\sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I})} \tag{60}

我们可进一步简化公式60的分子:

高斯分布 ( \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) ) 的表达式为:

N(x;yi,σ2I)=(2πσ2)d/2exp(xyi222σ2)(61)\mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) = \left( 2\pi\sigma^2 \right)^{-d/2} \exp\left( \frac{-\|\boldsymbol{x} - \boldsymbol{y}_i\|_2^2}{2\sigma^2} \right) \tag{61}

对 ( \boldsymbol{x} ) 求梯度:

xN(x;yi,σ2I)=x[(2πσ2)d/2exp(xyi222σ2)](62)\nabla_{\boldsymbol{x}} \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) = \nabla_{\boldsymbol{x}} \left[ \left( 2\pi\sigma^2 \right)^{-d/2} \exp\left( \frac{-\|\boldsymbol{x} - \boldsymbol{y}_i\|_2^2}{2\sigma^2} \right) \right] \tag{62}

由于 ( \left( 2\pi\sigma^2 \right)^{-d/2} ) 与 ( \boldsymbol{x} ) 无关,可提出梯度外:

xN(x;yi,σ2I)=(2πσ2)d/2xexp(xyi222σ2)(63)\nabla_{\boldsymbol{x}} \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) = \left( 2\pi\sigma^2 \right)^{-d/2} \nabla_{\boldsymbol{x}} \exp\left( \frac{-\|\boldsymbol{x} - \boldsymbol{y}_i\|_2^2}{2\sigma^2} \right) \tag{63}

利用链式法则,指数函数的梯度等于自身乘以指数部分的梯度:

xN(x;yi,σ2I)=(2πσ2)d/2exp(xyi222σ2)x(xyi222σ2)(64)\nabla_{\boldsymbol{x}} \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) = \left( 2\pi\sigma^2 \right)^{-d/2} \exp\left( \frac{-\|\boldsymbol{x} - \boldsymbol{y}_i\|_2^2}{2\sigma^2} \right) \cdot \nabla_{\boldsymbol{x}} \left( \frac{-\|\boldsymbol{x} - \boldsymbol{y}_i\|_2^2}{2\sigma^2} \right) \tag{64}

注意到前两项正是 ( \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}i, \sigma^2 \mathbf{I}) ),而范数平方的梯度为 ( \nabla{\boldsymbol{x}} |\boldsymbol{x} - \boldsymbol{y}_i|_2^2 = 2(\boldsymbol{x} - \boldsymbol{y}_i) ),因此:

xN(x;yi,σ2I)=N(x;yi,σ2I)yixσ2(65)\nabla_{\boldsymbol{x}} \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) = \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \cdot \frac{\boldsymbol{y}_i - \boldsymbol{x}}{\sigma^2} \tag{65}

将上述结果代回公式60:

xlogp(x;σ)=i=1YxN(x;yi,σ2I)i=1YN(x;yi,σ2I)(66)\nabla_{\boldsymbol{x}} \log p(\boldsymbol{x}; \sigma) = \frac{\sum_{i=1}^Y \nabla_{\boldsymbol{x}} \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I})}{\sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I})} \tag{66}

代入公式65的分子:

xlogp(x;σ)=i=1YN(x;yi,σ2I)yixσ2i=1YN(x;yi,σ2I)(67)\nabla_{\boldsymbol{x}} \log p(\boldsymbol{x}; \sigma) = \frac{\sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \cdot \frac{\boldsymbol{y}_i - \boldsymbol{x}}{\sigma^2}}{\sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I})} \tag{67}

提取公因子 ( 1/\sigma^2 ) 并整理分子:

xlogp(x;σ)=1σ2i=1YN(x;yi,σ2I)yixi=1YN(x;yi,σ2I)i=1YN(x;yi,σ2I)(68)\nabla_{\boldsymbol{x}} \log p(\boldsymbol{x}; \sigma) = \frac{1}{\sigma^2} \cdot \frac{ \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \boldsymbol{y}_i - \boldsymbol{x} \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) }{ \sum_{i=1}^Y \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) } \tag{68}

观察公式68的分子分数部分,发现其与公式57的 ( D(\boldsymbol{x}; \sigma) ) 完全一致(( D(\boldsymbol{x}; \sigma) = \frac{\sum_i \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I}) \boldsymbol{y}_i}{\sum_i \mathcal{N}(\boldsymbol{x}; \boldsymbol{y}_i, \sigma^2 \mathbf{I})} ))。因此,公式68可等价写为:

xlogp(x;σ)=D(x;σ)xσ2(69)\nabla_{\boldsymbol{x}} \log p(\boldsymbol{x}; \sigma) = \frac{D(\boldsymbol{x}; \sigma) - \boldsymbol{x}}{\sigma^2} \tag{69}

这与主论文中的公式3一致。

缩放ODE公式

假设 x\boldsymbol{x} 是原始无量纲变量 x^\hat{\boldsymbol{x}} 的缩放版本,将 x=s(t)x^\boldsymbol{x} = s(t)\,\hat{\boldsymbol{x}} 代入缩放ODE(公式4)中的分数项:

xlogp(xs(t);σ(t))=[s(t)x^]logp(s(t)x^s(t);σ(t))=s(t)x^logp(x^;σ(t))=1s(t)x^logp(x^;σ(t))\begin{align*} \nabla_{\boldsymbol{x}} \log p\left( \frac{\boldsymbol{x}}{s(t)}; \sigma(t) \right) &= \nabla_{[s(t)\,\hat{\boldsymbol{x}}]} \log p\left( \frac{s(t)\,\hat{\boldsymbol{x}}}{s(t)}; \sigma(t) \right)\\ &= \nabla_{s(t)\,\hat{\boldsymbol{x}}} \log p\left( \hat{\boldsymbol{x}}; \sigma(t) \right) \\ &= \frac{1}{s(t)}\,\nabla_{\hat{\boldsymbol{x}}} \log p\left( \hat{\boldsymbol{x}}; \sigma(t) \right) \end{align*}

利用公式3,我们可进一步将其用 D()D(\cdot) 重写:

xlogp(xs(t);σ(t))=1s(t)σ(t)2(D(x^;σ(t))x^)\nabla_{\boldsymbol{x}} \log p\left( \frac{\boldsymbol{x}}{s(t)}; \sigma(t) \right) = \frac{1}{s(t)\,\sigma(t)^2} \left( D\left( \hat{\boldsymbol{x}}; \sigma(t) \right) - \hat{\boldsymbol{x}} \right)

接下来,代入公式4,用训练好的模型 Dθ()D_\theta(\cdot) 近似理想去噪器 D()D(\cdot)

dx=[s˙(t)s(t)xs(t)2σ˙(t)σ(t)(1s(t)σ(t)2(Dθ(x^;σ(t))x^))]dt=[s˙(t)s(t)xσ˙(t)s(t)σ(t)(Dθ(x^;σ(t))x^)]dt\begin{align*} \mathrm{d}\boldsymbol{x} &= \left[ \frac{\dot{s}(t)}{s(t)}\,\boldsymbol{x} - s(t)^2\,\dot{\sigma}(t)\,\sigma(t) \left( \frac{1}{s(t)\,\sigma(t)^2} \left( D_\theta\left( \hat{\boldsymbol{x}}; \sigma(t) \right) - \hat{\boldsymbol{x}} \right) \right) \right] \mathrm{d}t \\ &= \left[ \frac{\dot{s}(t)}{s(t)}\,\boldsymbol{x} - \frac{\dot{\sigma}(t)\,s(t)}{\sigma(t)} \left( D_\theta\left( \hat{\boldsymbol{x}}; \sigma(t) \right) - \hat{\boldsymbol{x}} \right) \right] \mathrm{d}t \end{align*}

最后,回代 x^=x/s(t)\hat{\boldsymbol{x}} = \boldsymbol{x}/s(t)

dx=[s˙(t)s(t)xσ˙(t)s(t)σ(t)(Dθ([x^];σ(t))[x^])]dt=[s˙(t)s(t)xσ˙(t)s(t)σ(t)(Dθ(xs(t);σ(t))xs(t))]dt=[s˙(t)s(t)xσ˙(t)s(t)σ(t)Dθ(xs(t);σ(t))+σ˙(t)σ(t)x]dt=[(σ˙(t)σ(t)+s˙(t)s(t))xσ˙(t)s(t)σ(t)Dθ(xs(t);σ(t))]dt\begin{align*} \mathrm{d}\boldsymbol{x} &= \left[ \frac{\dot{s}(t)}{s(t)}\,\boldsymbol{x} - \frac{\dot{\sigma}(t)\,s(t)}{\sigma(t)} \left( D_\theta\left( [\hat{\boldsymbol{x}}]; \sigma(t) \right) - [\hat{\boldsymbol{x}}] \right) \right] \mathrm{d}t \\ &= \left[ \frac{\dot{s}(t)}{s(t)}\,\boldsymbol{x} - \frac{\dot{\sigma}(t)\,s(t)}{\sigma(t)} \left( D_\theta\left( \frac{\boldsymbol{x}}{s(t)}; \sigma(t) \right) - \frac{\boldsymbol{x}}{s(t)} \right) \right] \mathrm{d}t \\ &= \left[ \frac{\dot{s}(t)}{s(t)}\,\boldsymbol{x} - \frac{\dot{\sigma}(t)\,s(t)}{\sigma(t)}\,D_\theta\left( \frac{\boldsymbol{x}}{s(t)}; \sigma(t) \right) + \frac{\dot{\sigma}(t)}{\sigma(t)}\,\boldsymbol{x} \right] \mathrm{d}t \\ &= \left[ \left( \frac{\dot{\sigma}(t)}{\sigma(t)} + \frac{\dot{s}(t)}{s(t)} \right)\,\boldsymbol{x} - \frac{\dot{\sigma}(t)\,s(t)}{\sigma(t)}\,D_\theta\left( \frac{\boldsymbol{x}}{s(t)}; \sigma(t) \right) \right] \mathrm{d}t \end{align*}

我们可将上式等价写为:

dxdt=(σ˙(t)σ(t)+s˙(t)s(t))xσ˙(t)s(t)σ(t)Dθ(xs(t);σ(t))\frac{\mathrm{d}\boldsymbol{x}}{\mathrm{d}t} = \left( \frac{\dot{\sigma}(t)}{\sigma(t)} + \frac{\dot{s}(t)}{s(t)} \right)\,\boldsymbol{x} - \frac{\dot{\sigma}(t)\,s(t)}{\sigma(t)}\,D_\theta\left( \frac{\boldsymbol{x}}{s(t)}; \sigma(t) \right)

这与算法1的第4行和第7行一致。

随机采样

B.5 我们的SDE公式(公式6)

我们通过以下策略推导公式6的SDE:

  • 期望的边缘密度 ( p(\boldsymbol{x}; \sigma(t)) ) 是数据密度 ( p_{\text{data}} ) 与标准差为 ( \sigma(t) ) 的各向同性高斯密度的卷积(见公式20)。因此,作为时间 ( t ) 的函数,该密度遵循具有时变扩散率的热扩散PDE演化。第一步,我们先找到这个PDE。
  • 随后,我们使用Fokker-Planck方程推导出一族SDE,其密度演化符合该PDE。公式6通过对这一族SDE进行合理参数化得到。

B.5.1 通过热扩散生成边缘分布

我们考虑概率密度 ( q(\boldsymbol{x}, t) ) 的时间演化。目标是找到一个PDE,其初始值 ( q(\boldsymbol{x}, 0) := p_{\text{data}}(\boldsymbol{x}) ) 的解为 ( q(\boldsymbol{x}, t) = p(\boldsymbol{x}; \sigma(t)) )。即,该PDE应复现公式20中的边缘分布。

期望的边缘分布是 ( p_{\text{data}} ) 与标准差随时间变化的各向同性正态分布的卷积,因此可由具有时变扩散率 ( \kappa(t) ) 的热方程生成。在傅里叶域分析该情况最方便,此时边缘密度是高斯函数与变换后的数据密度的逐点乘积。为找到诱导正确标准差的扩散率,我们先写出热方程PDE:

q(x,t)t=κ(t)Δxq(x,t).(82)\frac{\partial q(\boldsymbol{x}, t)}{\partial t} = \kappa(t) \Delta_{\boldsymbol{x}} q(\boldsymbol{x}, t). \tag{82}

公式82的傅里叶变换对应式(变换沿 ( \boldsymbol{x} ) 维度进行)为:

q^(ν,t)t=κ(t)ν2q^(ν,t).(83)\frac{\partial \hat{q}(\boldsymbol{\nu}, t)}{\partial t} = -\kappa(t) \|\boldsymbol{\nu}\|^2 \hat{q}(\boldsymbol{\nu}, t). \tag{83}

目标解 ( q(\boldsymbol{x}, t) ) 及其傅里叶变换 ( \hat{q}(\boldsymbol{\nu}, t) ) 由公式20给出:

q(x,t)=p(x;σ(t))=pdata(x)N(0,σ(t)2I)(84)q(\boldsymbol{x}, t) = p(\boldsymbol{x}; \sigma(t)) = p_{\text{data}}(\boldsymbol{x}) * \mathcal{N}\bigl( 0, \sigma(t)^2 \mathbf{I} \bigr) \tag{84}

q^(ν,t)=p^data(ν)exp(12ν2σ(t)2).(85)\hat{q}(\boldsymbol{\nu}, t) = \hat{p}_{\text{data}}(\boldsymbol{\nu}) \exp\left( -\frac{1}{2} \|\boldsymbol{\nu}\|^2 \sigma(t)^2 \right). \tag{85}

沿时间轴对目标解求导,我们有:

q^(ν,t)t=σ˙(t)σ(t)ν2p^data(ν)exp(12ν2σ(t)2)(86)\frac{\partial \hat{q}(\boldsymbol{\nu}, t)}{\partial t} = -\dot{\sigma}(t)\,\sigma(t)\,\|\boldsymbol{\nu}\|^2\,\hat{p}_{\text{data}}(\boldsymbol{\nu}) \exp\left( -\frac{1}{2} \|\boldsymbol{\nu}\|^2 \sigma(t)^2 \right) \tag{86}

q^(ν,t)t=σ˙(t)σ(t)ν2q^(ν,t).(87)\frac{\partial \hat{q}(\boldsymbol{\nu}, t)}{\partial t} = -\dot{\sigma}(t)\,\sigma(t)\,\|\boldsymbol{\nu}\|^2\,\hat{q}(\boldsymbol{\nu}, t). \tag{87}

公式83和公式87的左侧相同。令二者相等,可解出产生期望演化的 ( \kappa(t) ):

κ(t)ν2q^(ν,t)=σ˙(t)σ(t)ν2q^(ν,t)(88)-\kappa(t)\,\|\boldsymbol{\nu}\|^2\,\hat{q}(\boldsymbol{\nu}, t) = -\dot{\sigma}(t)\,\sigma(t)\,\|\boldsymbol{\nu}\|^2\,\hat{q}(\boldsymbol{\nu}, t) \tag{88}

κ(t)=σ˙(t)σ(t).(89)\kappa(t) = \dot{\sigma}(t)\,\sigma(t). \tag{89}

综上,对应噪声水平 ( \sigma(t) ) 的期望边缘密度由以下PDE生成:

q(x,t)t=σ˙(t)σ(t)Δxq(x,t)(90)\frac{\partial q(\boldsymbol{x}, t)}{\partial t} = \dot{\sigma}(t)\,\sigma(t)\,\Delta_{\boldsymbol{x}} q(\boldsymbol{x}, t) \tag{90}

其初始密度为 ( q(\boldsymbol{x}, 0) = p_{\text{data}}(\boldsymbol{x}) )。

B.5.2 我们的SDE推导

考虑如下SDE:

dx=f(x,t)dt+g(x,t)dωt,(91)\mathrm{d}\boldsymbol{x} = \boldsymbol{f}(\boldsymbol{x}, t)\,\mathrm{d}t + \boldsymbol{g}(\boldsymbol{x}, t)\,\mathrm{d}\omega_t, \tag{91}

Fokker-Planck PDE 描述了其解的概率密度 ( r(\boldsymbol{x}, t) ) 的时间演化:

r(x,t)t=x(f(x,t)r(x,t))+12xx:(D(x,t)r(x,t)),(92)\frac{\partial r(\boldsymbol{x}, t)}{\partial t} = -\nabla_{\boldsymbol{x}} \cdot \bigl( \boldsymbol{f}(\boldsymbol{x}, t)\,r(\boldsymbol{x}, t) \bigr) + \frac{1}{2} \nabla_{\boldsymbol{x}} \nabla_{\boldsymbol{x}} : \bigl( \mathbf{D}(\boldsymbol{x}, t)\,r(\boldsymbol{x}, t) \bigr), \tag{92}

其中 ( \mathbf{D}{ij} = \sum_k g{ik} g_{jk} ) 是扩散张量。我们考虑 ( \boldsymbol{g}(\boldsymbol{x}, t) = g(t),\mathbf{I} ) 的特殊情况(与 ( \boldsymbol{x} ) 无关的白噪声添加),此时方程简化为:

r(x,t)t=x(f(x,t)r(x,t))+12g(t)2Δxr(x,t).(93)\frac{\partial r(\boldsymbol{x}, t)}{\partial t} = -\nabla_{\boldsymbol{x}} \cdot \bigl( \boldsymbol{f}(\boldsymbol{x}, t)\,r(\boldsymbol{x}, t) \bigr) + \frac{1}{2} g(t)^2 \,\Delta_{\boldsymbol{x}} r(\boldsymbol{x}, t). \tag{93}

我们寻求一个SDE,其解的密度由公式90的PDE描述。令 ( r(\boldsymbol{x}, t) = q(\boldsymbol{x}, t) ),并令公式93和公式90相等,可得SDE必须满足的充分条件:

x(f(x,t)q(x,t))+12g(t)2Δxq(x,t)=σ˙(t)σ(t)Δxq(x,t)(94)-\nabla_{\boldsymbol{x}} \cdot \bigl( \boldsymbol{f}(\boldsymbol{x}, t)\,q(\boldsymbol{x}, t) \bigr) + \frac{1}{2} g(t)^2 \,\Delta_{\boldsymbol{x}} q(\boldsymbol{x}, t) = \dot{\sigma}(t)\,\sigma(t)\,\Delta_{\boldsymbol{x}} q(\boldsymbol{x}, t) \tag{94}

x(f(x,t)q(x,t))=(12g(t)2σ˙(t)σ(t))Δxq(x,t).(95)\nabla_{\boldsymbol{x}} \cdot \bigl( \boldsymbol{f}(\boldsymbol{x}, t)\,q(\boldsymbol{x}, t) \bigr) = \left( \frac{1}{2} g(t)^2 - \dot{\sigma}(t)\,\sigma(t) \right) \Delta_{\boldsymbol{x}} q(\boldsymbol{x}, t). \tag{95}

任何满足该方程的函数 ( \boldsymbol{f}(\boldsymbol{x}, t) ) 和 ( g(t) ) 都构成所求的SDE。现在我们寻找这类解的一个具体族。关键思路源于恒等式 ( \nabla_{\boldsymbol{x}} \cdot \nabla_{\boldsymbol{x}} = \Delta_{\boldsymbol{x}} )。实际上,若令 ( \boldsymbol{f}(\boldsymbol{x}, t),q(\boldsymbol{x}, t) = v(t),\nabla_{\boldsymbol{x}} q(\boldsymbol{x}, t) )(对任意 ( v(t) ) 成立),则 ( \Delta_{\boldsymbol{x}} q(\boldsymbol{x}, t) ) 会出现在两边并抵消:

x(v(t)xq(x,t))=(12g(t)2σ˙(t)σ(t))Δxq(x,t)(96)\nabla_{\boldsymbol{x}} \cdot \bigl( v(t)\,\nabla_{\boldsymbol{x}} q(\boldsymbol{x}, t) \bigr) = \left( \frac{1}{2} g(t)^2 - \dot{\sigma}(t)\,\sigma(t) \right) \Delta_{\boldsymbol{x}} q(\boldsymbol{x}, t) \tag{96}

v(t)Δxq(x,t)=(12g(t)2σ˙(t)σ(t))Δxq(x,t)(97)v(t)\,\Delta_{\boldsymbol{x}} q(\boldsymbol{x}, t) = \left( \frac{1}{2} g(t)^2 - \dot{\sigma}(t)\,\sigma(t) \right) \Delta_{\boldsymbol{x}} q(\boldsymbol{x}, t) \tag{97}

v(t)=12g(t)2σ˙(t)σ(t).(98)v(t) = \frac{1}{2} g(t)^2 - \dot{\sigma}(t)\,\sigma(t). \tag{98}

上述 ( \boldsymbol{f}(\boldsymbol{x}, t) ) 实际上与分数函数成比例,因为该公式匹配了密度对数的梯度:

f(x,t)=v(t)xq(x,t)q(x,t)(99)\boldsymbol{f}(\boldsymbol{x}, t) = v(t)\,\frac{\nabla_{\boldsymbol{x}} q(\boldsymbol{x}, t)}{q(\boldsymbol{x}, t)} \tag{99}

f(x,t)=v(t)xlogq(x,t)(100)\boldsymbol{f}(\boldsymbol{x}, t) = v(t)\,\nabla_{\boldsymbol{x}} \log q(\boldsymbol{x}, t) \tag{100}

f(x,t)=(12g(t)2σ˙(t)σ(t))xlogq(x,t).(101)\boldsymbol{f}(\boldsymbol{x}, t) = \left( \frac{1}{2} g(t)^2 - \dot{\sigma}(t)\,\sigma(t) \right) \nabla_{\boldsymbol{x}} \log q(\boldsymbol{x}, t). \tag{101}


EDM:扩散模型的设计空间(下)
http://dbqdss.github.io/2025/08/06/EDM:扩散模型的设计空间(下)/
作者
失去理想的獾
发布于
2025年8月6日
许可协议