Contrastive Learning

CV领域的对比学习发展路径

什么是对比学习?对比学习是属于无监督/自监督学习范式的。在监督学习的分类问题中,我们希望模型能够精确地预测输入属于的类别,而在对比学习中,模型不需要知道数据的真实标签,只要最终的输出能够把不同的类别区分开就好。因而,对比学习的模型就是一个特征提取器,其模型将输入的特征提取出来作为输出,使得在输出的特征空间中,相似的数据尽可能地相邻,而不相似的数据尽可能地远离,如下图所示

1

Fig. 1. Contrastive Learning

对比学习的典型范式是代理任务+目标函数

  • 代理任务:代理任务是一些不像分类、目标检测一样具有实际意义的应用场景,但我们假定该模型是为了解决这个代理问题而训练的,而实际上它只是用于生成自监督信号以更新特征提取器,从而能够让我们获得更好的预训练模型。在NLP中,BERT预训练中用到的填词等任务就可以被视为代理任务。在CV中,如下面会提到的九宫格图像相对位置预测、图片着色等都属于代理任务。不过,在CV的对比学习中,更常用的代理任务是个体判别(Instance discrimination),即将同类的个体与其他个体区分开来。> 更通俗地来说,代理任务是为了生成类似监督学习的“标签”,使得无监督学习也有比较的对象(像监督学习中的Ground Truth和prediction一样),有了比较对象,我们才能用合适的metric构建目标函数。
  • 目标函数:产生梯度。
    1. 生成式网络:用生成的图片与原图片做对比,可以是$L1$或$L2$ losses。
    2. 判别式网络:对图片本身做划分,如作九宫格划分,用一个格子预测另一个格子在其哪个方位,实际上转化为了一个交叉熵损失。
    3. 对比式:衡量被提取的数据特征间在特征空间的相似性,不同于前两种的是(特别是生成式),由于编码器是在不断更新的,被提取的数据特征也是在不断被更新的,因而对比的对象不像前两者是固定的。
    4. 对抗性:衡量概率分布的差异。(不太懂)

百花齐放

  1. InstDisc: Memory bank。字典内容一致性不好。
  2. InvaSpread: 端到端,两个编码器都梯度下降。字典大小受限。
  3. CPC: InfoNCE。
  4. CMC: 多视角。

CV双雄

  1. MoCoV1: 动态编码器、Memory bank变队列。
  2. SimCLRV1: 端到端。
  3. MoCoV2
  4. SimCLRV2
  5. SWaV

不用负样本

  1. BYOL: MSE LOSS,一个编码器预测另一个编码器。
  2. SimSiam

基于Transformer

  1. MoCoV3
  2. DINO

骨干网络由ResNet换为ViT。

1

Fig. 1.

MoCo

对比学习是一次字典查询的过程。

  1. 字典要大;

  2. 字典的内容连续性要好。

  3. 队列作为字典的数据结构:每个mini batch,老key出去,更新后的key作为new key进来。

  4. 动量编码器:$\theta _k = m*\theta _k+(1-m)*\theta _q$,保证字典中key的一致性。

$m$很大,文中取$0.99$或$0.999$。

NCE(Noise Contrastive Estimation)

当分类任务的类别很多时,交叉熵的计算时间是难以承受的,因为交叉熵的分母必须对样本在所有类别上出现的可能进行求和。对于Instance discrimination,每个样本就是一个类,在这种情况下,用交叉熵是不现实的。
NCE(Noise Contrastive Estimation)将多分类问题转化为了多个二分类问题,所有的样本都只有两类:来自data samples的正类和来自noise samples的负类。

InfoNCE

InfoNCE是对NCE的改进,它比NCE更加接近交叉熵。InfoNCE将正例视作一类,将单个的负例也视作一类。因此,对于$1$正例$K$负例的采样,总类别数是$K+1$。InfoNCE实际上就是类别数为$K+1$且带温度参数$\tau$的交叉熵。还有一点特殊的是,由于我们想匹配的只是$q$和正例$k _+$,所以$\mathcal{L} _q$的分子永远都只会是$\exp(q\cdot k _+ / \tau)$。

$$
\mathcal{L} _q=-\log \frac{\exp(q\cdot k _+ / \tau)}{\sum _ {i=0} ^K\exp(q\cdot k _i / \tau)}
$$

上式中,$k _0$即为$k _+$。$\mathcal{L} _q$很好地体现了我们的优化目标:**$q$和正例的相似性出现在分子,所以越大越好,相应地,分母上$q$与负例的相似性越小越好**。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# f_q, f_k: encoder networks for query and key
# queue: dictionary as a queue of K keys (CxK)
# m: momentum
# t: temperature

f_k.params = f_q.params # initialize
for x in loader: # load a minibatch x with N samples
x_q = aug(x) # a randomly augmented version
x_k = aug(x) # another randomly augmented version

q = f_q.forward(x_q) # queries: NxC 256x128
k = f_k.forward(x_k) # keys: NxC 256x128
k = k.detach() # no gradient to keys

# positive logits: Nx1 256x1
l_pos = bmm(q.view(N,1,C), k.view(N,C,1))
# negative logits: NxK 256x65536
l_neg = mm(q.view(N,C), queue.view(C,K))

# logits: Nx(1+K) 256x65537
logits = cat([l_pos, l_neg], dim=1)

# contrastive loss, Eqn.(1)
labels = zeros(N) # positives are the 0-th
loss = CrossEntropyLoss(logits/t, labels)

# SGD update: query network
loss.backward()
update(f_q.params)

# momentum update: key network
f_k.params = m*f_k.params+(1-m)*f_q.params

# update dictionary
enqueue(queue, k) # enqueue the current minibatch
dequeue(queue) # dequeue the earliest minibatch