07.异构图神经网络

一、概述

大量现实世界的数据集以异质图(heterogeneous graphs )的形式存储,这促使在PyG(PyTorch Geometric )中为它们引入专门的功能。例如,推荐领域中的大多数图,如社交图,都是异质的,因为它们存储了关于不同类型实体及其不同类型关系的信息。本教程将介绍异质图如何映射到PyG,以及它们如何用作图神经网络(Graph Neural Network)模型的输入 。

异构图的节点和边附有不同类型的信息。因此,由于类型和维度的差异,单个节点或边特征张量无法容纳整个图的所有节点或边特征。相反,需要分别为节点和边指定一组类型,每种类型有其自身的数据张量。由于数据结构不同,消息传递公式也会相应改变,从而能够根据节点或边类型计算消息和更新函数 。

二、示例

image-20250707154634353

所给的异构图有 1,939,743 个节点,分为四种节点类型:作者、论文、机构和研究领域。它还有 21,111,007 条边,这些边也属于四种类型之一:

  • writes:一位作者撰写一篇特定的论文
  • affiliated with:一位作者隶属于一个特定的机构
  • cites:一篇论文引用另一篇论文
  • has topic:一篇论文具有特定研究领域的话题

该图表的任务是根据图表中存储的信息推断每篇论文的发表场所(会议或期刊)。

三、异构图的数据结构

首先,我们可以创建一个类型为torch_geometric.data.HeteroData的数据对象,针对该对象,我们需要为每种类型分别定义节点特征张量、边索引张量和边特征张量:

from torch_geometric.data import HeteroData

data = HeteroData()

data['paper'].x = ... # [num_papers, num_features_paper]
data['author'].x = ... # [num_authors, num_features_author]
data['institution'].x = ... # [num_institutions, num_features_institution]
data['field_of_study'].x = ... # [num_field, num_features_field]

data['paper', 'cites', 'paper'].edge_index = ... # [2, num_edges_cites]
data['author', 'writes', 'paper'].edge_index = ... # [2, num_edges_writes]
data['author', 'affiliated_with', 'institution'].edge_index = ... # [2, num_edges_affiliated]
data['paper', 'has_topic', 'field_of_study'].edge_index = ... # [2, num_edges_topic]

data['paper', 'cites', 'paper'].edge_attr = ... # [num_edges_cites, num_features_cites]
data['author', 'writes', 'paper'].edge_attr = ... # [num_edges_writes, num_features_writes]
data['author', 'affiliated_with', 'institution'].edge_attr = ... # [num_edges_affiliated, num_features_affiliated]
data['paper', 'has_topic', 'field_of_study'].edge_attr = ... # [num_edges_topic, num_features_topic]

节点或边的张量会在首次访问时自动创建,并通过字符串键进行索引。节点类型由单个字符串标识,而边类型则通过字符串三元组(源节点类型、边类型、目标节点类型)来标识。

节点类型由一个字符串标识,而边类型则通过一个字符串三元组 (source_node_type, edge_type, destination_node_type) 进行标识。同一对节点类型之间可存在多种边关系且数据对象允许每种类型具有不同的特征维度。

按属性名称而非节点或边类型分组的异构信息字典可直接通过data.{attribute_name}_dict访问,并随后用作图神经网络(GNN)模型的输入:

model = HeteroGNN(...)

output = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict)

如果该数据集存在于PyTorch Geometric的数据集列表中,则可以直接导入并使用。具体而言,它会被下载到root目录并自动进行处理。

from torch_geometric.datasets import OGB_MAG

dataset = OGB_MAG(root='./data', preprocess='metapath2vec')
data = dataset[0]

data 对象可以打印出来进行查看:

HeteroData(
  paper={
   
   
    x=[736389, 128],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  author={
   
    x=[1134649, 128] },
  institution={
   
    x=[8740, 128] },
  field_of_study={
   
    x=[59965, 128] },
  (author, affiliated_with, institution)={
   
    edge_index=[2, 1043998] },
  (author, writes, paper)={
   
    edge_index=[2, 7145660] },
  (paper, cites, paper)={
   
    edge_index=
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

kaiaaaa

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值