GNNRecom/gnnrec/kgrec/recall.py

54 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from gnnrec.config import DATA_DIR, MODEL_DIR
from gnnrec.kgrec.data import OAGCSContrastDataset
from gnnrec.kgrec.scibert import ContrastiveSciBERT
class Context:
def __init__(self, paper_embeds, scibert_model):
"""论文召回模块上下文对象
:param paper_embeds: tensor(N, d) 论文标题向量
:param scibert_model: ContrastiveSciBERT 微调后的SciBERT模型
"""
self.paper_embeds = paper_embeds
self.scibert_model = scibert_model
def get_context():
paper_embeds = torch.load(DATA_DIR / 'oag/cs/paper_feat.pkl', map_location='cpu')
scibert_model = ContrastiveSciBERT(128, 0.07)
scibert_model.load_state_dict(torch.load(MODEL_DIR / 'scibert.pt', map_location='cpu'))
return Context(paper_embeds, scibert_model)
def recall(ctx, query, k=1000):
"""根据输入的查询词在oag-cs数据集召回论文
:param ctx: Context 上下文对象
:param query: str 查询词
:param k: int, optional 召回论文数量默认为1000
:return: List[float], List[int] Top k论文的相似度和id按相似度降序排序
"""
q = ctx.scibert_model.get_embeds(query) # (1, d)
q = q / q.norm()
similarity = torch.mm(ctx.paper_embeds, q.t()).squeeze(dim=1) # (N,)
score, pid = similarity.topk(k, dim=0)
return score.tolist(), pid.tolist()
def main():
ctx = get_context()
paper_titles = OAGCSContrastDataset(DATA_DIR / 'oag/cs/mag_papers.txt', 'all')
while True:
query = input('query> ').strip()
score, pid = recall(ctx, query, 10)
for i in range(len(pid)):
print('{:.4f}\t{}'.format(score[i], paper_titles[pid[i]][0]))
if __name__ == '__main__':
main()