114 lines
4.0 KiB
Python
114 lines
4.0 KiB
Python
import dgl.function as fn
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class LabelPropagation(nn.Module):
|
||
|
||
def __init__(self, num_layers, alpha, norm):
|
||
"""标签传播模型
|
||
|
||
.. math::
|
||
Y^{(t+1)} = \\alpha SY^{(t)} + (1-\\alpha)Y, Y^{(0)} = Y
|
||
|
||
:param num_layers: int 传播层数
|
||
:param alpha: float α参数
|
||
:param norm: str 邻接矩阵归一化方式
|
||
'left': S=D^{-1}A, 'right': S=AD^{-1}, 'both': S=D^{-1/2}AD^{-1/2}
|
||
"""
|
||
super().__init__()
|
||
self.num_layers = num_layers
|
||
self.alpha = alpha
|
||
self.norm = norm
|
||
|
||
@torch.no_grad()
|
||
def forward(self, g, labels, mask=None, post_step=None):
|
||
"""
|
||
:param g: DGLGraph 无向图
|
||
:param labels: tensor(N, C) one-hot标签
|
||
:param mask: tensor(N), optional 有标签顶点mask
|
||
:param post_step: callable, optional f: tensor(N, C) -> tensor(N, C)
|
||
:return: tensor(N, C) 预测标签概率
|
||
"""
|
||
with g.local_scope():
|
||
if mask is not None:
|
||
y = torch.zeros_like(labels)
|
||
y[mask] = labels[mask]
|
||
else:
|
||
y = labels
|
||
|
||
residual = (1 - self.alpha) * y
|
||
degs = g.in_degrees().float().clamp(min=1)
|
||
norm = torch.pow(degs, -0.5 if self.norm == 'both' else -1).unsqueeze(1) # (N, 1)
|
||
for _ in range(self.num_layers):
|
||
if self.norm in ('both', 'right'):
|
||
y *= norm
|
||
g.ndata['h'] = y
|
||
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
|
||
y = self.alpha * g.ndata.pop('h')
|
||
if self.norm in ('both', 'left'):
|
||
y *= norm
|
||
y += residual
|
||
if post_step is not None:
|
||
y = post_step(y)
|
||
return y
|
||
|
||
|
||
class CorrectAndSmooth(nn.Module):
|
||
|
||
def __init__(
|
||
self, num_correct_layers, correct_alpha, correct_norm,
|
||
num_smooth_layers, smooth_alpha, smooth_norm, scale=1.0):
|
||
"""C&S模型"""
|
||
super().__init__()
|
||
self.correct_prop = LabelPropagation(num_correct_layers, correct_alpha, correct_norm)
|
||
self.smooth_prop = LabelPropagation(num_smooth_layers, smooth_alpha, smooth_norm)
|
||
self.scale = scale
|
||
|
||
def correct(self, g, labels, base_pred, mask):
|
||
"""Correct步,修正基础预测中的误差
|
||
|
||
:param g: DGLGraph 无向图
|
||
:param labels: tensor(N, C) one-hot标签
|
||
:param base_pred: tensor(N, C) 基础预测
|
||
:param mask: tensor(N) 训练集mask
|
||
:return: tensor(N, C) 修正后的预测
|
||
"""
|
||
err = torch.zeros_like(base_pred) # (N, C)
|
||
err[mask] = labels[mask] - base_pred[mask]
|
||
|
||
# FDiff-scale: 对训练集固定误差
|
||
def fix_input(y):
|
||
y[mask] = err[mask]
|
||
return y
|
||
|
||
smoothed_err = self.correct_prop(g, err, post_step=fix_input) # \hat{E}
|
||
corrected_pred = base_pred + self.scale * smoothed_err # Z^{(r)}
|
||
corrected_pred[corrected_pred.isnan()] = base_pred[corrected_pred.isnan()]
|
||
return corrected_pred
|
||
|
||
def smooth(self, g, labels, corrected_pred, mask):
|
||
"""Smooth步,平滑最终预测
|
||
|
||
:param g: DGLGraph 无向图
|
||
:param labels: tensor(N, C) one-hot标签
|
||
:param corrected_pred: tensor(N, C) 修正后的预测
|
||
:param mask: tensor(N) 训练集mask
|
||
:return: tensor(N, C) 最终预测
|
||
"""
|
||
guess = corrected_pred
|
||
guess[mask] = labels[mask]
|
||
return self.smooth_prop(g, guess)
|
||
|
||
def forward(self, g, labels, base_pred, mask):
|
||
"""
|
||
:param g: DGLGraph 无向图
|
||
:param labels: tensor(N, C) one-hot标签
|
||
:param base_pred: tensor(N, C) 基础预测
|
||
:param mask: tensor(N) 训练集mask
|
||
:return: tensor(N, C) 最终预测
|
||
"""
|
||
# corrected_pred = self.correct(g, labels, base_pred, mask)
|
||
corrected_pred = base_pred
|
||
return self.smooth(g, labels, corrected_pred, mask)
|