62 lines
2.5 KiB
Python
62 lines
2.5 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from transformers import AutoTokenizer, AutoModel
|
||
|
||
|
||
class ContrastiveSciBERT(nn.Module):
|
||
|
||
def __init__(self, out_dim, tau, device='cpu'):
|
||
"""用于对比学习的SciBERT模型
|
||
|
||
:param out_dim: int 输出特征维数
|
||
:param tau: float 温度参数τ
|
||
:param device: torch.device, optional 默认为CPU
|
||
"""
|
||
super().__init__()
|
||
self.tau = tau
|
||
self.device = device
|
||
self.tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
|
||
self.model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased').to(device)
|
||
self.linear = nn.Linear(self.model.config.hidden_size, out_dim)
|
||
|
||
def get_embeds(self, texts, max_length=64):
|
||
"""将文本编码为向量
|
||
|
||
:param texts: List[str] 输入文本列表,长度为N
|
||
:param max_length: int, optional padding最大长度,默认为64
|
||
:return: tensor(N, d_out)
|
||
"""
|
||
encoded = self.tokenizer(
|
||
texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt'
|
||
).to(self.device)
|
||
return self.linear(self.model(**encoded).pooler_output)
|
||
|
||
def calc_sim(self, texts_a, texts_b):
|
||
"""计算两组文本的相似度
|
||
|
||
:param texts_a: List[str] 输入文本A列表,长度为N
|
||
:param texts_b: List[str] 输入文本B列表,长度为N
|
||
:return: tensor(N, N) 相似度矩阵,S[i, j] = cos(a[i], b[j]) / τ
|
||
"""
|
||
embeds_a = self.get_embeds(texts_a) # (N, d_out)
|
||
embeds_b = self.get_embeds(texts_b) # (N, d_out)
|
||
embeds_a = embeds_a / embeds_a.norm(dim=1, keepdim=True)
|
||
embeds_b = embeds_b / embeds_b.norm(dim=1, keepdim=True)
|
||
return embeds_a @ embeds_b.t() / self.tau
|
||
|
||
def forward(self, texts_a, texts_b):
|
||
"""计算两组文本的对比损失
|
||
|
||
:param texts_a: List[str] 输入文本A列表,长度为N
|
||
:param texts_b: List[str] 输入文本B列表,长度为N
|
||
:return: tensor(N, N), float A对B的相似度矩阵,对比损失
|
||
"""
|
||
# logits_ab等价于预测概率,对比损失等价于交叉熵损失
|
||
logits_ab = self.calc_sim(texts_a, texts_b)
|
||
logits_ba = logits_ab.t()
|
||
labels = torch.arange(len(texts_a), device=self.device)
|
||
loss_ab = F.cross_entropy(logits_ab, labels)
|
||
loss_ba = F.cross_entropy(logits_ba, labels)
|
||
return logits_ab, (loss_ab + loss_ba) / 2
|