GNNRecom/gnnrec/hge/cs/model.py

114 lines
4.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import dgl.function as fn
import torch
import torch.nn as nn
class LabelPropagation(nn.Module):
def __init__(self, num_layers, alpha, norm):
"""标签传播模型
.. math::
Y^{(t+1)} = \\alpha SY^{(t)} + (1-\\alpha)Y, Y^{(0)} = Y
:param num_layers: int 传播层数
:param alpha: float α参数
:param norm: str 邻接矩阵归一化方式
'left': S=D^{-1}A, 'right': S=AD^{-1}, 'both': S=D^{-1/2}AD^{-1/2}
"""
super().__init__()
self.num_layers = num_layers
self.alpha = alpha
self.norm = norm
@torch.no_grad()
def forward(self, g, labels, mask=None, post_step=None):
"""
:param g: DGLGraph 无向图
:param labels: tensor(N, C) one-hot标签
:param mask: tensor(N), optional 有标签顶点mask
:param post_step: callable, optional f: tensor(N, C) -> tensor(N, C)
:return: tensor(N, C) 预测标签概率
"""
with g.local_scope():
if mask is not None:
y = torch.zeros_like(labels)
y[mask] = labels[mask]
else:
y = labels
residual = (1 - self.alpha) * y
degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5 if self.norm == 'both' else -1).unsqueeze(1) # (N, 1)
for _ in range(self.num_layers):
if self.norm in ('both', 'right'):
y *= norm
g.ndata['h'] = y
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
y = self.alpha * g.ndata.pop('h')
if self.norm in ('both', 'left'):
y *= norm
y += residual
if post_step is not None:
y = post_step(y)
return y
class CorrectAndSmooth(nn.Module):
def __init__(
self, num_correct_layers, correct_alpha, correct_norm,
num_smooth_layers, smooth_alpha, smooth_norm, scale=1.0):
"""C&S模型"""
super().__init__()
self.correct_prop = LabelPropagation(num_correct_layers, correct_alpha, correct_norm)
self.smooth_prop = LabelPropagation(num_smooth_layers, smooth_alpha, smooth_norm)
self.scale = scale
def correct(self, g, labels, base_pred, mask):
"""Correct步修正基础预测中的误差
:param g: DGLGraph 无向图
:param labels: tensor(N, C) one-hot标签
:param base_pred: tensor(N, C) 基础预测
:param mask: tensor(N) 训练集mask
:return: tensor(N, C) 修正后的预测
"""
err = torch.zeros_like(base_pred) # (N, C)
err[mask] = labels[mask] - base_pred[mask]
# FDiff-scale: 对训练集固定误差
def fix_input(y):
y[mask] = err[mask]
return y
smoothed_err = self.correct_prop(g, err, post_step=fix_input) # \hat{E}
corrected_pred = base_pred + self.scale * smoothed_err # Z^{(r)}
corrected_pred[corrected_pred.isnan()] = base_pred[corrected_pred.isnan()]
return corrected_pred
def smooth(self, g, labels, corrected_pred, mask):
"""Smooth步平滑最终预测
:param g: DGLGraph 无向图
:param labels: tensor(N, C) one-hot标签
:param corrected_pred: tensor(N, C) 修正后的预测
:param mask: tensor(N) 训练集mask
:return: tensor(N, C) 最终预测
"""
guess = corrected_pred
guess[mask] = labels[mask]
return self.smooth_prop(g, guess)
def forward(self, g, labels, base_pred, mask):
"""
:param g: DGLGraph 无向图
:param labels: tensor(N, C) one-hot标签
:param base_pred: tensor(N, C) 基础预测
:param mask: tensor(N) 训练集mask
:return: tensor(N, C) 最终预测
"""
# corrected_pred = self.correct(g, labels, base_pred, mask)
corrected_pred = base_pred
return self.smooth(g, labels, corrected_pred, mask)