205 lines
6.3 KiB
Python
205 lines
6.3 KiB
Python
|
import os
|
|||
|
import shutil
|
|||
|
import zipfile
|
|||
|
|
|||
|
import dgl
|
|||
|
import numpy as np
|
|||
|
import pandas as pd
|
|||
|
import scipy.sparse as sp
|
|||
|
import torch
|
|||
|
from dgl.data import DGLDataset
|
|||
|
from dgl.data.utils import download, save_graphs, save_info, load_graphs, load_info, \
|
|||
|
generate_mask_tensor, idx2mask
|
|||
|
|
|||
|
|
|||
|
class HeCoDataset(DGLDataset):
|
|||
|
"""HeCo模型使用的数据集基类
|
|||
|
|
|||
|
论文链接:https://arxiv.org/pdf/2105.09111
|
|||
|
|
|||
|
类属性
|
|||
|
-----
|
|||
|
* num_classes: 类别数
|
|||
|
* metapaths: 使用的元路径
|
|||
|
* predict_ntype: 目标顶点类型
|
|||
|
* pos: (tensor(E_pos), tensor(E_pos)) 目标顶点正样本对,pos[1][i]是pos[0][i]的正样本
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self, name, ntypes):
|
|||
|
url = 'https://api.github.com/repos/liun-online/HeCo/zipball/main'
|
|||
|
self._ntypes = {ntype[0]: ntype for ntype in ntypes}
|
|||
|
super().__init__(name + '-heco', url)
|
|||
|
|
|||
|
def download(self):
|
|||
|
file_path = os.path.join(self.raw_dir, 'HeCo-main.zip')
|
|||
|
if not os.path.exists(file_path):
|
|||
|
download(self.url, path=file_path)
|
|||
|
with zipfile.ZipFile(file_path, 'r') as f:
|
|||
|
f.extractall(self.raw_dir)
|
|||
|
shutil.copytree(
|
|||
|
os.path.join(self.raw_dir, 'HeCo-main', 'data', self.name.split('-')[0]),
|
|||
|
os.path.join(self.raw_path)
|
|||
|
)
|
|||
|
|
|||
|
def save(self):
|
|||
|
save_graphs(os.path.join(self.save_path, self.name + '_dgl_graph.bin'), [self.g])
|
|||
|
save_info(os.path.join(self.raw_path, self.name + '_pos.pkl'), {'pos_i': self.pos_i, 'pos_j': self.pos_j})
|
|||
|
|
|||
|
def load(self):
|
|||
|
graphs, _ = load_graphs(os.path.join(self.save_path, self.name + '_dgl_graph.bin'))
|
|||
|
self.g = graphs[0]
|
|||
|
ntype = self.predict_ntype
|
|||
|
self._num_classes = self.g.nodes[ntype].data['label'].max().item() + 1
|
|||
|
for k in ('train_mask', 'val_mask', 'test_mask'):
|
|||
|
self.g.nodes[ntype].data[k] = self.g.nodes[ntype].data[k].bool()
|
|||
|
info = load_info(os.path.join(self.raw_path, self.name + '_pos.pkl'))
|
|||
|
self.pos_i, self.pos_j = info['pos_i'], info['pos_j']
|
|||
|
|
|||
|
def process(self):
|
|||
|
self.g = dgl.heterograph(self._read_edges())
|
|||
|
|
|||
|
feats = self._read_feats()
|
|||
|
for ntype, feat in feats.items():
|
|||
|
self.g.nodes[ntype].data['feat'] = feat
|
|||
|
|
|||
|
labels = torch.from_numpy(np.load(os.path.join(self.raw_path, 'labels.npy'))).long()
|
|||
|
self._num_classes = labels.max().item() + 1
|
|||
|
self.g.nodes[self.predict_ntype].data['label'] = labels
|
|||
|
|
|||
|
n = self.g.num_nodes(self.predict_ntype)
|
|||
|
for split in ('train', 'val', 'test'):
|
|||
|
idx = np.load(os.path.join(self.raw_path, f'{split}_60.npy'))
|
|||
|
mask = generate_mask_tensor(idx2mask(idx, n))
|
|||
|
self.g.nodes[self.predict_ntype].data[f'{split}_mask'] = mask
|
|||
|
|
|||
|
pos_i, pos_j = sp.load_npz(os.path.join(self.raw_path, 'pos.npz')).nonzero()
|
|||
|
self.pos_i, self.pos_j = torch.from_numpy(pos_i).long(), torch.from_numpy(pos_j).long()
|
|||
|
|
|||
|
def _read_edges(self):
|
|||
|
edges = {}
|
|||
|
for file in os.listdir(self.raw_path):
|
|||
|
name, ext = os.path.splitext(file)
|
|||
|
if ext == '.txt':
|
|||
|
u, v = name
|
|||
|
e = pd.read_csv(os.path.join(self.raw_path, f'{u}{v}.txt'), sep='\t', names=[u, v])
|
|||
|
src = e[u].to_list()
|
|||
|
dst = e[v].to_list()
|
|||
|
edges[(self._ntypes[u], f'{u}{v}', self._ntypes[v])] = (src, dst)
|
|||
|
edges[(self._ntypes[v], f'{v}{u}', self._ntypes[u])] = (dst, src)
|
|||
|
return edges
|
|||
|
|
|||
|
def _read_feats(self):
|
|||
|
feats = {}
|
|||
|
for u in self._ntypes:
|
|||
|
file = os.path.join(self.raw_path, f'{u}_feat.npz')
|
|||
|
if os.path.exists(file):
|
|||
|
feats[self._ntypes[u]] = torch.from_numpy(sp.load_npz(file).toarray()).float()
|
|||
|
return feats
|
|||
|
|
|||
|
def has_cache(self):
|
|||
|
return os.path.exists(os.path.join(self.save_path, self.name + '_dgl_graph.bin'))
|
|||
|
|
|||
|
def __getitem__(self, idx):
|
|||
|
if idx != 0:
|
|||
|
raise IndexError('This dataset has only one graph')
|
|||
|
return self.g
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return 1
|
|||
|
|
|||
|
@property
|
|||
|
def num_classes(self):
|
|||
|
return self._num_classes
|
|||
|
|
|||
|
@property
|
|||
|
def metapaths(self):
|
|||
|
raise NotImplementedError
|
|||
|
|
|||
|
@property
|
|||
|
def predict_ntype(self):
|
|||
|
raise NotImplementedError
|
|||
|
|
|||
|
@property
|
|||
|
def pos(self):
|
|||
|
return self.pos_i, self.pos_j
|
|||
|
|
|||
|
|
|||
|
class ACMDataset(HeCoDataset):
|
|||
|
"""ACM数据集
|
|||
|
|
|||
|
统计数据
|
|||
|
-----
|
|||
|
* 顶点:4019 paper, 7167 author, 60 subject
|
|||
|
* 边:13407 paper-author, 4019 paper-subject
|
|||
|
* 目标顶点类型:paper
|
|||
|
* 类别数:3
|
|||
|
* 顶点划分:180 train, 1000 valid, 1000 test
|
|||
|
|
|||
|
paper顶点特征
|
|||
|
-----
|
|||
|
* feat: tensor(N_paper, 1902)
|
|||
|
* label: tensor(N_paper) 0~2
|
|||
|
* train_mask, val_mask, test_mask: tensor(N_paper)
|
|||
|
|
|||
|
author顶点特征
|
|||
|
-----
|
|||
|
* feat: tensor(7167, 1902)
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self):
|
|||
|
super().__init__('acm', ['paper', 'author', 'subject'])
|
|||
|
|
|||
|
@property
|
|||
|
def metapaths(self):
|
|||
|
return [['pa', 'ap'], ['ps', 'sp']]
|
|||
|
|
|||
|
@property
|
|||
|
def predict_ntype(self):
|
|||
|
return 'paper'
|
|||
|
|
|||
|
|
|||
|
class DBLPDataset(HeCoDataset):
|
|||
|
"""DBLP数据集
|
|||
|
|
|||
|
统计数据
|
|||
|
-----
|
|||
|
* 顶点:4057 author, 14328 paper, 20 conference, 7723 term
|
|||
|
* 边:19645 paper-author, 14328 paper-conference, 85810 paper-term
|
|||
|
* 目标顶点类型:author
|
|||
|
* 类别数:4
|
|||
|
* 顶点划分:240 train, 1000 valid, 1000 test
|
|||
|
|
|||
|
author顶点特征
|
|||
|
-----
|
|||
|
* feat: tensor(N_author, 334)
|
|||
|
* label: tensor(N_author) 0~3
|
|||
|
* train_mask, val_mask, test_mask: tensor(N_author)
|
|||
|
|
|||
|
paper顶点特征
|
|||
|
-----
|
|||
|
* feat: tensor(14328, 4231)
|
|||
|
|
|||
|
term顶点特征
|
|||
|
-----
|
|||
|
* feat: tensor(7723, 50)
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self):
|
|||
|
super().__init__('dblp', ['author', 'paper', 'conference', 'term'])
|
|||
|
|
|||
|
def _read_feats(self):
|
|||
|
feats = {}
|
|||
|
for u in 'ap':
|
|||
|
file = os.path.join(self.raw_path, f'{u}_feat.npz')
|
|||
|
feats[self._ntypes[u]] = torch.from_numpy(sp.load_npz(file).toarray()).float()
|
|||
|
feats['term'] = torch.from_numpy(np.load(os.path.join(self.raw_path, 't_feat.npz'))).float()
|
|||
|
return feats
|
|||
|
|
|||
|
@property
|
|||
|
def metapaths(self):
|
|||
|
return [['ap', 'pa'], ['ap', 'pc', 'cp', 'pa'], ['ap', 'pt', 'tp', 'pa']]
|
|||
|
|
|||
|
@property
|
|||
|
def predict_ntype(self):
|
|||
|
return 'author'
|