65 lines
2.1 KiB
Python
65 lines
2.1 KiB
Python
import json
|
|
|
|
import dgl
|
|
import torch
|
|
from dgl.dataloading import Collator
|
|
from dgl.utils import to_dgl_context
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
def iter_json(filename):
|
|
"""遍历每行一个JSON格式的文件。"""
|
|
with open(filename, encoding='utf8') as f:
|
|
for line in f:
|
|
yield json.loads(line)
|
|
|
|
|
|
class TripletNodeCollator(Collator):
|
|
|
|
def __init__(self, g, triplets, block_sampler, ntype):
|
|
"""用于OAGCSAuthorRankDataset数据集的NodeCollator
|
|
|
|
:param g: DGLGraph 异构图
|
|
:param triplets: tensor(N, 3) (t, ap, an)三元组
|
|
:param block_sampler: BlockSampler 邻居采样器
|
|
:param ntype: str 目标顶点类型
|
|
"""
|
|
self.g = g
|
|
self.triplets = triplets
|
|
self.block_sampler = block_sampler
|
|
self.ntype = ntype
|
|
|
|
def collate(self, items):
|
|
"""根据三元组中的学者id构造子图
|
|
|
|
:param items: List[tensor(3)] 一个批次的三元组
|
|
:return: tensor(N_src), tensor(N_dst), List[DGLBlock] (input_nodes, output_nodes, blocks)
|
|
"""
|
|
items = torch.stack(items, dim=0)
|
|
seed_nodes = items[:, 1:].flatten().unique()
|
|
blocks = self.block_sampler.sample_blocks(self.g, {self.ntype: seed_nodes})
|
|
output_nodes = blocks[-1].dstnodes[self.ntype].data[dgl.NID]
|
|
return items, output_nodes, blocks
|
|
|
|
@property
|
|
def dataset(self):
|
|
return self.triplets
|
|
|
|
|
|
class TripletNodeDataLoader(DataLoader):
|
|
|
|
def __init__(self, g, triplets, block_sampler, device=None, **kwargs):
|
|
"""用于OAGCSAuthorRankDataset数据集的NodeDataLoader
|
|
|
|
:param g: DGLGraph 异构图
|
|
:param triplets: tensor(N, 3) (t, ap, an)三元组
|
|
:param block_sampler: BlockSampler 邻居采样器
|
|
:param device: torch.device
|
|
:param kwargs: DataLoader的其他参数
|
|
"""
|
|
if device is None:
|
|
device = g.device
|
|
block_sampler.set_output_context(to_dgl_context(device))
|
|
self.collator = TripletNodeCollator(g, triplets, block_sampler, 'author')
|
|
super().__init__(triplets, collate_fn=self.collator.collate, **kwargs)
|