GNNRecom/gnnrec/hge/rhco/smooth.py

76 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import dgl
import torch
import torch.nn.functional as F
from gnnrec.hge.cs.model import LabelPropagation
from gnnrec.hge.rhco.model import RHCO
from gnnrec.hge.utils import get_device, load_data, add_node_feat, calc_metrics
def smooth(base_pred, g, labels, mask, args):
cs = LabelPropagation(args.num_smooth_layers, args.smooth_alpha, args.smooth_norm)
labels = F.one_hot(labels).float()
base_pred[mask] = labels[mask]
return cs(g, base_pred)
def main():
args = parse_args()
print(args)
device = get_device(args.device)
data, g, _, labels, predict_ntype, train_idx, val_idx, test_idx, evaluator = \
load_data(args.dataset, device)
add_node_feat(g, 'pretrained', args.node_embed_path, True)
if args.dataset == 'oag-venue':
labels[labels == -1] = 0
(*mgs, pos_g), _ = dgl.load_graphs(args.pos_graph_path)
pos_g = pos_g.to(device)
model = RHCO(
{ntype: g.nodes[ntype].data['feat'].shape[1] for ntype in g.ntypes},
args.num_hidden, data.num_classes, args.num_rel_hidden, args.num_heads,
g.ntypes, g.canonical_etypes, predict_ntype, args.num_layers, args.dropout,
len(mgs), args.tau, args.lambda_
).to(device)
model.load_state_dict(torch.load(args.model_path, map_location=device))
model.eval()
base_pred = model.get_embeds(g, args.neighbor_size, args.batch_size, device)
mask = torch.cat([train_idx, val_idx])
logits = smooth(base_pred, pos_g, labels, mask, args)
_, _, test_acc, _, _, test_f1 = calc_metrics(logits, labels, train_idx, val_idx, test_idx, evaluator)
print('After smoothing: Test Acc {:.4f} | Test Macro-F1 {:.4f}'.format(test_acc, test_f1))
def parse_args():
parser = argparse.ArgumentParser(description='RHCO+C&S仅Smooth步骤')
parser.add_argument('--device', type=int, default=0, help='GPU设备')
parser.add_argument('--dataset', choices=['ogbn-mag', 'oag-venue'], default='ogbn-mag', help='数据集')
# RHCO
parser.add_argument('--num-hidden', type=int, default=64, help='隐藏层维数')
parser.add_argument('--num-rel-hidden', type=int, default=8, help='关系表示的隐藏层维数')
parser.add_argument('--num-heads', type=int, default=8, help='注意力头数')
parser.add_argument('--num-layers', type=int, default=2, help='层数')
parser.add_argument('--dropout', type=float, default=0.5, help='Dropout概率')
parser.add_argument('--tau', type=float, default=0.8, help='温度参数')
parser.add_argument('--lambda', type=float, default=0.5, dest='lambda_', help='对比损失的平衡系数')
parser.add_argument('--batch-size', type=int, default=1024, help='批大小')
parser.add_argument('--neighbor-size', type=int, default=10, help='邻居采样数')
parser.add_argument('node_embed_path', help='预训练顶点嵌入路径')
parser.add_argument('pos_graph_path', help='正样本图保存路径')
parser.add_argument('model_path', help='预训练的模型保存路径')
# C&S
parser.add_argument('--num-smooth-layers', type=int, default=50, help='Smooth步骤传播层数')
parser.add_argument('--smooth-alpha', type=float, default=0.5, help='Smooth步骤α')
parser.add_argument(
'--smooth-norm', choices=['left', 'right', 'both'], default='right',
help='Smooth步骤归一化方式'
)
return parser.parse_args()
if __name__ == '__main__':
main()