54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
|
import torch
|
|||
|
|
|||
|
from gnnrec.config import DATA_DIR, MODEL_DIR
|
|||
|
from gnnrec.kgrec.data import OAGCSContrastDataset
|
|||
|
from gnnrec.kgrec.scibert import ContrastiveSciBERT
|
|||
|
|
|||
|
|
|||
|
class Context:
|
|||
|
|
|||
|
def __init__(self, paper_embeds, scibert_model):
|
|||
|
"""论文召回模块上下文对象
|
|||
|
|
|||
|
:param paper_embeds: tensor(N, d) 论文标题向量
|
|||
|
:param scibert_model: ContrastiveSciBERT 微调后的SciBERT模型
|
|||
|
"""
|
|||
|
self.paper_embeds = paper_embeds
|
|||
|
self.scibert_model = scibert_model
|
|||
|
|
|||
|
|
|||
|
def get_context():
|
|||
|
paper_embeds = torch.load(DATA_DIR / 'oag/cs/paper_feat.pkl', map_location='cpu')
|
|||
|
scibert_model = ContrastiveSciBERT(128, 0.07)
|
|||
|
scibert_model.load_state_dict(torch.load(MODEL_DIR / 'scibert.pt', map_location='cpu'))
|
|||
|
return Context(paper_embeds, scibert_model)
|
|||
|
|
|||
|
|
|||
|
def recall(ctx, query, k=1000):
|
|||
|
"""根据输入的查询词在oag-cs数据集召回论文
|
|||
|
|
|||
|
:param ctx: Context 上下文对象
|
|||
|
:param query: str 查询词
|
|||
|
:param k: int, optional 召回论文数量,默认为1000
|
|||
|
:return: List[float], List[int] Top k论文的相似度和id,按相似度降序排序
|
|||
|
"""
|
|||
|
q = ctx.scibert_model.get_embeds(query) # (1, d)
|
|||
|
q = q / q.norm()
|
|||
|
similarity = torch.mm(ctx.paper_embeds, q.t()).squeeze(dim=1) # (N,)
|
|||
|
score, pid = similarity.topk(k, dim=0)
|
|||
|
return score.tolist(), pid.tolist()
|
|||
|
|
|||
|
|
|||
|
def main():
|
|||
|
ctx = get_context()
|
|||
|
paper_titles = OAGCSContrastDataset(DATA_DIR / 'oag/cs/mag_papers.txt', 'all')
|
|||
|
while True:
|
|||
|
query = input('query> ').strip()
|
|||
|
score, pid = recall(ctx, query, 10)
|
|||
|
for i in range(len(pid)):
|
|||
|
print('{:.4f}\t{}'.format(score[i], paper_titles[pid[i]][0]))
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
main()
|