核心集和以往数据压缩算法存在的问题
核心集选择(Coreset selection)算法基于一定的启发式准则(Heuristic criteria),如核心集到原数据集中心的距离最小、核心集包含的样本要尽可能地多样等,从原训练集中选择一个更小的核心训练集。这种方法选择了更具有代表性的数据来进行训练,因而降低了训练的开销,但是它也存在如下缺点:
- 核心集选择依赖的几乎都是贪心算法,这可能无法达到全局最优;
- 有效性极度依赖原训练集的有效性。
数据蒸馏(Dataset distillation,或数据缩合,Dataset condensation),一定程度上解决了核心集选择的问题,但无论是最原始的数据蒸馏算法,还是前面提到的基于梯度匹配的数据缩合算法都存在一定的缺陷:
- 模型参数和合成数据集的双重梯度下降十分消耗算力;
- 对合成数据集进行梯度下降时,要计算二阶混合偏导;
- 用于合成数据的网络的超参数不好调节(e.g. 和的梯度下降次数和)。
这些缺陷限制了其在大数据集上的应用。
分布匹配
分布匹配(Distribution matching)是这篇文章使用的用于数据缩合的方法。实际上,它更像是一种可以学习的核心集:
- 相比于纯核心集,分布匹配学习到的中的数据不一定存在于原训练集中;
- 相比于基于梯度匹配的数据缩合,分布匹配学习到的中的数据分布更加接近原训练集,且缩合的速度更快;
- 可以将分布匹配理解为用学习的方法去获得一个更小的、与原训练集同分布的合成集。

Fig. 1. Dataset condensation with distribution matching
要让学习到的与原训练集近似于同分布,要点在于有个确切的方法能衡量两个分布的近似情况。文中采用的是常用的最大均值差异法(Maximum mean discrepancy):
其中,是原训练集,是合成集;和分别是原训练集和合成集单一样本的特征;是一个带参数的函数(实际上是一个神经网络,为其模型参数,服从分布;也是再生希尔伯特空间中的一个向量),该函数将样本的特征映射到更低的维度,便于处理(实际上是求一个高阶矩)。
文中用的实验样例是图片分类,因此样本的特征就是图片。作者还在式的基础上考虑了数据增强(Data augmentation)以更好地适应训练图片数据的实际情况,因此式又可变为:
其中,是数据增强参数空间,则是相应的增强操作,其对和是一样的。
重要代码
以下只是对文中代码的简单实现(并不能运行),具体代码详见文章提供的开源部分Dataset Condensation with Distribution Matching。

Fig. 2. Pseudocode in paper
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| num_exp = 5 num_classes = 10 ipc = 10 channel = 1 im_size = [28, 28] device = torch.device('cuda') lr_img = 1.0 K = 1000 batch_real = 256 batch_train = 256
for exp in range(num_exp): '''对原训练集和合成集的初始化'''
image_syn = torch.randn(size=(num_classes*ipc, channel, im_size[0], im_size[1]), dtype=torch.float32, requires_grad=True, device=device) label_syn = torch.tensor([np.ones(ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=device).view(-1)
'''优化算法、损失函数''' optimizer_image = torch.optim.SGD([image_syn, ], lr=lr_img, momentum=0.5) optimizer_image.zero_grad()
'''for k = 0, ... ,K - 1''' for k in range(K+1): '''sample \theta''' net = get_network().to(device) net.train() for param in list(net.parameters()): param.requires_grad = False embed = net.embed
loss = torch.tensor(0.0).to(device)
'''sample mini-batch pairs for each class''' for c in range(num_classes): img_real = get_images(c, batch_real) img_syn = image_syn[c*ipc:(c+1)*ipc].reshape(ipc, channel, im_size[0], im_size[1])
'''数据过网络''' output_real = embed(img_real).detach() output_syn = embed(img_syn) '''MMD''' loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)
'''更新合成图片''' optimizer_img.zero_grad() loss.backward() optimizer_img.step()
|
只有学习率一个可调参数。
实验设计
后续的内容几乎与Dataset Condensation一致,只不过用到了更大的数据集。实验结果也表明,相比于DC得到的合成集,DM得到的合成集的分布更加均匀、更加接近原训练集。