SDE视角下的生成模型(下)

本文最后更新于 2025年8月12日 晚上

欧拉-丸山(Euler-Maruyama)法

wiki百科上的词条

Euler法(ODE):

dxdt=a(x(t))\frac{dx}{dt} = a(x(t))

有迭代格式

x(t+Δt)=x(t)+a(x(t))Δtx(t+\Delta t) = x(t) + a(x(t)) \Delta t

丸山法(SDE):

dx=f(x,t)dt+g(t)dWd\mathbf{x} = \mathbf{f}(\mathbf{x},t)dt + g(t)d W

有迭代格式

x(t+Δt)=x(t)+f(x,t)Δt+g(t)ΔW\mathbf{x}(t+\Delta t) = \mathbf{x}(t) + \mathbf{f}(\mathbf{x},t)\Delta t + g(t) \Delta W

文中将上述迭代格式写为

xi+1=xi+f(x,t)Δt+g(t)ΔtZ\mathbf{x}_{i+1} = \mathbf{x}_i + \mathbf{f}(\mathbf{x},t)\Delta t + g(t) \sqrt{\Delta t} Z

简化为

xi+1=xi+f(x,t)+g(t)Z\mathbf{x}_{i+1} = \mathbf{x}_i + \mathbf{f}(\mathbf{x},t) + g(t) Z

DDPM与SMLD在SDE视角下的大一统

Forward process:

dx=f(x,t)dt+g(t)dWd\mathbf{x} = \mathbf{f}(\mathbf{x},t)dt + \mathbf{g}(t)d W

Backward process:

dx=[f(x,t)g2(t)sθ(t)]dt+g(t)dwd\mathbf{x} = \left[\mathbf{f}(\mathbf{x}, t) - g^2(t) s_\theta(t) \right]dt + g(t)d w

DDPM

DDPM回忆与转化

回忆加噪过程

xi=1βixi1+βi ϵ,ϵN(0,1),i=1,,Nx_i = \sqrt{1 - \beta_i}x_{i-1} + \sqrt{\beta_i} \ \epsilon, \quad \epsilon \sim N(0, 1), \quad i = 1,\ldots,N

定义如下变量:

x(t=iN)=xi,β(t=iN)=Nβi,Δt=1N,t[0,1]x(t = \frac{i}{N}) = x_i, \quad \beta(t = \frac{i}{N}) = N \beta_i, \quad \Delta t = \frac{1}{N}, \quad t \in [0, 1]

于是,

x(t+Δt)=1β(t+Δt)Nx(t)+β(t+Δt)N ϵ=1β(t+Δt)Δtx(t)+β(t+Δt)Δt ϵ(Taylor 展开)(112β(t+Δt)Δt)x(t)+β(t+Δt)Δt ϵ(取极限)=(112β(t)Δt)x(t)+β(t)Δt ϵ\begin{align*} x(t+\Delta t) &= \sqrt{1 - \frac{\beta(t+\Delta t)}{N}} x(t) + \sqrt{\frac{\beta(t+\Delta t)}{N}} \ \epsilon\\ & = \sqrt{1 - \beta(t+\Delta t) \Delta t} x(t) + \sqrt{\beta(t+\Delta t) \Delta t} \ \epsilon \\ \text{(Taylor 展开)}& \approx \left( 1 - \frac{1}{2} \beta(t+\Delta t) \Delta t \right)x(t) + \sqrt{\beta(t+\Delta t) \Delta t} \ \epsilon \\ \text{(取极限)}& = \left( 1 - \frac{1}{2} \beta(t) \Delta t \right)x(t) + \sqrt{\beta(t) \Delta t} \ \epsilon \\ \end{align*}

移项得

x(t+Δt)x(t)=12β(t)Δtx(t)+β(t)Δt ϵx(t+\Delta t) - x(t) = -\frac{1}{2} \beta(t) \Delta t x(t) + \sqrt{\beta(t) \Delta t} \ \epsilon

得到 <span style="color:red">DDPM 对应的SDE,

dx=12β(t)x(t)dt+β(t) dWd x = -\frac{1}{2} \beta(t) x(t) dt + \sqrt{\beta(t)} \ d W

与Ito过程SDE的对应项:

  • f(x,t)=12β(t)x(t)f(x,t) = -\frac{1}{2} \beta(t) x(t)
  • g(t)=β(t)g(t) = \sqrt{\beta(t)}

逆过程的表达式为:

dx=[12β(t)x(t)β(t)sθ(t)]dt+β(t)dwd x = \left[ - \frac{1}{2} \beta(t) x(t) - \beta(t) s_\theta(t) \right]dt + \sqrt{\beta(t)} d w

回忆DDPM中的

xt=αtx0+1αt ϵ  ϵ=xtαtx01αtx_t = \sqrt{\overline{\alpha_t}} x_0 + \sqrt{1 - \overline{\alpha_t}} \ \epsilon \ \Leftrightarrow \ \epsilon = \frac{x_t - \sqrt{\overline{\alpha_t}}x_0}{\sqrt{1 - \overline{\alpha_t}}}

因此,可以得到如下关系:

xtlogp(xtx0)=xt(xtμt)22σt2=xtμtσt2=xtαtx01αt=ϵθ1αt=sθ\begin{align*} \nabla_{x_t} \log p(x_t | x_0) &= - \nabla_{x_t} \frac{(x_t - \mu_t)^2}{2 \sigma_t^2} = - \frac{x_t - \mu_t}{\sigma_t^2} \\ &= - \frac{x_t - \sqrt{\overline{\alpha_t}}x_0}{1 - \overline{\alpha_t}} = - \frac{\epsilon_\theta}{\sqrt{1 - \overline{\alpha_t}}} = s_\theta \end{align*}

DDPM采样过程的迭代格式转化为score function的形式:

xt1=1αt(xt1αt1αˉtϵθ)+σtZ=11β(t)(x(t)ϵθ1αtβ(t))+β(t)Z=11β(t)(x(t)+sθ(t)β(t))+β(t)Z\begin{align*} x_{t-1} &= \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \mathbf{\epsilon}_\theta \right) + \sigma_t Z \\ &= \frac{1}{\sqrt{1-\beta(t)}}\left(x(t) -\frac{\epsilon_\theta}{\sqrt{1-\overline{\alpha_t}}} \beta(t) \right) + \sqrt{\beta(t)} Z \\ &= \frac{1}{\sqrt{1-\beta(t)}}\left(x(t) +s_\theta(t) \beta(t) \right) + \sqrt{\beta(t)} Z \end{align*}

对DDPM的逆向过程使用Euler-Maruyama法,将具体函数代入迭代格式中,

xi=xi+1f(x,t)+g2(t)sθ(t)+g(t)Z=xi+1+12βi+1xi+1+βi+1sθ(t)+βi+1Zxi+1+12βi+1xi+1+βi+1sθ(t)+12βi+12sθ+βi+1Z(1+12βi+1+o(βi+1))(xi+1+βi+1sθ)+βi+1Z11βi+1(xi+1+βi+1sθ)+βi+1Z\begin{align*} \mathbf{x}_i &= \mathbf{x}_{i+1} - \mathbf{f}(x,t) + g^2(t) s_\theta(t) + g(t) Z \\ &= \mathbf{x}_{i+1} + \frac{1}{2} \beta_{i+1} x_{i+1} + \beta_{i+1} s_\theta(t) + \sqrt{\beta_{i+1}} Z \\ &\approx \mathbf{x}_{i+1} + \frac{1}{2} \beta_{i+1} x_{i+1} + \beta_{i+1} s_\theta(t) + \frac{1}{2} \beta_{i+1}^2 s_\theta + \sqrt{\beta_{i+1}} Z \\ &\approx \left( 1 + \frac{1}{2} \beta_{i+1} + o(\beta_{i+1})\right) \left( \mathbf{x}_{i+1} + \beta_{i+1} s_\theta \right) + \sqrt{\beta_{i+1}} Z \\ &\approx \frac{1}{\sqrt{1-\beta_{i+1}}}(\mathbf{x}_{i+1} + \beta_{i+1} s_\theta) + \sqrt{\beta_{i+1}} Z \\ \end{align*}

此即DDPM的采样过程。第三步是因为βi+10\beta_{i+1} \to 0,第四步是为了凑Taylor展开。

SMLD

考虑原论文的采样过程,

xi+1=xi+σi+12σi2 ϵixi+1=xi1+σi+12σi2 ϵi+σi2σi12 ϵi1xi+1=xi1+σi+12σi12 ϵixi+1=x0+σi+12σ02 ϵ\begin{align*} &x_{i+1} = x_i + \sqrt{\sigma_{i+1}^2 - \sigma_i^2} \ \epsilon_i \\ \Rightarrow &x_{i+1} = x_{i-1} + \sqrt{\sigma_{i+1}^2 - \sigma_i^2} \ \epsilon_i + \sqrt{\sigma_i^2 - \sigma_{i-1}^2} \ \epsilon_{i-1} \\ \Rightarrow &x_{i+1} = x_{i-1} + \sqrt{\sigma_{i+1}^2 - \sigma_{i-1}^2} \ \epsilon_i^\prime \\ \Rightarrow &\cdots \\ \Rightarrow &x_{i+1} = x_0 + \sqrt{\sigma_{i+1}^2 - \sigma_0^2} \ \epsilon \end{align*}

于是,有迭代格式

x(t+Δt)=x(t)+σ(t+Δt)2σ(t)2 Z=x(t)+σ(t+Δt)2σ(t)2Δt Δt Z\begin{align*} x(t+\Delta t) &= x(t) + \sqrt{\sigma(t+\Delta t)^2 - \sigma(t)^2} \ Z \\ &= x(t) + \sqrt{\frac{\sigma(t+\Delta t)^2 - \sigma(t)^2}{\Delta t}} \ \sqrt{\Delta t} \ Z \end{align*}

Δt0\Delta t \to 0,得到 <span style="color:red">SMLD 对应的SDE,

dx=dσ2(t)dt dWd x = \sqrt{\frac{d \sigma^2(t)}{dt}} \ dW

与Ito process前向过程的SDE对应项:

  • f(x,t)=0f(x,t) = 0
  • g(t)=dσ2(t)dtg(t) = \sqrt{\frac{d \sigma^2(t)}{dt}}

仿照前文,对SMLD的采样过程使用Euler-Maruyama法,将具体函数代入迭代格式中,

xi=xi+1f(x,t)+g2(t) sθ(t)+g(t)Z=xi+1+(σi+12σi2) sθ(t)+σi+12σi2 Z\begin{align*} x_i &= x_{i+1} - \mathbf{f}(x,t) + g^2(t) \ s_\theta(t) + g(t) Z \\ &= x_{i+1} + (\sigma_{i+1}^2 - \sigma_i^2)\ s_\theta(t) + \sqrt{\sigma_{i+1}^2 - \sigma_i^2} \ Z \\ \end{align*}

此即SMLD的采样过程。

预测-校正采样法

预测-校正(Predictor-Corrector sampling, PC)

预测器(predictor)可采用任何固定离散化策略的反向时间随机微分方程数值求解器,校正器(corrector)可采用任何基于分数的马尔可夫链蒙特卡洛方法(score-based MCMC approach)。

算法
算法

概率流ODE

Probability Flow ODE

边缘分布的等价性证明

概率流常微分方程(ODE)的思想受到Maoutsa等人(2020)的启发,其中可找到一个简化情况的推导。下面我们将推导式(17)中完全一般化的ODE。

考虑式(15)中的随机微分方程(SDE),其形式如下:

dx=f(x,t)dt+G(x,t)dwdx = \mathbf{f}(x, t)dt + \mathbf{G}(x, t)dw

其中

  • f(,t):RdRd\mathbf{f}(\cdot, t): \mathbb{R}^d \to \mathbb{R}^d
  • G(,t):RdRd×d\mathbf{G}(\cdot, t): \mathbb{R}^d \to \mathbb{R}^{d \times d}

边缘概率密度 pt(x(t))p_t(x(t)) 根据 Kolmogorov 前向方程(Fokker-Planck方程)<sup><a href="#ref1">[1]</a></sup> 演化:

pt(x)t=i=1dxi[fi(x,t)pt(x)]+12i=1dj=1d2xixj[k=1dGik(x,t)Gjk(x,t)pt(x)]\frac{\partial p_t(x)}{\partial t} = -\sum_{i=1}^{d} \frac{\partial}{\partial x_i} \left[ f_i(x, t) p_t(x) \right] + \frac{1}{2} \sum_{i=1}^{d} \sum_{j=1}^{d} \frac{\partial^2}{\partial x_i \partial x_j} \left[ \sum_{k=1}^{d} G_{ik}(x, t) G_{jk}(x, t) p_t(x) \right]

我们可将上式重写为,

pt(x)t=i=1dxi[fi(x,t)pt(x)]+12i=1dj=1d2xixj[k=1dGik(x,t)Gjk(x,t)pt(x)]=i=1dxi[fi(x,t)pt(x)]+12i=1dxi[j=1dxj[k=1dGik(x,t)Gjk(x,t)pt(x)]]\begin{align*} \frac{\partial p_t(x)}{\partial t} &= -\sum_{i=1}^{d} \frac{\partial}{\partial x_i} \left[ f_i(x, t) p_t(x) \right] + \frac{1}{2} \sum_{i=1}^{d} \sum_{j=1}^{d} \frac{\partial^2}{\partial x_i \partial x_j} \left[ \sum_{k=1}^{d} G_{ik}(x, t) G_{jk}(x, t) p_t(x) \right] \\ &= -\sum_{i=1}^{d} \frac{\partial}{\partial x_i} \left[ f_i(x, t) p_t(x) \right] + \frac{1}{2} \sum_{i=1}^{d} \frac{\partial}{\partial x_i} \left[ \sum_{j=1}^{d} \frac{\partial}{\partial x_j} \left[ \sum_{k=1}^{d} G_{ik}(x, t) G_{jk}(x, t) p_t(x) \right] \right] \tag{1} \end{align*}

注意到,

j=1dxj[k=1dGik(x,t)Gjk(x,t)pt(x)]=j=1dxj[k=1dGik(x,t)Gjk(x,t)]pt(x)+j=1dk=1dGik(x,t)Gjk(x,t)pt(x)xjlogpt(x)=pt(x) [G(x,t)G(x,t)]+pt(x)G(x,t)G(x,t)xlogpt(x)\begin{align*} &\sum_{j=1}^{d} \frac{\partial}{\partial x_j} \left[ \sum_{k=1}^{d} G_{ik}(x, t) G_{jk}(x, t) p_t(x) \right] \\ = &\sum_{j=1}^{d} \frac{\partial}{\partial x_j} \left[ \sum_{k=1}^{d} G_{ik}(x, t) G_{jk}(x, t) \right] p_t(x) + \sum_{j=1}^{d} \sum_{k=1}^{d} G_{ik}(x, t) G_{jk}(x, t) p_t(x) \frac{\partial}{\partial x_j} \log p_t(x) \\ = &p_t(x) \ \nabla \cdot \left[ G(x, t) G(x, t)^\top \right] + p_t(x) G(x, t) G(x, t)^\top \nabla_x \log p_t(x) \end{align*}

基于此,我们可继续重写式,

pt(x)t=i=1dxi[fi(x,t)pt(x)]+12i=1dxi[j=1dxj[k=1dGik(x,t)Gjk(x,t)pt(x)]]=i=1dxi[fi(x,t)pt(x)]+12i=1dxi[pt(x)[G(x,t)G(x,t)]+pt(x)G(x,t)G(x,t)xlogpt(x)]=i=1dxi{fi(x,t)pt(x)12[[G(x,t)G(x,t)]+G(x,t)G(x,t)xlogpt(x)]pt(x)}=i=1dxi[f~i(x,t)pt(x)]\begin{align*} \frac{\partial p_t(x)}{\partial t} &= -\sum_{i=1}^{d} \frac{\partial}{\partial x_i} \left[ f_i(x, t) p_t(x) \right] + \frac{1}{2} \sum_{i=1}^{d} \frac{\partial}{\partial x_i} \left[ \sum_{j=1}^{d} \frac{\partial}{\partial x_j} \left[ \sum_{k=1}^{d} G_{ik}(x, t) G_{jk}(x, t) p_t(x) \right] \right] \\ &= -\sum_{i=1}^{d} \frac{\partial}{\partial x_i} \left[ f_i(x, t) p_t(x) \right] + \frac{1}{2} \sum_{i=1}^{d} \frac{\partial}{\partial x_i} \left[ p_t(x) \nabla \cdot \left[ G(x, t) G(x, t)^\top \right] + p_t(x) G(x, t) G(x, t)^\top \nabla_x \log p_t(x) \right] \\ &= -\sum_{i=1}^{d} \frac{\partial}{\partial x_i} \left\{ f_i(x, t) p_t(x) - \frac{1}{2} \left[ \nabla \cdot \left[ G(x, t) G(x, t)^\top \right] + G(x, t) G(x, t)^\top \nabla_x \log p_t(x) \right] p_t(x) \right\} \\ &= -\sum_{i=1}^{d} \frac{\partial}{\partial x_i} \left[ \tilde{f}_i(x, t) p_t(x) \right] \end{align*}

其中我们定义:

f~(x,t):=f(x,t)12[G(x,t)G(x,t)]12G(x,t)G(x,t)xlogpt(x)(2)\tilde{f}(x, t) := f(x, t) - \frac{1}{2} \nabla \cdot \left[ G(x, t) G(x, t)^\top \right] - \frac{1}{2} G(x, t) G(x, t)^\top \nabla_x \log p_t(x) \tag{2}

观察式 (2) 可知,它等价于以下 SDE 的 Kolmogorov 前向方程(此时该方程也称为 Liouville 方程),其中 G~(x,t):=0\tilde{G}(x, t) := 0

dx=f~(x,t)dt+G~(x,t)dwdx = \tilde{f}(x, t) dt + \tilde{G}(x, t) dw

这本质上是一个常微分方程:

dx=f~(x,t)dtdx = \tilde{f}(x, t) dt

与式(17)给出的概率流ODE一致。因此,我们证明了概率流ODE(式17)与式(15)中的SDE具有相同的边缘概率密度 pt(x)p_t(x)

这是一段需要引用文献的内容,
另一段引用<sup><a href="#ref2">2</a></sup>

参考文献

  • [1] Stochastic differential equations
    Bernt Øksendal. In Stochastic differential equations, pp. 65–84. Springer, 2003.
  • [2] 霍金 S. 时间简史[M]. bantam, 2005.

SDE视角下的生成模型(下)
http://dbqdss.github.io/2025/08/10/SDE视角下的生成模型(下)/
作者
失去理想的獾
发布于
2025年8月10日
许可协议