Paper002: Dataset Condensation with Distribution Matching
核心集和以往数据压缩算法存在的问题
核心集选择(Coreset selection)算法基于一定的启发式准则(Heuristic criteria),如核心集到原数据集中心的距离最小、核心集包含的样本要尽可能地多样等,从原训练集中选择一个更小的核心训练集。这种方法选择了更具有代表性的数据来进行训练,因而降低了训练的开销,但是它也存在如下缺点:
- 核心集选择依赖的几乎都是贪心算法,这可能无法达到全局最优;
- 有效性极度依赖原训练集的有效性。
数据蒸馏(Dataset distillation,或数据缩合,Dataset condensation),一定程度上解决了核心集选择的问题,但无论是最原始的数据蒸馏算法,还是前面提到的基于梯度匹配的数据缩合算法都存在一定的缺陷:
- 模型参数$\theta ^{\mathcal{S}}$和合成数据集$\mathcal{S}$的双重梯度下降十分消耗算力;
- 对合成数据集$\mathcal{S}$进行梯度下降时,要计算二阶混合偏导;
- 用于合成数据的网络的超参数不好调节(e.g. $\theta ^{\mathcal{S}}$和$\mathcal{S}$的梯度下降次数$\varsigma ^{\mathcal{\theta}}$和$\varsigma ^{\mathcal{S}}$)。
这些缺陷限制了其在大数据集上的应用。
分布匹配
分布匹配(Distribution matching)是这篇文章使用的用于数据缩合的方法。实际上,它更像是一种可以学习的核心集:
- 相比于纯核心集,分布匹配学习到的$\mathcal{S}$中的数据不一定存在于原训练集中;
- 相比于基于梯度匹配的数据缩合,分布匹配学习到的$\mathcal{S}$中的数据分布更加接近原训练集,且缩合的速度更快;
- 可以将分布匹配理解为用学习的方法去获得一个更小的、与原训练集同分布的合成集。
要让学习到的$\mathcal{S}$与原训练集$\mathcal{T}$近似于同分布,要点在于有个确切的方法能衡量两个分布的近似情况。文中采用的是常用的最大均值差异法(Maximum mean discrepancy):
$$
\begin{align*}
\sup _{\psi _{\theta}\in\mathcal{H},||\psi _{\theta}||_\mathcal{H}\le1}(\text{E}[\psi _{\theta}(\mathcal{T})]&-\text{E}[\psi _{\theta}(\mathcal{S})])\tag{1}\\
\min \text{E} _{\theta\sim P _{\theta}} ||\frac{1}{|\mathcal{T}|}\sum\limits _{i=1} ^{|\mathcal{T}|}\psi _{\theta}(x_i)&-\frac{1}{|\mathcal{S}|}\sum\limits _{j=1} ^{|\mathcal{S}|}\psi _{\theta}(\mathcal{s}_j)|| ^2\tag{2}
\end{align*}
$$
其中,$\mathcal{T}$是原训练集,$\mathcal{S}$是合成集;$x$和$\mathcal{s}$分别是原训练集和合成集单一样本的特征;$\psi_\theta$是一个带参数的函数(实际上是一个神经网络,$\theta$为其模型参数,服从分布$P_\theta$;也是再生希尔伯特空间中的一个向量),该函数将样本的特征映射到更低的维度,便于处理(实际上是求一个高阶矩)。
文中用的实验样例是图片分类,因此样本的特征就是图片。作者还在式$(2)$的基础上考虑了数据增强(Data augmentation)以更好地适应训练图片数据的实际情况,因此式$(2)$又可变为:
$$
\min _{\omega\sim\Omega} \text{E} _{\theta\sim P _{\theta}} ||\frac{1}{|\mathcal{T}|}\sum\limits _{i=1} ^{|\mathcal{T}|}\psi _{\theta}(\mathcal{A}(x_i,\omega))-\frac{1}{|\mathcal{S}|}\sum\limits _{j=1} ^{|\mathcal{S}|}\psi _{\theta}(\mathcal{A}(\mathcal{s}_j,\omega))|| ^2\tag{3}
$$
其中,$\Omega$是数据增强参数空间,$\mathcal{A}(x,\omega)$则是相应的增强操作,其对$x$和$\mathcal{s}$是一样的。
重要代码
以下只是对文中代码的简单实现(并不能运行),具体代码详见文章提供的开源部分Dataset Condensation with Distribution Matching。
1 | # 参数定义 |
只有学习率一个可调参数。
实验设计
后续的内容几乎与Dataset Condensation一致,只不过用到了更大的数据集。实验结果也表明,相比于DC得到的合成集,DM得到的合成集的分布更加均匀、更加接近原训练集。