57 lines
2.3 KiB
Python
57 lines
2.3 KiB
Python
import argparse
|
|
|
|
import dgl
|
|
import torch
|
|
from ogb.nodeproppred import DglNodePropPredDataset
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
from gnnrec.config import DATA_DIR
|
|
from gnnrec.hge.utils import add_reverse_edges
|
|
|
|
|
|
def random_walk(g, metapaths, num_walks, walk_length, output_file):
|
|
"""在异构图上按指定的元路径随机游走,将轨迹保存到指定文件
|
|
|
|
:param g: DGLGraph 异构图
|
|
:param metapaths: Dict[str, List[str]] 起点类型到元路径的映射,元路径表示为边类型列表,起点和终点类型应该相同
|
|
:param num_walks: int 每个顶点的游走次数
|
|
:param walk_length: int 元路径重复次数
|
|
:param output_file: str 输出文件名
|
|
:return:
|
|
"""
|
|
with open(output_file, 'w') as f:
|
|
for ntype, metapath in metapaths.items():
|
|
print(ntype)
|
|
loader = DataLoader(torch.arange(g.num_nodes(ntype)), batch_size=200)
|
|
for b in tqdm(loader):
|
|
nodes = torch.repeat_interleave(b, num_walks)
|
|
traces, types = dgl.sampling.random_walk(g, nodes, metapath=metapath * walk_length)
|
|
f.writelines([trace2name(g, trace, types) + '\n' for trace in traces])
|
|
|
|
|
|
def trace2name(g, trace, types):
|
|
return ' '.join(g.ntypes[t] + '_' + str(int(n)) for n, t in zip(trace, types) if int(n) >= 0)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='ogbn-mag数据集 metapath2vec基于元路径的随机游走')
|
|
parser.add_argument('--num-walks', type=int, default=5, help='每个顶点游走次数')
|
|
parser.add_argument('--walk-length', type=int, default=16, help='元路径重复次数')
|
|
parser.add_argument('output_file', help='输出文件名')
|
|
args = parser.parse_args()
|
|
|
|
data = DglNodePropPredDataset('ogbn-mag', DATA_DIR)
|
|
g = add_reverse_edges(data[0][0])
|
|
metapaths = {
|
|
'author': ['writes', 'has_topic', 'has_topic_rev', 'writes_rev'], # APFPA
|
|
'paper': ['writes_rev', 'writes', 'has_topic', 'has_topic_rev'], # PAPFP
|
|
'field_of_study': ['has_topic_rev', 'writes_rev', 'writes', 'has_topic'], # FPAPF
|
|
'institution': ['affiliated_with_rev', 'writes', 'writes_rev', 'affiliated_with'] # IAPAI
|
|
}
|
|
random_walk(g, metapaths, args.num_walks, args.walk_length, args.output_file)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|