GNNRecom/gnnrec/kgrec/data/preprocess/build_author_rank.py

203 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import json
import math
import random
import dgl
import dgl.function as fn
import django
import numpy as np
import torch
from dgl.ops import edge_softmax
from sklearn.metrics import ndcg_score
from tqdm import tqdm
from gnnrec.config import DATA_DIR
from gnnrec.hge.utils import set_random_seed, add_reverse_edges
from gnnrec.kgrec.data import OAGCSDataset
from gnnrec.kgrec.utils import iter_json, precision_at_k, recall_at_k
def build_ground_truth_valid(args):
"""从AI 2000抓取的学者排名数据匹配学者id作为学者排名ground truth验证集。"""
field_map = {
'AAAI/IJCAI': 'artificial intelligence',
'Machine Learning': 'machine learning',
'Computer Vision': 'computer vision',
'Natural Language Processing': 'natural language processing',
'Robotics': 'robotics',
'Knowledge Engineering': 'knowledge engineering',
'Speech Recognition': 'speech recognition',
'Data Mining': 'data mining',
'Information Retrieval and Recommendation': 'information retrieval',
'Database': 'database',
'Human-Computer Interaction': 'human computer interaction',
'Computer Graphics': 'computer graphics',
'Multimedia': 'multimedia',
'Visualization': 'visualization',
'Security and Privacy': 'security privacy',
'Computer Networking': 'computer network',
'Computer Systems': 'operating system',
'Theory': 'theory',
'Chip Technology': 'chip',
'Internet of Things': 'internet of things',
}
with open(DATA_DIR / 'rank/ai2000.json', encoding='utf8') as f:
ai2000_author_rank = json.load(f)
django.setup()
from rank.models import Author
author_rank = {}
for field, scholars in ai2000_author_rank.items():
aid = []
for s in scholars:
qs = Author.objects.filter(name=s['name'], institution__name=s['org']).order_by('-n_citation')
if qs.exists():
aid.append(qs[0].id)
else:
qs = Author.objects.filter(name=s['name']).order_by('-n_citation')
aid.append(qs[0].id if qs.exists() else -1)
author_rank[field_map[field]] = aid
if not args.use_field_name:
field2id = {f['name']: i for i, f in enumerate(iter_json(DATA_DIR / 'oag/cs/mag_fields.txt'))}
author_rank = {field2id[f]: aid for f, aid in author_rank.items()}
with open(DATA_DIR / 'rank/author_rank_val.json', 'w') as f:
json.dump(author_rank, f)
print('结果已保存到', f.name)
def build_ground_truth_train(args):
"""根据某个领域的论文引用数加权求和构造学者排名作为ground truth训练集。"""
data = OAGCSDataset()
g = data[0]
g.nodes['paper'].data['citation'] = g.nodes['paper'].data['citation'].float().log1p()
g.edges['writes'].data['order'] = g.edges['writes'].data['order'].float()
apg = g['author', 'writes', 'paper']
# 1.筛选论文数>=num_papers的领域
field_in_degree, fid = g.in_degrees(g.nodes('field'), etype='has_field').sort(descending=True)
fid = fid[field_in_degree >= args.num_papers].tolist()
# 2.对每个领域召回论文,构造学者-论文子图,通过论文引用数之和对学者排名
author_rank = {}
for i in tqdm(fid):
pid, _ = g.in_edges(i, etype='has_field')
sg = add_reverse_edges(dgl.in_subgraph(apg, {'paper': pid}, relabel_nodes=True))
# 第k作者的权重为1/k最后一个视为通讯作者权重为1/2
sg.edges['writes'].data['w'] = 1.0 / sg.edges['writes'].data['order']
sg.update_all(fn.copy_e('w', 'w'), fn.min('w', 'mw'), etype='writes')
sg.apply_edges(fn.copy_u('mw', 'mw'), etype='writes_rev')
w, mw = sg.edges['writes'].data.pop('w'), sg.edges['writes_rev'].data.pop('mw')
w[w == mw] = 0.5
# 每篇论文所有作者的权重归一化,每个学者所有论文的引用数加权求和
p = edge_softmax(sg['author', 'writes', 'paper'], torch.log(w).unsqueeze(dim=1))
sg.edges['writes_rev'].data['p'] = p.squeeze(dim=1)
sg.update_all(fn.u_mul_e('citation', 'p', 'c'), fn.sum('c', 'c'), etype='writes_rev')
author_citation = sg.nodes['author'].data['c']
_, aid = author_citation.topk(args.num_authors)
aid = sg.nodes['author'].data[dgl.NID][aid]
author_rank[i] = aid.tolist()
if args.use_field_name:
fields = [f['name'] for f in iter_json(DATA_DIR / 'oag/cs/mag_fields.txt')]
author_rank = {fields[i]: aid for i, aid in author_rank.items()}
with open(DATA_DIR / 'rank/author_rank_train.json', 'w') as f:
json.dump(author_rank, f)
print('结果已保存到', f.name)
def evaluate_ground_truth(args):
"""评估ground truth训练集的质量。"""
with open(DATA_DIR / 'rank/author_rank_val.json') as f:
author_rank_val = json.load(f)
with open(DATA_DIR / 'rank/author_rank_train.json') as f:
author_rank_train = json.load(f)
fields = list(set(author_rank_val) & set(author_rank_train))
author_rank_val = {k: v for k, v in author_rank_val.items() if k in fields}
author_rank_train = {k: v for k, v in author_rank_train.items() if k in fields}
num_authors = OAGCSDataset()[0].num_nodes('author')
true_relevance = np.zeros((len(fields), num_authors), dtype=np.int32)
scores = np.zeros_like(true_relevance)
for i, f in enumerate(fields):
for r, a in enumerate(author_rank_val[f]):
if a != -1:
true_relevance[i, a] = math.ceil((100 - r) / 10)
for r, a in enumerate(author_rank_train[f]):
scores[i, a] = len(author_rank_train[f]) - r
for k in (100, 50, 20, 10, 5):
print('nDGC@{0}={1:.4f}\tPrecision@{0}={2:.4f}\tRecall@{0}={3:.4f}'.format(
k, ndcg_score(true_relevance, scores, k=k, ignore_ties=True),
sum(precision_at_k(author_rank_val[f], author_rank_train[f], k) for f in fields) / len(fields),
sum(recall_at_k(author_rank_val[f], author_rank_train[f], k) for f in fields) / len(fields)
))
def sample_triplets(args):
set_random_seed(args.seed)
with open(DATA_DIR / 'rank/author_rank_train.json') as f:
author_rank = json.load(f)
# 三元组:(t, ap, an)表示对于领域t学者ap的排名在an之前
triplets = []
for fid, aid in author_rank.items():
fid = int(fid)
n = len(aid)
easy_margin, hard_margin = int(n * args.easy_margin), int(n * args.hard_margin)
num_triplets = min(args.max_num, 2 * n - easy_margin - hard_margin)
num_hard = int(num_triplets * args.hard_ratio)
num_easy = num_triplets - num_hard
triplets.extend(
(fid, aid[i], aid[i + easy_margin])
for i in random.sample(range(n - easy_margin), num_easy)
)
triplets.extend(
(fid, aid[i], aid[i + hard_margin])
for i in random.sample(range(n - hard_margin), num_hard)
)
with open(DATA_DIR / 'rank/author_rank_triplets.txt', 'w') as f:
for t, ap, an in triplets:
f.write(f'{t} {ap} {an}\n')
print('结果已保存到', f.name)
def main():
parser = argparse.ArgumentParser(description='基于oag-cs数据集构造学者排名数据集')
subparsers = parser.add_subparsers()
build_val_parser = subparsers.add_parser('build-val', help='构造学者排名验证集')
build_val_parser.add_argument('--use-field-name', action='store_true', help='使用领域名称(用于调试)')
build_val_parser.set_defaults(func=build_ground_truth_valid)
build_train_parser = subparsers.add_parser('build-train', help='构造学者排名训练集')
build_train_parser.add_argument('--num-papers', type=int, default=5000, help='筛选领域的论文数阈值')
build_train_parser.add_argument('--num-authors', type=int, default=100, help='每个领域取top k的学者数量')
build_train_parser.add_argument('--use-field-name', action='store_true', help='使用领域名称(用于调试)')
build_train_parser.set_defaults(func=build_ground_truth_train)
evaluate_parser = subparsers.add_parser('eval', help='评估ground truth训练集的质量')
evaluate_parser.set_defaults(func=evaluate_ground_truth)
sample_parser = subparsers.add_parser('sample', help='采样三元组')
sample_parser.add_argument('--seed', type=int, default=0, help='随机数种子')
sample_parser.add_argument('--max-num', type=int, default=100, help='每个领域采样三元组最大数量')
sample_parser.add_argument('--easy-margin', type=float, default=0.2, help='简单样本间隔(百分比)')
sample_parser.add_argument('--hard-margin', type=float, default=0.05, help='困难样本间隔(百分比)')
sample_parser.add_argument('--hard-ratio', type=float, default=0.5, help='困难样本比例')
sample_parser.set_defaults(func=sample_triplets)
args = parser.parse_args()
print(args)
args.func(args)
if __name__ == '__main__':
main()