GNNRecom/gnnrec/kgrec/scibert.py

62 lines
2.5 KiB
Python
Raw Normal View History

2021-11-16 07:04:52 +00:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
class ContrastiveSciBERT(nn.Module):
def __init__(self, out_dim, tau, device='cpu'):
"""用于对比学习的SciBERT模型
:param out_dim: int 输出特征维数
:param tau: float 温度参数τ
:param device: torch.device, optional 默认为CPU
"""
super().__init__()
self.tau = tau
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
self.model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased').to(device)
self.linear = nn.Linear(self.model.config.hidden_size, out_dim)
def get_embeds(self, texts, max_length=64):
"""将文本编码为向量
:param texts: List[str] 输入文本列表长度为N
:param max_length: int, optional padding最大长度默认为64
:return: tensor(N, d_out)
"""
encoded = self.tokenizer(
texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt'
).to(self.device)
return self.linear(self.model(**encoded).pooler_output)
def calc_sim(self, texts_a, texts_b):
"""计算两组文本的相似度
:param texts_a: List[str] 输入文本A列表长度为N
:param texts_b: List[str] 输入文本B列表长度为N
:return: tensor(N, N) 相似度矩阵S[i, j] = cos(a[i], b[j]) / τ
"""
embeds_a = self.get_embeds(texts_a) # (N, d_out)
embeds_b = self.get_embeds(texts_b) # (N, d_out)
embeds_a = embeds_a / embeds_a.norm(dim=1, keepdim=True)
embeds_b = embeds_b / embeds_b.norm(dim=1, keepdim=True)
return embeds_a @ embeds_b.t() / self.tau
def forward(self, texts_a, texts_b):
"""计算两组文本的对比损失
:param texts_a: List[str] 输入文本A列表长度为N
:param texts_b: List[str] 输入文本B列表长度为N
:return: tensor(N, N), float A对B的相似度矩阵对比损失
"""
# logits_ab等价于预测概率对比损失等价于交叉熵损失
logits_ab = self.calc_sim(texts_a, texts_b)
logits_ba = logits_ab.t()
labels = torch.arange(len(texts_a), device=self.device)
loss_ab = F.cross_entropy(logits_ab, labels)
loss_ba = F.cross_entropy(logits_ba, labels)
return logits_ab, (loss_ab + loss_ba) / 2