33 lines
883 B
Python
33 lines
883 B
Python
import json
|
||
|
||
from gnnrec.config import DATA_DIR
|
||
|
||
|
||
class Context:
|
||
|
||
def __init__(self, recall_ctx, author_rank):
|
||
"""学者排名模块上下文对象
|
||
|
||
:param recall_ctx: gnnrec.kgrec.recall.Context
|
||
:param author_rank: {field_id: [author_id]} 领域学者排名
|
||
"""
|
||
self.recall_ctx = recall_ctx
|
||
# 之后需要:author_embeds
|
||
self.author_rank = author_rank
|
||
|
||
|
||
def get_context(recall_ctx):
|
||
with open(DATA_DIR / 'rank/author_rank_train.json') as f:
|
||
author_rank = json.load(f)
|
||
return Context(recall_ctx, author_rank)
|
||
|
||
|
||
def rank(ctx, query):
|
||
"""根据输入的查询词在oag-cs数据集计算学者排名
|
||
|
||
:param ctx: Context 上下文对象
|
||
:param query: str 查询词
|
||
:return: List[float], List[int] 学者得分和id,按得分降序排序
|
||
"""
|
||
return [], ctx.author_rank.get(query, [])
|