DDPM¶
约 2473 个字 预计阅读时间 8 分钟
关于 DDPM 的三种理解方式和一些直观的解释
认识 DDPM¶
首先,认识 DDPM 最好的方式就是先去读一读大佬对于直观理解 DDPM 的讲解博客,这里引用【苏剑林](https://kexue.fm/archives/9119) 大佬的博客,建议大家先读完第一篇扩散模型漫谈,明白 DDPM 是做什么的,再回到本文。
在高维空间中游走¶
DDPM,按照原文的介绍,不只可以用于处理图像信息,但是原文的分析主要是从图像来讲的(是吧),所以我也从图像分析的角度来说一下吧。
什么是图像?按照最简单的角度理解,图像就是像素值矩阵,我们只要把一幅图从左上角到右下角的所有像素值(可能分 RGB 三个矩阵)都存储起来,就记录了这个个图像的所有信息。
如果我们再把这个图像从矩阵拼接成一个向量,图像就成了一个很高维度的空间中的一个点(比如 1024*1024 维),在这个空间中,有些点是有意义的图像(因为我们总能拿来一个有意义的图像,把图像映射到点上),大部分却是无意义的噪声图(想象一下电视花屏,我们随机在空间里取一个点大概就是这样的),DDPM 就是试图从这个空间的随机一个点出发,希望找到一个有意义的点,也就是在高维空间中游走。
图像的概率分布¶
我们很好理解作为一个点的图像,因为我们把这个点坐标拿出来,还原成矩阵,就可以还原到图像了,那么什么是图像的概率分布?
想象一个坐标为(1,2)的点,和一个均值为(1,2)的二维正态分布,这其实就是二者的区别,一个点可以准确还原到一幅图像,但概率分布并不准确,他只是告诉我们这幅图是(1,2)的概率最高,它也可能是其他图像,它是其他图像的概率可以通过概率分布算出,比如我们可以确定他是(1,3)的概率是多少。
那么,给定一个图像概率分布,我们如何获取图像呢?我们可以直接取均值的到图像,那就是(1,2),这么做的道理是:它是这幅图的概率最高,我们也可以按照概率去一个点,可能是(1,2),也可能是(2,4),但(1,2)的概率最高,一些很离谱的点(偏离均值很远)的概率就很低。
回到 DDPM 的向前过程,我们的起点是空间中的一个点,比如说(1,2),我们走了一步之后,就不再得到一个点,而是得到一个概率分布,比如说可以近似看成均值为(1,2)的概率分布,可以想像我们从空间中的一个点出发,走了一步后,可能在以其为圆心的任何一个地方上,不同的距离,概率都不同(可以类比电子云).我们再走第二步,距离就会更远一些,概率分布更分散一些。
按照原文走出一千步,我们的图像就会变成一个以原图为圆心的弥散的广阔的类似球的概率分布内了.因其广阔,我们大概最终会停在某一个无意义的点上.注意,这里我说的停在一个噪点上,是指从最终的分布中取均值或者随机采样得到一个点,以从分布跌回确定值。
原文中给出了每步怎么走的公式,其具有的良好性质就是:我们可以直接从给订的起点,\(\alpha\),\(\beta\)中直接推出第 N 步我们的分布,可以看到,1k 步后\(\bar \alpha\)接近于林,我们的分布取均值也是取不到原图的了,均值已经在 0 附近.总体分布接近高斯分布。
回到起点¶
从有意义的点走出 1000 步,到达的终点,近乎服从高斯分布,那么我们直接从高斯分布采样,就能得到近似从某个/任何一个有意义的点出走所达到的终点了,现在我们想回去,怎么回?
到这里我们就可以解答原模型中一个可能会令人困惑的点了,假如我们走了 500 步,到了一个点,我们走下一步,就是走到 500 为圆心的一个概率分布上,如果我们训练一个神经网络来拟合的话,不就是在拟合高斯吗,高斯又何须拟合呢?
问题就在这里了,现在我们随机采了一个点,我们说是终点,我们要走回去,当然不能以每一步都是高斯走,那最终走到的还是无意义的点,我们希望知道,我现在在这里,那我上一步最有可能是从哪里来的?我要回到那个地方去。
想象一下,如果我们采样的点在起点那个圆的上方,那么我们希望的回答是:我们最有可能从下面来,要回去的话,往下走.如果我们在下面,我们最有可能从上面来,往上走.所以,我们最有可能从哪里来会随着我现在在哪里而变化,所以,神经网络预测的时候需要我们现在的位置(\(x_t\))和时间,来拟合我们最有可能来的位置。
于是我们开始训练,第一次训练可能网络一步步往上走,使得模型倾向于在这些路径点上输出:嘿,我是从下面来的,往下走就可以回去.第二次,我们可能往左走,同样让神经网络在这条路径上懂得如何回归.最终,训练完毕后,神经网络就懂得如何从任意一个点走回去了。
所以我们要设计合适的损失函数让训练过程可以更好的领路回去,这就是文章中数学部分分析的工作了。
神奇的语意融合¶
我也不知道在这里适用于以融合对不对,但事实就是,输入的起点不唯一,我们最后走回来的话,不仅可以走回起点,而且可以走回起点之间的有意义的点,比如输入满月和半月,可以生成 ¾ 月,这两幅图在空间中的不同点上,相聚有一段距离,他们中间的点大部分都是无意义的噪点,但是却能被成功走到,实在是神奇。
进一步的学习¶
以上是对 DDPM 的一个理解的角度,大佬的博客还给出了对于 DDPM 的两种不同的理解的视角
我们先读第一篇,这篇可以跟读.首先文章说明了把问题分解成多步可以大幅度提高拟合能力,这是很有教义的观察。
联合散度的最消化目标也很好理解,首先要知道 KL 散度是衡量两个分布之间相近程度的非负函数,KL 散度为零当且仅当两个分布完全相同.我们希望走出来的路跟走回去的路是尽可能相近的,所以要最小化 KL 散度。
随后就是数学处理,去掉最小化目标中的无关常数项.并化简损失函数使之可用。
超参设置中,要求边缘分布相等,其实就是要求终点的分布要近似于高斯,那就是\(\bar \alpha\)近乎 0 就可以了。
有了上述的铺垫,我认为这篇文章的脉络是相当清晰的,我们进入下一篇。
第三篇的前面一大部分都跟第一篇完全一样,但到了贝叶斯的部分,开始有了不一样,我们要知道,在 DDPM 中,\(p(x_{t-1}|x_t) = p(x_{t-1}|x_t,x_0)\),因为这是一个马尔可夫过程,也就是每一个状态仅与前一状态有关,而与更前的状态无关,我们也可以理解成,我们只需要知道 500 步到哪里了,对于给定的参数,我们就能有 501 步的分布,第 0-499 步,给不给都无所谓,不改变分布,但是引入了起点后,我们是可以从起点直接推\(x_t,x_{t-1}\)的,这就让我们的贝叶斯公式中许多项成为已知。
随后我们就可以一路推导出损失函数.注意,在这个推导流程下\(x_0\)给定,我们能直接得到往回走的方向,所以我们唯一需要训练的是模型猜原点的能力,而这一损失函数,正如推导出来的那样,是跟前面几个视角一样的。
到了预估修正一节,我们的行为有可以被解释为,先猜出一个起点,往回走,并通过损失函数修正起点,然后走下一次,这样我们起点猜的越来越准,回到起点的能力也越来越强了。
此处的第三种理解是与后来的 DDIM 有很强的关联性的,比如说我们利用了贝叶斯公式之后其实只需要各个时间点对于\(x_0\)的分布而不需要其对于前一个时间点的分布了,但是要注意,我们从后往前推仍然需要一步一步走,如果我们一步走完的话,就会退化回 VAE 了。