GNNRecom/gnnrec/kgrec/scibert.py

62 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
class ContrastiveSciBERT(nn.Module):
def __init__(self, out_dim, tau, device='cpu'):
"""用于对比学习的SciBERT模型
:param out_dim: int 输出特征维数
:param tau: float 温度参数τ
:param device: torch.device, optional 默认为CPU
"""
super().__init__()
self.tau = tau
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
self.model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased').to(device)
self.linear = nn.Linear(self.model.config.hidden_size, out_dim)
def get_embeds(self, texts, max_length=64):
"""将文本编码为向量
:param texts: List[str] 输入文本列表长度为N
:param max_length: int, optional padding最大长度默认为64
:return: tensor(N, d_out)
"""
encoded = self.tokenizer(
texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt'
).to(self.device)
return self.linear(self.model(**encoded).pooler_output)
def calc_sim(self, texts_a, texts_b):
"""计算两组文本的相似度
:param texts_a: List[str] 输入文本A列表长度为N
:param texts_b: List[str] 输入文本B列表长度为N
:return: tensor(N, N) 相似度矩阵S[i, j] = cos(a[i], b[j]) / τ
"""
embeds_a = self.get_embeds(texts_a) # (N, d_out)
embeds_b = self.get_embeds(texts_b) # (N, d_out)
embeds_a = embeds_a / embeds_a.norm(dim=1, keepdim=True)
embeds_b = embeds_b / embeds_b.norm(dim=1, keepdim=True)
return embeds_a @ embeds_b.t() / self.tau
def forward(self, texts_a, texts_b):
"""计算两组文本的对比损失
:param texts_a: List[str] 输入文本A列表长度为N
:param texts_b: List[str] 输入文本B列表长度为N
:return: tensor(N, N), float A对B的相似度矩阵对比损失
"""
# logits_ab等价于预测概率对比损失等价于交叉熵损失
logits_ab = self.calc_sim(texts_a, texts_b)
logits_ba = logits_ab.t()
labels = torch.arange(len(texts_a), device=self.device)
loss_ab = F.cross_entropy(logits_ab, labels)
loss_ba = F.cross_entropy(logits_ba, labels)
return logits_ab, (loss_ab + loss_ba) / 2