GNNRecom/gnnrec/hge/cs/train.py

102 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from gnnrec.hge.cs.model import CorrectAndSmooth
from gnnrec.hge.utils import set_random_seed, get_device, load_data, calc_metrics, METRICS_STR
def train_base_model(base_model, feats, labels, train_idx, val_idx, test_idx, evaluator, args):
print('Training base model...')
optimizer = optim.Adam(base_model.parameters(), lr=args.lr)
for epoch in range(args.epochs):
base_model.train()
logits = base_model(feats)
loss = F.cross_entropy(logits[train_idx], labels[train_idx])
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(('Epoch {:d} | Loss {:.4f} | ' + METRICS_STR).format(
epoch, loss.item(),
*evaluate(base_model, feats, labels, train_idx, val_idx, test_idx, evaluator)
))
@torch.no_grad()
def evaluate(model, feats, labels, train_idx, val_idx, test_idx, evaluator):
model.eval()
logits = model(feats)
return calc_metrics(logits, labels, train_idx, val_idx, test_idx, evaluator)
def correct_and_smooth(base_model, g, feats, labels, train_idx, val_idx, test_idx, evaluator, args):
print('Training C&S...')
base_model.eval()
base_pred = base_model(feats).softmax(dim=1) # 注意要softmax
cs = CorrectAndSmooth(
args.num_correct_layers, args.correct_alpha, args.correct_norm,
args.num_smooth_layers, args.smooth_alpha, args.smooth_norm, args.scale
)
mask = torch.cat([train_idx, val_idx])
logits = cs(g, F.one_hot(labels).float(), base_pred, mask)
_, _, test_acc, _, _, test_f1 = calc_metrics(logits, labels, train_idx, val_idx, test_idx, evaluator)
print('Test Acc {:.4f} | Test Macro-F1 {:.4f}'.format(test_acc, test_f1))
def train(args):
set_random_seed(args.seed)
device = get_device(args.device)
data, _, feat, labels, _, train_idx, val_idx, test_idx, evaluator = \
load_data(args.dataset, device)
feat = (feat - feat.mean(dim=0)) / feat.std(dim=0)
# 标签传播图
if args.dataset in ('acm', 'dblp'):
pos_v, pos_u = data.pos
pg = dgl.graph((pos_u, pos_v), device=device)
else:
pg = dgl.load_graphs(args.prop_graph)[0][-1].to(device)
if args.dataset == 'oag-venue':
labels[labels == -1] = 0
base_model = nn.Linear(feat.shape[1], data.num_classes).to(device)
train_base_model(base_model, feat, labels, train_idx, val_idx, test_idx, evaluator, args)
correct_and_smooth(base_model, pg, feat, labels, train_idx, val_idx, test_idx, evaluator, args)
def main():
parser = argparse.ArgumentParser(description='训练C&S模型')
parser.add_argument('--seed', type=int, default=0, help='随机数种子')
parser.add_argument('--device', type=int, default=0, help='GPU设备')
parser.add_argument('--dataset', choices=['acm', 'dblp', 'ogbn-mag', 'oag-venue'], default='ogbn-mag', help='数据集')
# 基础模型
parser.add_argument('--epochs', type=int, default=300, help='基础模型训练epoch数')
parser.add_argument('--lr', type=float, default=0.01, help='基础模型学习率')
# C&S
parser.add_argument('--prop-graph', help='标签传播图所在路径')
parser.add_argument('--num-correct-layers', type=int, default=50, help='Correct步骤传播层数')
parser.add_argument('--correct-alpha', type=float, default=0.5, help='Correct步骤α')
parser.add_argument(
'--correct-norm', choices=['left', 'right', 'both'], default='both',
help='Correct步骤归一化方式'
)
parser.add_argument('--num-smooth-layers', type=int, default=50, help='Smooth步骤传播层数')
parser.add_argument('--smooth-alpha', type=float, default=0.5, help='Smooth步骤α')
parser.add_argument(
'--smooth-norm', choices=['left', 'right', 'both'], default='both',
help='Smooth步骤归一化方式'
)
parser.add_argument('--scale', type=float, default=20, help='放缩系数')
args = parser.parse_args()
print(args)
train(args)
if __name__ == '__main__':
main()