GNNRecom/gnnrec/hge/metapath2vec/random_walk.py

57 lines
2.3 KiB
Python

import argparse
import dgl
import torch
from ogb.nodeproppred import DglNodePropPredDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from gnnrec.config import DATA_DIR
from gnnrec.hge.utils import add_reverse_edges
def random_walk(g, metapaths, num_walks, walk_length, output_file):
"""在异构图上按指定的元路径随机游走,将轨迹保存到指定文件
:param g: DGLGraph 异构图
:param metapaths: Dict[str, List[str]] 起点类型到元路径的映射,元路径表示为边类型列表,起点和终点类型应该相同
:param num_walks: int 每个顶点的游走次数
:param walk_length: int 元路径重复次数
:param output_file: str 输出文件名
:return:
"""
with open(output_file, 'w') as f:
for ntype, metapath in metapaths.items():
print(ntype)
loader = DataLoader(torch.arange(g.num_nodes(ntype)), batch_size=200)
for b in tqdm(loader):
nodes = torch.repeat_interleave(b, num_walks)
traces, types = dgl.sampling.random_walk(g, nodes, metapath=metapath * walk_length)
f.writelines([trace2name(g, trace, types) + '\n' for trace in traces])
def trace2name(g, trace, types):
return ' '.join(g.ntypes[t] + '_' + str(int(n)) for n, t in zip(trace, types) if int(n) >= 0)
def main():
parser = argparse.ArgumentParser(description='ogbn-mag数据集 metapath2vec基于元路径的随机游走')
parser.add_argument('--num-walks', type=int, default=5, help='每个顶点游走次数')
parser.add_argument('--walk-length', type=int, default=16, help='元路径重复次数')
parser.add_argument('output_file', help='输出文件名')
args = parser.parse_args()
data = DglNodePropPredDataset('ogbn-mag', DATA_DIR)
g = add_reverse_edges(data[0][0])
metapaths = {
'author': ['writes', 'has_topic', 'has_topic_rev', 'writes_rev'], # APFPA
'paper': ['writes_rev', 'writes', 'has_topic', 'has_topic_rev'], # PAPFP
'field_of_study': ['has_topic_rev', 'writes_rev', 'writes', 'has_topic'], # FPAPF
'institution': ['affiliated_with_rev', 'writes', 'writes_rev', 'affiliated_with'] # IAPAI
}
random_walk(g, metapaths, args.num_walks, args.walk_length, args.output_file)
if __name__ == '__main__':
main()