203 lines
9.0 KiB
Python
203 lines
9.0 KiB
Python
|
import argparse
|
|||
|
import json
|
|||
|
import math
|
|||
|
import random
|
|||
|
|
|||
|
import dgl
|
|||
|
import dgl.function as fn
|
|||
|
import django
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
from dgl.ops import edge_softmax
|
|||
|
from sklearn.metrics import ndcg_score
|
|||
|
from tqdm import tqdm
|
|||
|
|
|||
|
from gnnrec.config import DATA_DIR
|
|||
|
from gnnrec.hge.utils import set_random_seed, add_reverse_edges
|
|||
|
from gnnrec.kgrec.data import OAGCSDataset
|
|||
|
from gnnrec.kgrec.utils import iter_json, precision_at_k, recall_at_k
|
|||
|
|
|||
|
|
|||
|
def build_ground_truth_valid(args):
|
|||
|
"""从AI 2000抓取的学者排名数据匹配学者id,作为学者排名ground truth验证集。"""
|
|||
|
field_map = {
|
|||
|
'AAAI/IJCAI': 'artificial intelligence',
|
|||
|
'Machine Learning': 'machine learning',
|
|||
|
'Computer Vision': 'computer vision',
|
|||
|
'Natural Language Processing': 'natural language processing',
|
|||
|
'Robotics': 'robotics',
|
|||
|
'Knowledge Engineering': 'knowledge engineering',
|
|||
|
'Speech Recognition': 'speech recognition',
|
|||
|
'Data Mining': 'data mining',
|
|||
|
'Information Retrieval and Recommendation': 'information retrieval',
|
|||
|
'Database': 'database',
|
|||
|
'Human-Computer Interaction': 'human computer interaction',
|
|||
|
'Computer Graphics': 'computer graphics',
|
|||
|
'Multimedia': 'multimedia',
|
|||
|
'Visualization': 'visualization',
|
|||
|
'Security and Privacy': 'security privacy',
|
|||
|
'Computer Networking': 'computer network',
|
|||
|
'Computer Systems': 'operating system',
|
|||
|
'Theory': 'theory',
|
|||
|
'Chip Technology': 'chip',
|
|||
|
'Internet of Things': 'internet of things',
|
|||
|
}
|
|||
|
with open(DATA_DIR / 'rank/ai2000.json', encoding='utf8') as f:
|
|||
|
ai2000_author_rank = json.load(f)
|
|||
|
|
|||
|
django.setup()
|
|||
|
from rank.models import Author
|
|||
|
|
|||
|
author_rank = {}
|
|||
|
for field, scholars in ai2000_author_rank.items():
|
|||
|
aid = []
|
|||
|
for s in scholars:
|
|||
|
qs = Author.objects.filter(name=s['name'], institution__name=s['org']).order_by('-n_citation')
|
|||
|
if qs.exists():
|
|||
|
aid.append(qs[0].id)
|
|||
|
else:
|
|||
|
qs = Author.objects.filter(name=s['name']).order_by('-n_citation')
|
|||
|
aid.append(qs[0].id if qs.exists() else -1)
|
|||
|
author_rank[field_map[field]] = aid
|
|||
|
if not args.use_field_name:
|
|||
|
field2id = {f['name']: i for i, f in enumerate(iter_json(DATA_DIR / 'oag/cs/mag_fields.txt'))}
|
|||
|
author_rank = {field2id[f]: aid for f, aid in author_rank.items()}
|
|||
|
|
|||
|
with open(DATA_DIR / 'rank/author_rank_val.json', 'w') as f:
|
|||
|
json.dump(author_rank, f)
|
|||
|
print('结果已保存到', f.name)
|
|||
|
|
|||
|
|
|||
|
def build_ground_truth_train(args):
|
|||
|
"""根据某个领域的论文引用数加权求和构造学者排名,作为ground truth训练集。"""
|
|||
|
data = OAGCSDataset()
|
|||
|
g = data[0]
|
|||
|
g.nodes['paper'].data['citation'] = g.nodes['paper'].data['citation'].float().log1p()
|
|||
|
g.edges['writes'].data['order'] = g.edges['writes'].data['order'].float()
|
|||
|
apg = g['author', 'writes', 'paper']
|
|||
|
|
|||
|
# 1.筛选论文数>=num_papers的领域
|
|||
|
field_in_degree, fid = g.in_degrees(g.nodes('field'), etype='has_field').sort(descending=True)
|
|||
|
fid = fid[field_in_degree >= args.num_papers].tolist()
|
|||
|
|
|||
|
# 2.对每个领域召回论文,构造学者-论文子图,通过论文引用数之和对学者排名
|
|||
|
author_rank = {}
|
|||
|
for i in tqdm(fid):
|
|||
|
pid, _ = g.in_edges(i, etype='has_field')
|
|||
|
sg = add_reverse_edges(dgl.in_subgraph(apg, {'paper': pid}, relabel_nodes=True))
|
|||
|
|
|||
|
# 第k作者的权重为1/k,最后一个视为通讯作者,权重为1/2
|
|||
|
sg.edges['writes'].data['w'] = 1.0 / sg.edges['writes'].data['order']
|
|||
|
sg.update_all(fn.copy_e('w', 'w'), fn.min('w', 'mw'), etype='writes')
|
|||
|
sg.apply_edges(fn.copy_u('mw', 'mw'), etype='writes_rev')
|
|||
|
w, mw = sg.edges['writes'].data.pop('w'), sg.edges['writes_rev'].data.pop('mw')
|
|||
|
w[w == mw] = 0.5
|
|||
|
|
|||
|
# 每篇论文所有作者的权重归一化,每个学者所有论文的引用数加权求和
|
|||
|
p = edge_softmax(sg['author', 'writes', 'paper'], torch.log(w).unsqueeze(dim=1))
|
|||
|
sg.edges['writes_rev'].data['p'] = p.squeeze(dim=1)
|
|||
|
sg.update_all(fn.u_mul_e('citation', 'p', 'c'), fn.sum('c', 'c'), etype='writes_rev')
|
|||
|
author_citation = sg.nodes['author'].data['c']
|
|||
|
|
|||
|
_, aid = author_citation.topk(args.num_authors)
|
|||
|
aid = sg.nodes['author'].data[dgl.NID][aid]
|
|||
|
author_rank[i] = aid.tolist()
|
|||
|
if args.use_field_name:
|
|||
|
fields = [f['name'] for f in iter_json(DATA_DIR / 'oag/cs/mag_fields.txt')]
|
|||
|
author_rank = {fields[i]: aid for i, aid in author_rank.items()}
|
|||
|
|
|||
|
with open(DATA_DIR / 'rank/author_rank_train.json', 'w') as f:
|
|||
|
json.dump(author_rank, f)
|
|||
|
print('结果已保存到', f.name)
|
|||
|
|
|||
|
|
|||
|
def evaluate_ground_truth(args):
|
|||
|
"""评估ground truth训练集的质量。"""
|
|||
|
with open(DATA_DIR / 'rank/author_rank_val.json') as f:
|
|||
|
author_rank_val = json.load(f)
|
|||
|
with open(DATA_DIR / 'rank/author_rank_train.json') as f:
|
|||
|
author_rank_train = json.load(f)
|
|||
|
fields = list(set(author_rank_val) & set(author_rank_train))
|
|||
|
author_rank_val = {k: v for k, v in author_rank_val.items() if k in fields}
|
|||
|
author_rank_train = {k: v for k, v in author_rank_train.items() if k in fields}
|
|||
|
|
|||
|
num_authors = OAGCSDataset()[0].num_nodes('author')
|
|||
|
true_relevance = np.zeros((len(fields), num_authors), dtype=np.int32)
|
|||
|
scores = np.zeros_like(true_relevance)
|
|||
|
for i, f in enumerate(fields):
|
|||
|
for r, a in enumerate(author_rank_val[f]):
|
|||
|
if a != -1:
|
|||
|
true_relevance[i, a] = math.ceil((100 - r) / 10)
|
|||
|
for r, a in enumerate(author_rank_train[f]):
|
|||
|
scores[i, a] = len(author_rank_train[f]) - r
|
|||
|
|
|||
|
for k in (100, 50, 20, 10, 5):
|
|||
|
print('nDGC@{0}={1:.4f}\tPrecision@{0}={2:.4f}\tRecall@{0}={3:.4f}'.format(
|
|||
|
k, ndcg_score(true_relevance, scores, k=k, ignore_ties=True),
|
|||
|
sum(precision_at_k(author_rank_val[f], author_rank_train[f], k) for f in fields) / len(fields),
|
|||
|
sum(recall_at_k(author_rank_val[f], author_rank_train[f], k) for f in fields) / len(fields)
|
|||
|
))
|
|||
|
|
|||
|
|
|||
|
def sample_triplets(args):
|
|||
|
set_random_seed(args.seed)
|
|||
|
with open(DATA_DIR / 'rank/author_rank_train.json') as f:
|
|||
|
author_rank = json.load(f)
|
|||
|
|
|||
|
# 三元组:(t, ap, an),表示对于领域t,学者ap的排名在an之前
|
|||
|
triplets = []
|
|||
|
for fid, aid in author_rank.items():
|
|||
|
fid = int(fid)
|
|||
|
n = len(aid)
|
|||
|
easy_margin, hard_margin = int(n * args.easy_margin), int(n * args.hard_margin)
|
|||
|
num_triplets = min(args.max_num, 2 * n - easy_margin - hard_margin)
|
|||
|
num_hard = int(num_triplets * args.hard_ratio)
|
|||
|
num_easy = num_triplets - num_hard
|
|||
|
triplets.extend(
|
|||
|
(fid, aid[i], aid[i + easy_margin])
|
|||
|
for i in random.sample(range(n - easy_margin), num_easy)
|
|||
|
)
|
|||
|
triplets.extend(
|
|||
|
(fid, aid[i], aid[i + hard_margin])
|
|||
|
for i in random.sample(range(n - hard_margin), num_hard)
|
|||
|
)
|
|||
|
|
|||
|
with open(DATA_DIR / 'rank/author_rank_triplets.txt', 'w') as f:
|
|||
|
for t, ap, an in triplets:
|
|||
|
f.write(f'{t} {ap} {an}\n')
|
|||
|
print('结果已保存到', f.name)
|
|||
|
|
|||
|
|
|||
|
def main():
|
|||
|
parser = argparse.ArgumentParser(description='基于oag-cs数据集构造学者排名数据集')
|
|||
|
subparsers = parser.add_subparsers()
|
|||
|
|
|||
|
build_val_parser = subparsers.add_parser('build-val', help='构造学者排名验证集')
|
|||
|
build_val_parser.add_argument('--use-field-name', action='store_true', help='使用领域名称(用于调试)')
|
|||
|
build_val_parser.set_defaults(func=build_ground_truth_valid)
|
|||
|
|
|||
|
build_train_parser = subparsers.add_parser('build-train', help='构造学者排名训练集')
|
|||
|
build_train_parser.add_argument('--num-papers', type=int, default=5000, help='筛选领域的论文数阈值')
|
|||
|
build_train_parser.add_argument('--num-authors', type=int, default=100, help='每个领域取top k的学者数量')
|
|||
|
build_train_parser.add_argument('--use-field-name', action='store_true', help='使用领域名称(用于调试)')
|
|||
|
build_train_parser.set_defaults(func=build_ground_truth_train)
|
|||
|
|
|||
|
evaluate_parser = subparsers.add_parser('eval', help='评估ground truth训练集的质量')
|
|||
|
evaluate_parser.set_defaults(func=evaluate_ground_truth)
|
|||
|
|
|||
|
sample_parser = subparsers.add_parser('sample', help='采样三元组')
|
|||
|
sample_parser.add_argument('--seed', type=int, default=0, help='随机数种子')
|
|||
|
sample_parser.add_argument('--max-num', type=int, default=100, help='每个领域采样三元组最大数量')
|
|||
|
sample_parser.add_argument('--easy-margin', type=float, default=0.2, help='简单样本间隔(百分比)')
|
|||
|
sample_parser.add_argument('--hard-margin', type=float, default=0.05, help='困难样本间隔(百分比)')
|
|||
|
sample_parser.add_argument('--hard-ratio', type=float, default=0.5, help='困难样本比例')
|
|||
|
sample_parser.set_defaults(func=sample_triplets)
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
print(args)
|
|||
|
args.func(args)
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
main()
|