一、概述
大量现实世界的数据集以异质图(heterogeneous graphs )的形式存储,这促使在PyG(PyTorch Geometric )中为它们引入专门的功能。例如,推荐领域中的大多数图,如社交图,都是异质的,因为它们存储了关于不同类型实体及其不同类型关系的信息。本教程将介绍异质图如何映射到PyG,以及它们如何用作图神经网络(Graph Neural Network)模型的输入 。
异构图的节点和边附有不同类型的信息。因此,由于类型和维度的差异,单个节点或边特征张量无法容纳整个图的所有节点或边特征。相反,需要分别为节点和边指定一组类型,每种类型有其自身的数据张量。由于数据结构不同,消息传递公式也会相应改变,从而能够根据节点或边类型计算消息和更新函数 。
二、示例
所给的异构图有 1,939,743 个节点,分为四种节点类型:作者、论文、机构和研究领域。它还有 21,111,007 条边,这些边也属于四种类型之一:
writes:一位作者撰写一篇特定的论文affiliated with:一位作者隶属于一个特定的机构cites:一篇论文引用另一篇论文has topic:一篇论文具有特定研究领域的话题
该图表的任务是根据图表中存储的信息推断每篇论文的发表场所(会议或期刊)。
三、异构图的数据结构
首先,我们可以创建一个类型为torch_geometric.data.HeteroData的数据对象,针对该对象,我们需要为每种类型分别定义节点特征张量、边索引张量和边特征张量:
from torch_geometric.data import HeteroData
data = HeteroData()
data['paper'].x = ... # [num_papers, num_features_paper]
data['author'].x = ... # [num_authors, num_features_author]
data['institution'].x = ... # [num_institutions, num_features_institution]
data['field_of_study'].x = ... # [num_field, num_features_field]
data['paper', 'cites', 'paper'].edge_index = ... # [2, num_edges_cites]
data['author', 'writes', 'paper'].edge_index = ... # [2, num_edges_writes]
data['author', 'affiliated_with', 'institution'].edge_index = ... # [2, num_edges_affiliated]
data['paper', 'has_topic', 'field_of_study'].edge_index = ... # [2, num_edges_topic]
data['paper', 'cites', 'paper'].edge_attr = ... # [num_edges_cites, num_features_cites]
data['author', 'writes', 'paper'].edge_attr = ... # [num_edges_writes, num_features_writes]
data['author', 'affiliated_with', 'institution'].edge_attr = ... # [num_edges_affiliated, num_features_affiliated]
data['paper', 'has_topic', 'field_of_study'].edge_attr = ... # [num_edges_topic, num_features_topic]
节点或边的张量会在首次访问时自动创建,并通过字符串键进行索引。节点类型由单个字符串标识,而边类型则通过字符串三元组(源节点类型、边类型、目标节点类型)来标识。
节点类型由一个字符串标识,而边类型则通过一个字符串三元组 (source_node_type, edge_type, destination_node_type) 进行标识。同一对节点类型之间可存在多种边关系且数据对象允许每种类型具有不同的特征维度。
按属性名称而非节点或边类型分组的异构信息字典可直接通过data.{attribute_name}_dict访问,并随后用作图神经网络(GNN)模型的输入:
model = HeteroGNN(...)
output = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict)
如果该数据集存在于PyTorch Geometric的数据集列表中,则可以直接导入并使用。具体而言,它会被下载到root目录并自动进行处理。
from torch_geometric.datasets import OGB_MAG
dataset = OGB_MAG(root='./data', preprocess='metapath2vec')
data = dataset[0]
data 对象可以打印出来进行查看:
HeteroData(
paper={
x=[736389, 128],
y=[736389],
train_mask=[736389],
val_mask=[736389],
test_mask=[736389]
},
author={
x=[1134649, 128] },
institution={
x=[8740, 128] },
field_of_study={
x=[59965, 128] },
(author, affiliated_with, institution)={
edge_index=[2, 1043998] },
(author, writes, paper)={
edge_index=[2, 7145660] },
(paper, cites, paper)={
edge_index=


922

被折叠的 条评论
为什么被折叠?



