GNNRecom/gnnrec/kgrec/random_walk.py

29 lines
1.2 KiB
Python
Raw Normal View History

2021-11-16 07:04:52 +00:00
import argparse
from gnnrec.hge.metapath2vec.random_walk import random_walk
from gnnrec.hge.utils import add_reverse_edges
from gnnrec.kgrec.data import OAGCSDataset
def main():
parser = argparse.ArgumentParser(description='oag-cs数据集 metapath2vec基于元路径的随机游走')
parser.add_argument('--num-walks', type=int, default=4, help='每个顶点游走次数')
parser.add_argument('--walk-length', type=int, default=10, help='元路径重复次数')
parser.add_argument('output_file', help='输出文件名')
args = parser.parse_args()
data = OAGCSDataset()
g = add_reverse_edges(data[0])
metapaths = {
'author': ['writes', 'published_at', 'published_at_rev', 'writes_rev'], # APVPA
'paper': ['writes_rev', 'writes', 'published_at', 'published_at_rev', 'has_field', 'has_field_rev'], # PAPVPFP
'venue': ['published_at_rev', 'writes_rev', 'writes', 'published_at'], # VPAPV
'field': ['has_field_rev', 'writes_rev', 'writes', 'has_field'], # FPAPF
'institution': ['affiliated_with_rev', 'writes', 'writes_rev', 'affiliated_with'] # IAPAI
}
random_walk(g, metapaths, args.num_walks, args.walk_length, args.output_file)
if __name__ == '__main__':
main()