GNNRecom/gnnrec/kgrec/rank.py

33 lines
883 B
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from gnnrec.config import DATA_DIR
class Context:
def __init__(self, recall_ctx, author_rank):
"""学者排名模块上下文对象
:param recall_ctx: gnnrec.kgrec.recall.Context
:param author_rank: {field_id: [author_id]} 领域学者排名
"""
self.recall_ctx = recall_ctx
# 之后需要author_embeds
self.author_rank = author_rank
def get_context(recall_ctx):
with open(DATA_DIR / 'rank/author_rank_train.json') as f:
author_rank = json.load(f)
return Context(recall_ctx, author_rank)
def rank(ctx, query):
"""根据输入的查询词在oag-cs数据集计算学者排名
:param ctx: Context 上下文对象
:param query: str 查询词
:return: List[float], List[int] 学者得分和id按得分降序排序
"""
return [], ctx.author_rank.get(query, [])