超大数据集类的创建

之前,我们只涉及可以将所有数据存储在内存中的数据集。这些数据集对应的数据集类在创建对象时会将所有数据加载到内存中。但是,如果数据集超大,我们很难有足够的内存完整存储所有数据。因此,您需要一个数据集类来按需将样本加载到内存中。

数据集类

在 PyG 中,我们继承了 torch_ 几何。数据。数据集基类定义了按需将样本加载到内存中的数据集类。继承 torch_ 几何。数据。 inmemorydataset 基类要实现的方法。继承这个基类也需要实现。此外,还需要实现以下方法:

  • len():返回数据集中的样本数

  • get():实现加载单个图的操作。在内部,getitem() 返回调用 get() 得到的 Data 对象,并根据 transform 参数有选择地进行转换。

我们可以不定义Dataset类,直接生成Dataloader对象,通过以下方法进行训练:

从 torch_geometric.data 导入数据、DataLoader

数据_list u003d [数据(...),...,数据(...)]

loader u003d DataLoader(data_list, batch_sizeu003d32)

我们还可以通过以下方式从 Data 对象列表中形成批处理:

从 torch_geometric.data 导入数据,批处理

数据_list u003d [数据(...),...,数据(...)]

loader u003d Batch.from\data\list(data\list, batch_sizeu003d32)

Graph样本封装成batch和DataLoader类

合并小图纸形成大图纸

图可以有任意数量的节点和边。它不是常规的数据结构。因此,将图数据封装成批次的操作与将图像和序列数据封装成批次的操作是不同的。 Pytorch 几何通过将小图合并为连接组件来构建大图,将多个图封装成批次。因此,小图的邻接矩阵存储在大图的邻接矩阵的对角线上。大图的邻接矩阵、属性矩阵和预测目标矩阵分别为:

这种方法具有以下主要优点:

  • 依赖消息传递方案的GNN操作不需要修改,因为属于不同图的两个节点之间仍然不能交换消息。

  • 没有额外的计算或内存开销。

小图的属性增量和拼接

将小图存入大图时,需要修改小图的属性。最重要的例子之一是增加节点序列号的价值。最一般的形式,pytorch几何的DataLoader类会自动更新edge_index张量的值增加,增加的值是当前处理的图的前一个图中节点的累积数量。 Pytorch 几何允许我们覆盖 torch_geometric.data.inc() 和 torch_geometric.data.cat_dim() 函数以实现所需的行为。

图匹配

如果要在一个 Data 对象中存储多个图,例如用于图匹配等应用,我们需要确保将所有这些图正确封装成批处理行为。例如,考虑在 Data 类中存储两个图,一个源图 Gs 和一个目标图 Gt,即

类对数据(数据):

def __init__(self, edge_index_s, x_s, edge_index_t,x_t):

super(PairData, self).__init__()

self.edge_index_s u003d 边_index_s

自我.x_s u003d x_s

self.edge_index_t u003d 边_index_t

自我.x_t u003d x_t

在这种情况下,edge_index_s 应根据源图 Gs 中的节点数添加,即 x_s.size(0),而 edge_index_t 应根据源图 Gs 中的节点数添加目标图 Gt,即 x_t.size(0)。

我们通过一个例子来看看节点增值:

边缘_index_s u003d torch.tensor([

[0, 0, 0, 0],

[1, 2, 3, 4],

])

x_s u003d torch.randn(5, 16) # 5 个节点。

边_index_t u003d torch.tensor([

[0, 0, 0],

[1, 2, 3],

])

x_t u003d torch.randn(4, 16) # 4 个节点。

数据 u003d PairData(edge_index_s, x_s, edge_index_t, x_t)

数据_list u003d [数据,数据]

loader u003d DataLoader(data_list, batch_sizeu003d2)

批处理 u003d 下一个(迭代器(加载器))

打印(批量)

批处理(边_index_su003d[2, 8], x_su003d[10, 16],

边_index_tu003d[2, 6], x_tu003d[8, 16])

打印(batch.edge_index_s)

张量([[0, 0, 0, 0, 5, 5, 5, 5],

[1, 2, 3, 4, 6, 7, 8, 9]])

打印(batch.edge_index_t)

张量([[0, 0, 0, 4, 4, 4],

[1, 2, 3, 5, 6, 7]])

我们可以在 DataLoader_Batch 参数中使用 follow 来维护批处理属性。

二分图

二分图的邻接矩阵定义了两类节点之间的连接关系。不同类型的节点数量不需要相同,所以边的源节点和目标节点的增值操作应该不同。我们需要告诉 pytorch geometry,它应该在 edge_index 处独立地对边缘的源节点和目标节点执行增值操作。

def __inc__(自我、键、值):

如果键 u003du003d 'edge_index':

返回 torch.tensor([[self.x_s.size(0)],[self.x_t.size(0)]])

其他:

返回超级()。__inc__(键,值)

其中,edge_index[0]根据x_s.size(0)(边的源节点)进行增值操作,而edge_index[1](边的目标节点)根据 x_t.size(0) 进行增值操作。

拼接新维度

有时候,Data对象的属性需要在一个新的维度上进行拼接(比如经典的封装)

Batch),例如图形级别的属性或预测目标。具体来说,形状 [num_features]

的属性列表应返回为 [num_examples, num_features],而不是

[num_examples * num_features]。火炬几何

cat_dim() 返回一个连接维度 None 来实现这一点。

类我的数据(数据):

def __cat_dim__(自我、钥匙、物品):

如果键 u003du003d 'foo':

返回无

其他:

return super().__cat_dim__(key, item)

图形预测任务练习

手术:

(1) 需要128G的虚拟内存

(2)使用教程的参数,需要运行49个epochs和16个num_workers,每个epoch的运行时间大约3~4分钟,整体运行至少需要5个小时

(3) 试运行开始后,程序会在saves目录下创建一个任务_name参数指定的文件夹,用于记录测试过程。当saves目录下已经有同名文件夹时,程序会在task_name参数的末尾添加一个后缀作为文件夹名。在测试运行的过程中,所有的打印输出都会写入到 test 文件夹下的输出文件中,tensorboard summarywriter 记录的信息也存储在 trial 文件夹下的文件中。

参考资料:

1.创建按需数据集类

2.图形预测任务练习

点击阅读全文
Logo

学AI,认准AI Studio!GPU算力,限时免费领,邀请好友解锁更多惊喜福利 >>>

更多推荐