【AIGC专题】DMD:分布匹配蒸馏法,单步Diffusion效果堪比SD

在这里插入图片描述


本文将对《One-step Diffusion with Distribution Matching Distillation》这篇文章进行解读。扩散模型生成高质量图像但需数十次前向计算。该论文提出分布匹配蒸馏法(Distribution Matching Distillation,DMD),可将扩散模型转化为单步图像生成器,质量媲美Stable Diffusion(SD)但速度快数个数量级。
参考资料如下:
[1]. DMD论文地址
[2]. 项目主页


专题介绍

在当今数字时代,图像与视频已成为信息传播的核心载体,而人工智能生成内容(AIGC)技术正以前所未有的力量重塑着这一领域。本专题将深入剖析图像视频 AIGC 技术的前沿动态、技术架构、应用场景以及未来趋势。AIGC 图像视频技术的核心在于强大的算法和先进的模型架构,从生成对抗网络(GAN)的巧妙对抗,到扩散模型的逐步演化,这些技术如何在像素的海洋中创造出逼真且富有创意的图像与视频?一起来领略这场视觉革命的无限魅力,欢迎探讨交流。

一、研究背景

扩散模型虽能生成高保真图像,但迭代采样过程(10-1000步)计算成本高,难以满足实时应用需求。因此有不少针对加速扩散模型的研究。

  • 第一种是快速扩散采样器,它可以显著减少预训练扩散模型所需的采样步骤数,从千级降到 20–50,但再降则质量骤降。
  • 第二种是知识蒸馏方法,如Progressive Distillation、Consistency Models、InstaFlow。需拟合高维噪声与图像映射,且依赖完整去噪轨迹计算,成本高昂。也有方法试图压缩为单步生成,但存在性能损失或训练不稳定问题。

因此,该论文提出了分布匹配蒸馏法(Distribution Matching Distillation,DMD),该方法可将扩散模型转化为单步图像生成器,同时对图像质量影响较小。通过最小化一个近似KL散度实现单步生成器在分布层面上与多步扩散模型匹配,该散度的梯度可表示为两个评分函数的差值:一个来自目标分布,另一个来自单步生成器合成的分布。这两个评分函数分别通过在不同分布上独立训练的两个扩散模型进行参数化。结合与多步扩散输出的大尺度结构相匹配的简单回归损失,使得其效果优于所有已发布的少步扩散方案,质量“媲美”Stable Diffusion但速度快数个数量级。

一句话概括该论文方法:结合分布匹配与扩散模型优势,通过动态分数建模和回归正则化,解决单步蒸馏的质量-速度权衡问题。

先看个效果,这里原文作者卖了个关子,“猜猜哪些是SD生成的,哪些是DMD生成的”。答案在文末,反正博主第一次看没有全猜对。

在这里插入图片描述

二、方法细节

目标很明确:基于预训练扩散模型 μ b a s e μ_{base} μbase(教师模型),蒸馏训练出一个单步生成器 G θ G_θ Gθ
如何得到一个高质量的 G θ G_θ Gθ是难题,该论文的整体方案框架如下图所示,

在这里插入图片描述
这里可以拆分成两部分来看,

  • 先看左边红框部分,主体为单步生成器 G θ G_θ Gθ。其中paired dataset是由教师模型离线生成的成对数据(噪声&生成结果)。每次训练时,会跑两次 G θ G_θ Gθ,第一次是随机噪声输入,第二次是paired dataset的噪声输入。前者用于计算分布匹配损失,后者用于计算回归损失。
  • 右边蓝框部分,主体为real教师模型和fake教师的模型,前者固定,后者可调(通过diffusion loss微调),两者联合起来计算分布匹配损失,用于更新 G θ G_θ Gθ的参数。

我们先来看下重点部分,Distribution Matching Loss。

2.1 分布匹配损失

论文的思路是在生成器的输出分布上与教师模型去对齐,因此可以用KL散度作为分布匹配的目标函数来约束,公式如下:

D K L ( p fake ∥ p real ) = E x ∼ p fake ( log ⁡ ( p fake ( x ) p real ( x ) ) ) = E z ∼ N ( 0 , I ) x = G θ ( z ) ( − ( log ⁡ p real ( x ) − log ⁡ p fake ( x ) ) ) \begin{aligned} D_{KL} \left( p_{\text{fake}} \parallel p_{\text{real}} \right) &= \mathbb{E}_{x \sim p_{\text{fake}}} \left( \log \left( \frac{p_{\text{fake}}(x)}{p_{\text{real}}(x)} \right) \right) \\ &= \mathbb{E}_{\substack{z \sim \mathcal{N}(0, I) \\ x = G_\theta(z)}} \left( -(\log p_{\text{real}}(x) - \log p_{\text{fake}}(x)) \right) \end{aligned} DKL(pfakepreal)=Expfake(log(preal(x)pfake(x)))=EzN(0,I)x=Gθ(z)((logpreal(x)logpfake(x)))

由于直接计算概率分布难度大,因此转为计算梯度,变为下式(推导公式可见原文附录F):
∇ θ D K L = E s u b a r r a y z ∼ N ( 0 , I ) x = G θ ( z ) [ − ( s real ( x ) − s fake ( x ) ) d G d θ ] \nabla_\theta D_{KL} = \mathbb{E}_{subarray{z \sim \mathcal{N}(0, I) \\ x = G_\theta(z)}} \left[ -(s_{\text{real}}(x) - s_{\text{fake}}(x)) \frac{dG}{d\theta} \right] θDKL=EsubarrayzN(0,I)x=Gθ(z)[(sreal(x)sfake(x))dθdG]

其中

  • s real ( x ) = ∇ x log ⁡ p real ( x ) s_{\text{real}}(x) = \nabla_x \log p_{\text{real}}(x) sreal(x)=xlogpreal(x):真实分布的分数函数(指向真实数据密度更高的方向)
  • s fake ( x ) = ∇ x log ⁡ p fake ( x ) s_{\text{fake}}(x) = \nabla_x \log p_{\text{fake}}(x) sfake(x)=xlogpfake(x):生成分布的分数函数(指向生成数据密度更高的方向)

s real s_{\text{real}} sreal将x移向 p real p_{\text{real}} preal,而 − s fake -s_{\text{fake}} sfake将他们分开。论文中采用了一对扩散模型,分别计算 s real s_{\text{real}} sreal s fake s_{\text{fake}} sfake,这俩代表了梯度场。

然而计算梯度也是有难度的,主要有两个问题

  • 结果为低概率样本时,概率密度很小,分数易发散。特别是初始化的real分布和fake分布无重叠时,会存在一方的概率接近0,分数可能消失,因为log(0)是无定义的。
  • 用扩散模型来估计分数,只能学习噪声扰动后分布的分数。当噪声很低的时候,学习意义就不大了。

Score-SDE(参考文献)给了解决方案。通过给生成器的输出添加不同标准差的随机高斯噪声扰动,使得real和fake的分布有重叠,那么分布匹配的目标才是明确的。下图就比较直观地给出了加噪的作用。

在这里插入图片描述
加噪的方式正好与扩散过程一致,因此两个分数可以用扩散模型来计算,具体公式如下,
s real ( x t , t ) = − x t − α t μ base ( x t , t ) σ t 2 s_{\text{real}}(x_t, t) = -\frac{x_t - \alpha_t \mu_{\text{base}}(x_t, t)}{\sigma_t^2} sreal(xt,t)=σt2xtαtμbase(xt,t) s fake ( x t , t ) = − x t − α t μ fake ϕ ( x t , t ) σ t 2 s_{\text{fake}}(x_t, t) = -\frac{x_t - \alpha_t \mu_{\text{fake}}^\phi(x_t, t)}{\sigma_t^2} sfake(xt,t)=σt2xtαtμfakeϕ(xt,t)

其中 t t t为扩散时间步,随机选取, x t x_t xt为生成器加噪后的结果, α t \alpha_t αt σ t \sigma_t σt为扩散参数。

其中可调的fake模型,通过diffusion loss来进行更新,
L denoise ϕ = ∥ μ fake ϕ ( x t , t ) − x 0 ∥ 2 2 \mathcal{L}_{\text{denoise}}^\phi = \left\| \mu_{\text{fake}}^\phi(x_t, t) - x_0 \right\|_2^2 Ldenoiseϕ= μfakeϕ(xt,t)x0 22

在实际计算分布匹配梯度时,作者又加入了一个小巧思,设计了权重因子 w t w_t wt来归一化不同噪声水平下的梯度幅值,使得公式变为,
w t = σ t 2 α t C S ∥ μ base ( x t , t ) − x ∥ 1 w_t = \frac{\sigma_t^2}{\alpha_t} \frac{CS}{\|\mu_{\text{base}}(x_t, t) - x\|_1} wt=αtσt2μbase(xt,t)x1CS ∇ θ D K L ≃ E z , t , x , x t [ w t α t ( s fake ( x t , t ) − s real ( x t , t ) ) d G d θ ] \nabla_\theta D_{KL} \simeq \mathbb{E}_{z, t, x, x_t} \left[ w_t \alpha_t \left( s_{\text{fake}}(x_t, t) - s_{\text{real}}(x_t, t) \right) \frac{dG}{d\theta} \right] θDKLEz,t,x,xt[wtαt(sfake(xt,t)sreal(xt,t))dθdG]

S为空间位置数,C为通道数。原作者认为不同噪声水平下,扩散模型预测的误差量级差异较大,直接使用未加权的梯度可能导致某些 t t t 的更新主导训练,而其他 t t t 的贡献被忽略。权重 w t ​ w_t​ wt 通过归一化梯度幅值,确保所有噪声水平的更新均衡。误差越大,权重越小,从而抑制高误差区域的梯度,类似于自适应学习率。

在附录中给出分布匹配损失的算法说明,还是比较直观的,如下图
在这里插入图片描述
其中权重因子公式稍有偏差,但不影响其作用。最后一步,则用 ( x − g r a d ) (x-grad) xgrad作为目标,MSE 作为损失。是把“做一步梯度下降”这个优化步骤巧妙地变成了监督信号,可以用普通 SGD/Adam 继续训练,而不必手写更新规则。 而 stopgrad保证了只学习如何预测,却不让预测目标本身继续产生梯度, 从而把优化步骤无缝嵌入到端到端训练里,让网络能够“预测一次梯度下降之后的位置”。

2.2 回归损失

分布匹配损失对于受噪声影响大的生成样本是有明显作用的,但小噪声样本 s real s_{\text{real}} sreal会不太可靠,优化容易受到模式崩溃/下降的影响,其中假分布为模式的子集分配了更高的总体密度。为了避免此问题,采用了额外的回归损失来确保所有的模式都被保留,稳定训练和缓解模式崩溃。下图非常直观地体现了两个损失的作用。

在这里插入图片描述

  • 只用real分数,容易模式崩溃,直接聚集在一个小分布范围内。
  • 加上fake分数,能够较好地学习到一个模式。
  • 加上回归损失,能够学习到多个模式。

那么回归损失采用了LPIPS loss。 即衡量在相同噪声输入下,生成器与预训练扩散模型输出之间的感知距离。
L reg = E ( z , y ) ∼ D [ ℓ ( G θ ( z ) , y ) ] \mathcal{L}_{\text{reg}} = \mathbb{E}_{(z, y) \sim \mathcal{D}} \left[ \ell(G_\theta(z), y) \right] Lreg=E(z,y)D[(Gθ(z),y)]

最终过程就是

  • fake模型用 L denoise ϕ \mathcal{L}_{\text{denoise}}^\phi Ldenoiseϕ训练
  • 生成器模型用 ∇ θ D K L \nabla_\theta D_{KL} θDKL+ L reg \mathcal{L}_{\text{reg}} Lreg训练

除此之外,还使用了CFG(classifier-free guidance)策略,用于提升生成图像质量,这里的引导尺寸是固定的。

三、实验论证

论证方面主要用了FID和CLIP Score这两个评价指标。

  • 基准测试

    • ImageNet-64×64:DMD FID 2.62,比Consistency Model提升2.4倍,速度提升512倍。
    • 零样本文生图测试:在加速型diffusion方案里效果最佳,FID 11.49,接近Stable Diffusion1.5(FID 8.78),速度提升接近30倍(。
      在这里插入图片描述
      在这里插入图片描述
  • 消融实验

    移除分布匹配损失会导致图像结构失真,且FID上升。
    在这里插入图片描述
    移除回归损失会导致模式崩溃,多样性降低,生成了类似的小车。
    在这里插入图片描述

采用自适应权重 w t w_t wt,FID提升0.9。
在这里插入图片描述

感知效果上,DMD能够产生与Stable Diffusion相媲美的高质量图像,在相同速度下,效果明显优于其他单步Diff或者少步Diff。从下图上看确实在纹理细节上的表现还是不错的。
在这里插入图片描述

图1中DMD的结果 (left to right): bottom, top, bottom, bottom, top

四、总结和思考

DMD首次将分布匹配思想引入扩散蒸馏,为实时图像生成提供了高效解决方案,有一定通用性,适用于EDM、Stable Diffusion等主流扩散模型。其价值在于拉近了单步Diff与真实数据的分布。

但也有一定局限性,与千步采样相比仍存在细微质量差距;训练内存消耗大(需同时更新生成器和fake模型)。这点可以采用LoRA技术来规避。后续借鉴DMD的几个技术方案也都采用了LoRA,如OSEDiff


感谢阅读,欢迎留言或私信,一起探讨和交流。
如果对你有帮助的话,也希望可以给博主点一个关注,感谢。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值