Link Prediction

链接预测的基本方法

当前,在图机器学习领域,链接预测的主流方法有三种:

  1. 启发式方法(Heuristic methods)
    这种方法认为存在链接关系的结点的特征存在某种共同的特性,这种特性使得它们间有更大的相似度。启发式方法通过定义一种映射来衡量这样的相似度,如结点的共同邻居(Common neighor)、Katz Index等,但是这样的映射不一定对所有的图都有效。
  2. 结点嵌入(Node embedding)
    这种方法与结点分类(Node classification)中的方法一致,即基于游走的方法学习结点的embedding,常见的方法有:DeepWalk、Node2vec等。这样的方法,没有直接将链接预测任务嵌入到有监督学习的流程中,并且无法较好的利用用户的节点属性,无法达到较好的预测精度。
  3. 图神经网络(GNN)
    基于基本的图神经网络架构,如GCN、GraphSAGE、GAT等,通过对邻居结点的聚合得到融合了图结构信息的结点表示,再以类似于启发式方法的方式求任意两结点的“相似度”,进而判断两结点间边的有无。最后的判断本质上是个逻辑回归,即二分类问题(有边为1,无边为0)。

基于PyG的GNN链接预测

PyG(PyTorch Geometric)是个建立在PyTorch基础上的图神经网络库,它为训练不同任务(结点分类、链接预测、图分类等)、不同架构(GCN、GraphSAGE、GAT等)的图神经网络提供了方便。使用PyG来完成链接预测,可以省去很多复杂的操作,特别是在数据预处理阶段。事实上,结点分类和链接预测最大的区别就在数据预处理阶段。

数据集的划分

前面提到过,链接预测本质上是一个二分类问题,只不过它二分类的单位不是一个结点,而是一对结点,分类的结果是这对结点间存在(1)或不存在(0)边。这样的不同决定了我们需要对数据集进行额外的处理,以获得训练样本,包括:正采样和负采样,前者采样存在边的结点对,后者采样不存在边的结点对。在结点分类用到的数据集的基础上,使用PyG提供的torch_geometric.transforms.RandomLinkSplit函数可以很方便地完成对训练集、验证集和测试集的采样。

torch_geometric.transforms.RandomLinkSplit(num_val, num_test,...),以下为几个常用参数的说明:

  1. num_val:验证集中边占所有边比例,默认为0.1。
  2. num_test:测试集中边占所有边的比例,默认为0.1。
  3. is_undirectedTrue则假定图是无向图,反之为有向图。
  4. add_negative_train_samples:是否为训练集添加负训练样本。一般设置为False(默认也是False),也就是使得训练集中不包含负样本,这样每一轮训练时在训练集中可以重新采样负样本进行训练,由此可以保证每一轮训练中采样得到的负样本都是不一样的,可以有效提高模型泛化能力。验证集和测试集则默认会自动完成负样本的采样。
  5. neg_sampling_ratio:采样中正负样本的比例,默认为1。即正负样本个数一致(对验证集、测试集和add_negative_train_samples设置为True后的训练集)。

更多资料见RandomLinkSplit

RandomLinkSplit的使用方法很简单,因为它被包含在transforms中,所以可以在读数据的时候作为一个参数传入读数据的函数中,以读取Cora数据集为例:

1
2
3
4
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
train_data, val_data, test_data = Planetoid(root='./data/Cora', name="Cora",
transform=T.RandomLinkSplit(num_val=0.1, num_test=0.1, is_undirected=True, add_negative_train_samples=False))[0]

而一般结点分类任务用到的数据集形式为:

1
data = Planetoid(root='./data/Cora', name="Cora")[0]

注意Planetoid是PyG定义的InMemoryDataset的子类,InMemoryDataset所返回的是数据集中所有的图结构,因此要用下标来读取指定的图。Planetoid能加载的数据集都只有一张图,所以用Planetoid()[0]即可。

我们不妨来看看这四种data的区别:

1
2
3
4
5
>>> print(train_data, "\n", val_data, "\n", test_data, "\n", data)
Data(x=[2708, 1433], edge_index=[2, 8448], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label=[4224], edge_label_index=[2, 4224])
Data(x=[2708, 1433], edge_index=[2, 8448], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label=[1054], edge_label_index=[2, 1054])
Data(x=[2708, 1433], edge_index=[2, 9502], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label=[1054], edge_label_index=[2, 1054])
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

不难看出,图都是同一张图,只不过链接预测用到的数据集中多了edge_labeledge_label_index这两个量:

  • edge_labelTensor,数据集的标签,即y,值为1或0,对于val_datatest_datatrain_data未进行负采样),既包含了正样本的标签,又包含了负样本的标签,因此对于上文的数据集,正样本数为527,负样本数也为527。
  • edge_label_indexTensor,数据集的样本,即X,但是,此处X的内容并不是边的特征,而是边的两个结点的编号,因而也不是用来训练的,其中edge_label_index[0]为起始结点的编号,edge_label_index[1]为终点结点的编号。

train_data的负采样可以使用PyG提供的torch_geometric.utils.negative_sampling函数,见torch_geometric.utils,此处不再赘述。要特别注意把负样本加入edge_labeledge_label_index中。

训练

链接预测图神经网络的训练与结点分类图神经网络的训练并没有本质的不同。若将Node embedding的过程称为编码过程,最后的链接预测/结点分类称为解码过程,那么两者的编码过程是一致的,即卷积和消息传递的过程是一致的。两者的不同在解码过程。对于结点分类任务,解码是对得到Node embedding做一次softmax,因此最后的结点的特征维度也被限定为了结点类别的数量;对于链接预测任务,解码是对edge_label_index所选定的边两端的结点计算相似度后求sigmoid,即:

1
2
3
4
def decoder(self, X, edge_label_index):  # X为编码后的结点特征
src_node = X[edge_label_index[0]] # 取出边起始结点的特征,shape (N x F)
end_node = X[edge_label_index[1]] # 取出边终点结点的特征,shape (N x F)
return (src_node * end_node).sum(dim=1) # 向量内积求相似度,shape (N),此处没有直接sigmoid,放在网络外面做也是一样的

精度分析

由于链接预测是一个二分类问题,因此精度分析采用AUC。

AUC的含义和计算方法见F1 score and ROC & AUC

结点分类和链接预测的测试集泄露问题

测试集泄露,即测试集中的样本特征或者标签在训练网络的过程中被使用了。对于GNN来说,测试集泄露是一个很普遍的问题,因为GNN的消息传递过程不可避免地要用到其他结点的信息,而这些结点就有可能包括测试集的结点。

结点分类

对于最初的GCN,其在训练模型参数时用到的是整个图的结点特征和邻接矩阵,因而不可避免地传递了测试集结点的特征信息。虽然可以把用来计算损失的训练集结点及其边、用于测试效果的测试集结点及其边从原图中单独拎出来形成两个互不相交的子图来从源头上避免数据泄露(即在训练过程中,只让消息在训练集结点间传递),但是这样做会导致精度下降,因为原图的整体结构被破坏了。

GraphSAGE的出现使得大家对GNN的消息传递有了全新的理解,即:只要把待预测结点周围结点的信息传递给待预测结点即可。在这样的视角下,原本耦合于全图的各个结点实际上变成了一个个独立的子图,子图的中心结点是待预测结点,而其他结点是要将信息传递给中心结点的辅助结点。于是,单一结点的结点分类成为了可能(原先必须要将整个图喂进模型中),只要为该结点随机取样一个子图即可。相应地,数据泄露的影响也降到最低(此时网络学习的是一种利用邻居结点信息获得中心结点Embedding的方法/模式,而不是学习怎样为每个结点生成一个Embedding)。这样的子图/采样的方法对于GCN、SGC、GAT等也是成立的,因为它们本质上都是消息传递型的GNN。

链接预测

类似地,对于链接预测,其在消息传递的过程中会用到结点间的连接信息,也就是结点间的边。但是,不同于结点分类的是,链接预测的对象是边,而边的泄露在消息传递的过程中是可以被规避掉的,因为被传递的消息是结点信息而不是边的信息。

具体来说,如不对前面提到的Cora数据集进行负采样,则其链接预测的训练集、验证集、测试集(比例8:1:1)和全集依次为:

1
2
3
4
5
>>> print(train_data, "\n", val_data, "\n", test_data, "\n", data)
Data(x=[2708, 1433], edge_index=[2, 8448], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label=[4224], edge_label_index=[2, 4224])
Data(x=[2708, 1433], edge_index=[2, 8448], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label=[527], edge_label_index=[2, 527])
Data(x=[2708, 1433], edge_index=[2, 9502], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_label=[527], edge_label_index=[2, 527])
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

事实上,上面四个数据集都代表着四张不同的图,其中前面三张图是最后一张(原图)的生成子图,即前面三张图包含原图的所有结点和部分边。这些信息体现在edge_index上。

  • 对于训练集,其不包含负样本的edge_label_indexedge_index的无向图形式(即在edge_index中,无向图的一条边被视为有向图的两条边,而edge_label_index只将其视为一条边,所以edge_label_index的边数是edge_index的一半)。在训练模型过程中,训练集的消息传递会用到所有的训练集边(edge_index),最终的预测也会用到所有的训练集边(edge_label_index & edge_label)+负采样边;
  • 对于验证集,其消息传递的过程中不能使用验证集中特有的边,因此消息传递用到的是训练集的边,这也是为什么验证集的edge_index和训练集的edge_index是一样的。最终的分类过程只会用到验证集特有的边(edge_label_index & edge_label)+负采样边;
  • 对于测试集,其消息传递的过程中也不能使用测试集中特有的边,但是可以使用验证集中特有的边,因此消息传递用到的是训练集和边+验证集特有的边(8448+527*2=9502)。最终的分类过程同样只会用到测试集特有的边+负采样边。

需要特别注意的是,虽然负采样的边不会参与消息传递的过程,但是训练集、验证集和测试集在分类阶段的负采样边同样不能重叠。这很容易实现,只要让负采样在正采样的边所特有的结点间进行即可(e.g. train_data的在8848条边的结点间进行,val_data的在527条边的结点间进行)。

综合来说,链接预测是可以保证数据完全不泄露的,只要保证训练集、验证集和测试集拥有label的边不同以及除训练集外的消息传递不用有label的边即可。

在更严格的算法中(e.g. SEAL),训练集的边还被进一步划分为training_supervision_edgestraining_message_edges,其中前者是训练集特有的带label的边,而后者是用于消息传递的边。而一般的算法(e.g. GAE),对这两者不做区分。

参考