GNNRecom/gnnrec/hge/cs/model.py

114 lines
4.0 KiB
Python
Raw Normal View History

2021-11-16 07:04:52 +00:00
import dgl.function as fn
import torch
import torch.nn as nn
class LabelPropagation(nn.Module):
def __init__(self, num_layers, alpha, norm):
"""标签传播模型
.. math::
Y^{(t+1)} = \\alpha SY^{(t)} + (1-\\alpha)Y, Y^{(0)} = Y
:param num_layers: int 传播层数
:param alpha: float α参数
:param norm: str 邻接矩阵归一化方式
'left': S=D^{-1}A, 'right': S=AD^{-1}, 'both': S=D^{-1/2}AD^{-1/2}
"""
super().__init__()
self.num_layers = num_layers
self.alpha = alpha
self.norm = norm
@torch.no_grad()
def forward(self, g, labels, mask=None, post_step=None):
"""
:param g: DGLGraph 无向图
:param labels: tensor(N, C) one-hot标签
:param mask: tensor(N), optional 有标签顶点mask
:param post_step: callable, optional f: tensor(N, C) -> tensor(N, C)
:return: tensor(N, C) 预测标签概率
"""
with g.local_scope():
if mask is not None:
y = torch.zeros_like(labels)
y[mask] = labels[mask]
else:
y = labels
residual = (1 - self.alpha) * y
degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5 if self.norm == 'both' else -1).unsqueeze(1) # (N, 1)
for _ in range(self.num_layers):
if self.norm in ('both', 'right'):
y *= norm
g.ndata['h'] = y
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
y = self.alpha * g.ndata.pop('h')
if self.norm in ('both', 'left'):
y *= norm
y += residual
if post_step is not None:
y = post_step(y)
return y
class CorrectAndSmooth(nn.Module):
def __init__(
self, num_correct_layers, correct_alpha, correct_norm,
num_smooth_layers, smooth_alpha, smooth_norm, scale=1.0):
"""C&S模型"""
super().__init__()
self.correct_prop = LabelPropagation(num_correct_layers, correct_alpha, correct_norm)
self.smooth_prop = LabelPropagation(num_smooth_layers, smooth_alpha, smooth_norm)
self.scale = scale
def correct(self, g, labels, base_pred, mask):
"""Correct步修正基础预测中的误差
:param g: DGLGraph 无向图
:param labels: tensor(N, C) one-hot标签
:param base_pred: tensor(N, C) 基础预测
:param mask: tensor(N) 训练集mask
:return: tensor(N, C) 修正后的预测
"""
err = torch.zeros_like(base_pred) # (N, C)
err[mask] = labels[mask] - base_pred[mask]
# FDiff-scale: 对训练集固定误差
def fix_input(y):
y[mask] = err[mask]
return y
smoothed_err = self.correct_prop(g, err, post_step=fix_input) # \hat{E}
corrected_pred = base_pred + self.scale * smoothed_err # Z^{(r)}
corrected_pred[corrected_pred.isnan()] = base_pred[corrected_pred.isnan()]
return corrected_pred
def smooth(self, g, labels, corrected_pred, mask):
"""Smooth步平滑最终预测
:param g: DGLGraph 无向图
:param labels: tensor(N, C) one-hot标签
:param corrected_pred: tensor(N, C) 修正后的预测
:param mask: tensor(N) 训练集mask
:return: tensor(N, C) 最终预测
"""
guess = corrected_pred
guess[mask] = labels[mask]
return self.smooth_prop(g, guess)
def forward(self, g, labels, base_pred, mask):
"""
:param g: DGLGraph 无向图
:param labels: tensor(N, C) one-hot标签
:param base_pred: tensor(N, C) 基础预测
:param mask: tensor(N) 训练集mask
:return: tensor(N, C) 最终预测
"""
# corrected_pred = self.correct(g, labels, base_pred, mask)
corrected_pred = base_pred
return self.smooth(g, labels, corrected_pred, mask)