GNNRecom/gnnrec/hge/utils/data.py

139 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import dgl
import dgl.function as fn
import torch
from gensim.models import Word2Vec
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from gnnrec.config import DATA_DIR
from gnnrec.hge.data import ACMDataset, DBLPDataset
from gnnrec.kgrec.data import OAGVenueDataset
def load_data(name, device='cpu', add_reverse_edge=True, reverse_self=True):
"""加载数据集
:param name: str 数据集名称 acm, dblp, ogbn-mag, oag-venue
:param device: torch.device, optional 将图和数据移动到指定的设备上默认为CPU
:param add_reverse_edge: bool, optional 是否添加反向边默认为True
:param reverse_self: bool, optional 起点和终点类型相同时是否添加反向边默认为True
:return: dataset, g, features, labels, predict_ntype, train_mask, val_mask, test_mask, evaluator
"""
if name == 'ogbn-mag':
return load_ogbn_mag(device, add_reverse_edge, reverse_self)
elif name == 'acm':
data = ACMDataset()
elif name == 'dblp':
data = DBLPDataset()
elif name == 'oag-venue':
data = OAGVenueDataset()
else:
raise ValueError(f'load_data: 未知数据集{name}')
g = data[0]
predict_ntype = data.predict_ntype
# ACM和DBLP数据集已添加反向边
if add_reverse_edge and name not in ('acm', 'dblp'):
g = add_reverse_edges(g, reverse_self)
g = g.to(device)
features = g.nodes[predict_ntype].data['feat']
labels = g.nodes[predict_ntype].data['label']
train_mask = g.nodes[predict_ntype].data['train_mask'].nonzero(as_tuple=True)[0]
val_mask = g.nodes[predict_ntype].data['val_mask'].nonzero(as_tuple=True)[0]
test_mask = g.nodes[predict_ntype].data['test_mask'].nonzero(as_tuple=True)[0]
return data, g, features, labels, predict_ntype, train_mask, val_mask, test_mask, None
def load_ogbn_mag(device, add_reverse_edge, reverse_self):
"""加载ogbn-mag数据集
:param device: torch.device 将图和数据移动到指定的设备上默认为CPU
:param add_reverse_edge: bool 是否添加反向边
:param reverse_self: bool 起点和终点类型相同时是否添加反向边
:return: dataset, g, features, labels, predict_ntype, train_mask, val_mask, test_mask, evaluator
"""
data = DglNodePropPredDataset('ogbn-mag', DATA_DIR)
g, labels = data[0]
if add_reverse_edge:
g = add_reverse_edges(g, reverse_self)
g = g.to(device)
features = g.nodes['paper'].data['feat']
labels = labels['paper'].squeeze(dim=1).to(device)
split_idx = data.get_idx_split()
train_idx = split_idx['train']['paper'].to(device)
val_idx = split_idx['valid']['paper'].to(device)
test_idx = split_idx['test']['paper'].to(device)
evaluator = Evaluator(data.name)
return data, g, features, labels, 'paper', train_idx, val_idx, test_idx, evaluator
def add_reverse_edges(g, reverse_self=True):
"""给异构图的每种边添加反向边,返回新的异构图
:param g: DGLGraph 异构图
:param reverse_self: bool, optional 起点和终点类型相同时是否添加反向边默认为True
:return: DGLGraph 添加反向边之后的异构图
"""
data = {}
for stype, etype, dtype in g.canonical_etypes:
u, v = g.edges(etype=(stype, etype, dtype))
data[(stype, etype, dtype)] = u, v
if stype != dtype or reverse_self:
data[(dtype, etype + '_rev', stype)] = v, u
new_g = dgl.heterograph(data, {ntype: g.num_nodes(ntype) for ntype in g.ntypes})
for ntype in g.ntypes:
new_g.nodes[ntype].data.update(g.nodes[ntype].data)
for etype in g.canonical_etypes:
new_g.edges[etype].data.update(g.edges[etype].data)
return new_g
def one_hot_node_feat(g):
for ntype in g.ntypes:
if 'feat' not in g.nodes[ntype].data:
g.nodes[ntype].data['feat'] = torch.eye(g.num_nodes(ntype), device=g.device)
def average_node_feat(g):
"""ogbn-mag数据集没有输入特征的顶点取邻居平均"""
message_func, reduce_func = fn.copy_u('feat', 'm'), fn.mean('m', 'feat')
g.multi_update_all({
'writes_rev': (message_func, reduce_func),
'has_topic': (message_func, reduce_func)
}, 'sum')
g.multi_update_all({'affiliated_with': (message_func, reduce_func)}, 'sum')
def load_pretrained_node_embed(g, node_embed_path, concat=False):
"""为没有输入特征的顶点加载预训练的顶点特征
:param g: DGLGraph 异构图
:param node_embed_path: str 预训练的word2vec模型路径
:param concat: bool, optional 如果为True则将预训练特征与原输入特征拼接
"""
model = Word2Vec.load(node_embed_path)
for ntype in g.ntypes:
embed = torch.from_numpy(model.wv[[f'{ntype}_{i}' for i in range(g.num_nodes(ntype))]]) \
.to(g.device)
if 'feat' in g.nodes[ntype].data:
if concat:
g.nodes[ntype].data['feat'] = torch.cat([g.nodes[ntype].data['feat'], embed], dim=1)
else:
g.nodes[ntype].data['feat'] = embed
def add_node_feat(g, method, node_embed_path=None, concat=False):
"""为没有输入特征的顶点添加输入特征
:param g: DGLGraph 异构图
:param method: str one-hot, average仅用于ogbn-mag数据集, pretrained
:param node_embed_path: str 预训练的word2vec模型路径
:param concat: bool, optional 如果为True则将预训练特征与原输入特征拼接
"""
if method == 'one-hot':
one_hot_node_feat(g)
elif method == 'average':
average_node_feat(g)
elif method == 'pretrained':
load_pretrained_node_embed(g, node_embed_path, concat)
else:
raise ValueError(f'add_node_feat: 未知方法{method}')