Paper002: Dataset Condensation with Distribution Matching

核心集和以往数据压缩算法存在的问题

核心集选择(Coreset selection)算法基于一定的启发式准则(Heuristic criteria),如核心集到原数据集中心的距离最小、核心集包含的样本要尽可能地多样等,从原训练集中选择一个更小的核心训练集。这种方法选择了更具有代表性的数据来进行训练,因而降低了训练的开销,但是它也存在如下缺点:

  1. 核心集选择依赖的几乎都是贪心算法,这可能无法达到全局最优;
  2. 有效性极度依赖原训练集的有效性。

数据蒸馏(Dataset distillation,或数据缩合,Dataset condensation),一定程度上解决了核心集选择的问题,但无论是最原始的数据蒸馏算法,还是前面提到的基于梯度匹配的数据缩合算法都存在一定的缺陷:

  1. 模型参数$\theta ^{\mathcal{S}}$和合成数据集$\mathcal{S}$的双重梯度下降十分消耗算力;
  2. 对合成数据集$\mathcal{S}$进行梯度下降时,要计算二阶混合偏导;
  3. 用于合成数据的网络的超参数不好调节(e.g. $\theta ^{\mathcal{S}}$和$\mathcal{S}$的梯度下降次数$\varsigma ^{\mathcal{\theta}}$和$\varsigma ^{\mathcal{S}}$)。

这些缺陷限制了其在大数据集上的应用。

分布匹配

分布匹配(Distribution matching)是这篇文章使用的用于数据缩合的方法。实际上,它更像是一种可以学习的核心集:

  1. 相比于纯核心集,分布匹配学习到的$\mathcal{S}$中的数据不一定存在于原训练集中;
  2. 相比于基于梯度匹配的数据缩合,分布匹配学习到的$\mathcal{S}$中的数据分布更加接近原训练集,且缩合的速度更快;
  3. 可以将分布匹配理解为用学习的方法去获得一个更小的、与原训练集同分布的合成集。

1

Fig. 1. Dataset condensation with distribution matching

要让学习到的$\mathcal{S}$与原训练集$\mathcal{T}$近似于同分布,要点在于有个确切的方法能衡量两个分布的近似情况。文中采用的是常用的最大均值差异法(Maximum mean discrepancy)

$$
\begin{align*}
\sup _{\psi _{\theta}\in\mathcal{H},||\psi _{\theta}||_\mathcal{H}\le1}(\text{E}[\psi _{\theta}(\mathcal{T})]&-\text{E}[\psi _{\theta}(\mathcal{S})])\tag{1}\\
\min \text{E} _{\theta\sim P _{\theta}} ||\frac{1}{|\mathcal{T}|}\sum\limits _{i=1} ^{|\mathcal{T}|}\psi _{\theta}(x_i)&-\frac{1}{|\mathcal{S}|}\sum\limits _{j=1} ^{|\mathcal{S}|}\psi _{\theta}(\mathcal{s}_j)|| ^2\tag{2}
\end{align*}
$$

其中,$\mathcal{T}$是原训练集,$\mathcal{S}$是合成集;$x$和$\mathcal{s}$分别是原训练集和合成集单一样本的特征;$\psi_\theta$是一个带参数的函数(实际上是一个神经网络,$\theta$为其模型参数,服从分布$P_\theta$;也是再生希尔伯特空间中的一个向量),该函数将样本的特征映射到更低的维度,便于处理(实际上是求一个高阶矩)。

文中用的实验样例是图片分类,因此样本的特征就是图片。作者还在式$(2)$的基础上考虑了数据增强(Data augmentation)以更好地适应训练图片数据的实际情况,因此式$(2)$又可变为:

$$
\min _{\omega\sim\Omega} \text{E} _{\theta\sim P _{\theta}} ||\frac{1}{|\mathcal{T}|}\sum\limits _{i=1} ^{|\mathcal{T}|}\psi _{\theta}(\mathcal{A}(x_i,\omega))-\frac{1}{|\mathcal{S}|}\sum\limits _{j=1} ^{|\mathcal{S}|}\psi _{\theta}(\mathcal{A}(\mathcal{s}_j,\omega))|| ^2\tag{3}
$$

其中,$\Omega$是数据增强参数空间,$\mathcal{A}(x,\omega)$则是相应的增强操作,其对$x$和$\mathcal{s}$是一样的。

重要代码

以下只是对文中代码的简单实现(并不能运行),具体代码详见文章提供的开源部分Dataset Condensation with Distribution Matching

2

Fig. 2. Pseudocode in paper

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# 参数定义
num_exp = 5 # 重复实验次数
num_classes = 10 # 实验为图片分类,使用数据集MNIST,共10个类
ipc = 10 # 为每个类训练十张合成图片
channel = 1 # 对黑白图片,输入通道为1
im_size = [28, 28] # MNIST图片尺寸为28x28
device = torch.device('cuda') # 用GPU训练
lr_img = 1.0 # 合成图片学习率
K = 1000 # 最外层循环,使得合成集S能适应模型参数的不同初始化方式,实际上作用同epoch
batch_real = 256 # 原训练集的批量大小
batch_train = 256 # 训练模型参数时的批量大小

# 重复实验
for exp in range(num_exp):
'''对原训练集和合成集的初始化'''
# 一些获取训练集T的操作,最终得到的特征和标签依次为(MNIST):
# images_all: Tensor,shape (6000, 1, 28, 28),其中6000为样本数,1为输入通道数(因为是黑白图片),28x28是图片像素
# labels_all: Tensor,shape (6000),6000为样本数,类别为10,包含0~9的手写数字

# 可训练的合成图片,shape (10, 1, 28, 28)
image_syn = torch.randn(size=(num_classes*ipc, channel, im_size[0], im_size[1]), dtype=torch.float32, requires_grad=True, device=device)
# 不可训练的合成图片标签,shape (10)
label_syn = torch.tensor([np.ones(ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=device).view(-1)

'''优化算法、损失函数'''
# 对合成图片的优化算法为使用动量法的SGD
optimizer_image = torch.optim.SGD([image_syn, ], lr=lr_img, momentum=0.5)
optimizer_image.zero_grad() # 梯度清零
# 此处先不定义损失函数,因为用的是MMD

'''for k = 0, ... ,K - 1'''
for k in range(K+1):
# 源代码在此处存在在特定结点分析训练效果的代码,此处省略

'''sample \theta'''
# 获取某个指定的网络同时随机初始化模型参数,原文存在参数,此处省略,仅作为示意
net = get_network().to(device)
net.train() # 使网络进入训练模式
# PyTorch生成网络的模型参数默认是记录梯度的,但在DM中,网络只是映射函数,因此不需要训练,也就无需梯度
for param in list(net.parameters()):
param.requires_grad = False
# net.embed是作者定义的网络类中的函数,实际上是个简化版的forward,一般是去掉最后一层的全连接层
embed = net.embed

# 将所有类别的梯度差异相加再计算梯度
loss = torch.tensor(0.0).to(device)

# 原文还分有BatchNorm和无BatchNorm的情况,此处只考虑无BatchNorm
'''sample mini-batch pairs for each class'''
# 对不同的类别单独训练
for c in range(num_classes):
# 从原训练集的c类中随机获取batch_real张图片,注意,get_images是原文定义的内联函数,因此可以直接访问images_all,而此时images_all已经被放到GPU中了
img_real = get_images(c, batch_real)
# 因为网络只是映射特征,故不需要labels
# 同样地获取合成集中对应类别的数据
img_syn = image_syn[c*ipc:(c+1)*ipc].reshape(ipc, channel, im_size[0], im_size[1])

# 原文此处有数据增强,此处省略

'''数据过网络'''
# 原训练集
output_real = embed(img_real).detach() # 以防万一,不要让其进入计算图中
# 合成集
output_syn = embed(img_syn)
'''MMD'''
# 计算MMD,实际值为矩阵每列平均值(行一般代表样本数)差的平方和
loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)

'''更新合成图片'''
optimizer_img.zero_grad() # 合成图片矩阵的梯度清零
loss.backward() # 计算对合成图片矩阵的二阶混合偏导
optimizer_img.step() # 梯度下降

只有学习率一个可调参数。

实验设计

后续的内容几乎与Dataset Condensation一致,只不过用到了更大的数据集。实验结果也表明,相比于DC得到的合成集,DM得到的合成集的分布更加均匀、更加接近原训练集。