132 lines
5.2 KiB
Python
132 lines
5.2 KiB
Python
|
import argparse
|
||
|
|
||
|
import torch
|
||
|
import torch.optim as optim
|
||
|
from torch.utils.data import DataLoader
|
||
|
from tqdm import tqdm
|
||
|
from transformers import get_linear_schedule_with_warmup
|
||
|
|
||
|
from gnnrec.config import DATA_DIR, MODEL_DIR
|
||
|
from gnnrec.hge.utils import set_random_seed, get_device, accuracy
|
||
|
from gnnrec.kgrec.data import OAGCSContrastDataset
|
||
|
from gnnrec.kgrec.scibert import ContrastiveSciBERT
|
||
|
from gnnrec.kgrec.utils import iter_json
|
||
|
|
||
|
|
||
|
def collate(samples):
|
||
|
return map(list, zip(*samples))
|
||
|
|
||
|
|
||
|
def train(args):
|
||
|
set_random_seed(args.seed)
|
||
|
device = get_device(args.device)
|
||
|
|
||
|
raw_file = DATA_DIR / 'oag/cs/mag_papers.txt'
|
||
|
train_dataset = OAGCSContrastDataset(raw_file, split='train')
|
||
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate)
|
||
|
valid_dataset = OAGCSContrastDataset(raw_file, split='valid')
|
||
|
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate)
|
||
|
|
||
|
model = ContrastiveSciBERT(args.num_hidden, args.tau, device).to(device)
|
||
|
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
|
||
|
total_steps = len(train_loader) * args.epochs
|
||
|
scheduler = get_linear_schedule_with_warmup(
|
||
|
optimizer, num_warmup_steps=total_steps * 0.1, num_training_steps=total_steps
|
||
|
)
|
||
|
for epoch in range(args.epochs):
|
||
|
model.train()
|
||
|
losses, scores = [], []
|
||
|
for titles, keywords in tqdm(train_loader):
|
||
|
logits, loss = model(titles, keywords)
|
||
|
labels = torch.arange(len(titles), device=device)
|
||
|
losses.append(loss.item())
|
||
|
scores.append(score(logits, labels))
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
scheduler.step()
|
||
|
val_score = evaluate(valid_loader, model, device)
|
||
|
print('Epoch {:d} | Loss {:.4f} | Train Acc {:.4f} | Val Acc {:.4f}'.format(
|
||
|
epoch, sum(losses) / len(losses), sum(scores) / len(scores), val_score
|
||
|
))
|
||
|
model_save_path = MODEL_DIR / 'scibert.pt'
|
||
|
torch.save(model.state_dict(), model_save_path)
|
||
|
print('模型已保存到', model_save_path)
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def evaluate(loader, model, device):
|
||
|
model.eval()
|
||
|
scores = []
|
||
|
for titles, keywords in tqdm(loader):
|
||
|
logits = model.calc_sim(titles, keywords)
|
||
|
labels = torch.arange(len(titles), device=device)
|
||
|
scores.append(score(logits, labels))
|
||
|
return sum(scores) / len(scores)
|
||
|
|
||
|
|
||
|
def score(logits, labels):
|
||
|
return (accuracy(logits.argmax(dim=1), labels) + accuracy(logits.argmax(dim=0), labels)) / 2
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def infer(args):
|
||
|
device = get_device(args.device)
|
||
|
model = ContrastiveSciBERT(args.num_hidden, args.tau, device).to(device)
|
||
|
model.load_state_dict(torch.load(MODEL_DIR / 'scibert.pt', map_location=device))
|
||
|
model.eval()
|
||
|
|
||
|
raw_path = DATA_DIR / 'oag/cs'
|
||
|
dataset = OAGCSContrastDataset(raw_path / 'mag_papers.txt', split='all')
|
||
|
loader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate)
|
||
|
print('正在推断论文向量...')
|
||
|
h = []
|
||
|
for titles, _ in tqdm(loader):
|
||
|
h.append(model.get_embeds(titles).detach().cpu())
|
||
|
h = torch.cat(h) # (N_paper, d_hid)
|
||
|
h = h / h.norm(dim=1, keepdim=True)
|
||
|
torch.save(h, raw_path / 'paper_feat.pkl')
|
||
|
print('论文向量已保存到', raw_path / 'paper_feat.pkl')
|
||
|
|
||
|
fields = [f['name'] for f in iter_json(raw_path / 'mag_fields.txt')]
|
||
|
loader = DataLoader(fields, batch_size=args.batch_size)
|
||
|
print('正在推断领域向量...')
|
||
|
h = []
|
||
|
for fields in tqdm(loader):
|
||
|
h.append(model.get_embeds(fields).detach().cpu())
|
||
|
h = torch.cat(h) # (N_field, d_hid)
|
||
|
h = h / h.norm(dim=1, keepdim=True)
|
||
|
torch.save(h, raw_path / 'field_feat.pkl')
|
||
|
print('领域向量已保存到', raw_path / 'field_feat.pkl')
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(description='通过论文标题和关键词的对比学习对SciBERT模型进行fine-tune')
|
||
|
subparsers = parser.add_subparsers()
|
||
|
|
||
|
train_parser = subparsers.add_parser('train', help='训练')
|
||
|
train_parser.add_argument('--seed', type=int, default=42, help='随机数种子')
|
||
|
train_parser.add_argument('--device', type=int, default=0, help='GPU设备')
|
||
|
train_parser.add_argument('--num-hidden', type=int, default=128, help='隐藏层维数')
|
||
|
train_parser.add_argument('--tau', type=float, default=0.07, help='温度参数')
|
||
|
train_parser.add_argument('--epochs', type=int, default=5, help='训练epoch数')
|
||
|
train_parser.add_argument('--batch-size', type=int, default=64, help='批大小')
|
||
|
train_parser.add_argument('--lr', type=float, default=5e-5, help='学习率')
|
||
|
train_parser.set_defaults(func=train)
|
||
|
|
||
|
infer_parser = subparsers.add_parser('infer', help='推断')
|
||
|
infer_parser.add_argument('--device', type=int, default=0, help='GPU设备')
|
||
|
infer_parser.add_argument('--num-hidden', type=int, default=128, help='隐藏层维数')
|
||
|
infer_parser.add_argument('--tau', type=float, default=0.07, help='温度参数')
|
||
|
infer_parser.add_argument('--batch-size', type=int, default=64, help='批大小')
|
||
|
infer_parser.set_defaults(func=infer)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
print(args)
|
||
|
args.func(args)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|