GNNRecom/gnnrec/hge/hgt/model.py

181 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import HeteroGraphConv
from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair
class HGTAttention(nn.Module):
def __init__(self, out_dim, num_heads, k_linear, q_linear, v_linear, w_att, w_msg, mu):
"""HGT注意力模块
:param out_dim: int 输出特征维数
:param num_heads: int 注意力头数K
:param k_linear: nn.Linear(d_in, d_out)
:param q_linear: nn.Linear(d_in, d_out)
:param v_linear: nn.Linear(d_in, d_out)
:param w_att: tensor(K, d_out/K, d_out/K)
:param w_msg: tensor(K, d_out/K, d_out/K)
:param mu: tensor(1)
"""
super().__init__()
self.out_dim = out_dim
self.num_heads = num_heads
self.d_k = out_dim // num_heads
self.k_linear = k_linear
self.q_linear = q_linear
self.v_linear = v_linear
self.w_att = w_att
self.w_msg = w_msg
self.mu = mu
def forward(self, g, feat):
"""
:param g: DGLGraph 二分图(只包含一种关系)
:param feat: tensor(N_src, d_in) or (tensor(N_src, d_in), tensor(N_dst, d_in)) 输入特征
:return: tensor(N_dst, d_out) 目标顶点该关于关系的表示
"""
with g.local_scope():
feat_src, feat_dst = expand_as_pair(feat, g)
# (N_src, d_in) -> (N_src, d_out) -> (N_src, K, d_out/K)
k = self.k_linear(feat_src).view(-1, self.num_heads, self.d_k)
v = self.v_linear(feat_src).view(-1, self.num_heads, self.d_k)
q = self.q_linear(feat_dst).view(-1, self.num_heads, self.d_k)
# k[:, h] @= w_att[h] => k[n, h, j] = ∑(i) k[n, h, i] * w_att[h, i, j]
k = torch.einsum('nhi,hij->nhj', k, self.w_att)
v = torch.einsum('nhi,hij->nhj', v, self.w_msg)
g.srcdata.update({'k': k, 'v': v})
g.dstdata['q'] = q
g.apply_edges(fn.v_dot_u('q', 'k', 't')) # g.edata['t']: (E, K, 1)
attn = g.edata.pop('t').squeeze(dim=-1) * self.mu / math.sqrt(self.d_k)
attn = edge_softmax(g, attn) # (E, K)
self.attn = attn.detach()
g.edata['t'] = attn.unsqueeze(dim=-1) # (E, K, 1)
g.update_all(fn.u_mul_e('v', 't', 'm'), fn.sum('m', 'h'))
out = g.dstdata['h'].view(-1, self.out_dim) # (N_dst, d_out)
return out
class HGTLayer(nn.Module):
def __init__(self, in_dim, out_dim, num_heads, ntypes, etypes, dropout=0.2, use_norm=True):
"""HGT层
:param in_dim: int 输入特征维数
:param out_dim: int 输出特征维数
:param num_heads: int 注意力头数K
:param ntypes: List[str] 顶点类型列表
:param etypes: List[(str, str, str)] 规范边类型列表
:param dropout: dropout: float, optional Dropout概率默认为0.2
:param use_norm: bool, optional 是否使用层归一化默认为True
"""
super().__init__()
d_k = out_dim // num_heads
k_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
q_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
v_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
w_att = {r[1]: nn.Parameter(torch.Tensor(num_heads, d_k, d_k)) for r in etypes}
w_msg = {r[1]: nn.Parameter(torch.Tensor(num_heads, d_k, d_k)) for r in etypes}
mu = {r[1]: nn.Parameter(torch.ones(num_heads)) for r in etypes}
self.reset_parameters(w_att, w_msg)
self.conv = HeteroGraphConv({
etype: HGTAttention(
out_dim, num_heads, k_linear[stype], q_linear[dtype], v_linear[stype],
w_att[etype], w_msg[etype], mu[etype]
) for stype, etype, dtype in etypes
}, 'mean')
self.a_linear = nn.ModuleDict({ntype: nn.Linear(out_dim, out_dim) for ntype in ntypes})
self.skip = nn.ParameterDict({ntype: nn.Parameter(torch.ones(1)) for ntype in ntypes})
self.drop = nn.Dropout(dropout)
self.use_norm = use_norm
if use_norm:
self.norms = nn.ModuleDict({ntype: nn.LayerNorm(out_dim) for ntype in ntypes})
def reset_parameters(self, w_att, w_msg):
for etype in w_att:
nn.init.xavier_uniform_(w_att[etype])
nn.init.xavier_uniform_(w_msg[etype])
def forward(self, g, feats):
"""
:param g: DGLGraph 异构图
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入顶点特征的映射
:return: Dict[str, tensor(N_i, d_out)] 顶点类型到输出特征的映射
"""
if g.is_block:
feats_dst = {ntype: feats[ntype][:g.num_dst_nodes(ntype)] for ntype in feats}
else:
feats_dst = feats
with g.local_scope():
# 第1步异构互注意力+异构消息传递+目标相关的聚集
hs = self.conv(g, feats) # {ntype: tensor(N_i, d_out)}
# 第2步残差连接
out_feats = {}
for ntype in g.dsttypes:
if g.num_dst_nodes(ntype) == 0:
continue
alpha = torch.sigmoid(self.skip[ntype])
trans_out = self.drop(self.a_linear[ntype](hs[ntype]))
out = alpha * trans_out + (1 - alpha) * feats_dst[ntype]
out_feats[ntype] = self.norms[ntype](out) if self.use_norm else out
return out_feats
class HGT(nn.Module):
def __init__(
self, in_dims, hidden_dim, out_dim, num_heads, ntypes, etypes,
predict_ntype, num_layers, dropout=0.2, use_norm=True):
"""HGT模型
:param in_dims: Dict[str, int] 顶点类型到输入特征维数的映射
:param hidden_dim: int 隐含特征维数
:param out_dim: int 输出特征维数
:param num_heads: int 注意力头数K
:param ntypes: List[str] 顶点类型列表
:param etypes: List[(str, str, str)] 规范边类型列表
:param predict_ntype: str 待预测顶点类型
:param num_layers: int 层数
:param dropout: dropout: float, optional Dropout概率默认为0.2
:param use_norm: bool, optional 是否使用层归一化默认为True
"""
super().__init__()
self.predict_ntype = predict_ntype
self.adapt_fcs = nn.ModuleDict({
ntype: nn.Linear(in_dim, hidden_dim) for ntype, in_dim in in_dims.items()
})
self.layers = nn.ModuleList([
HGTLayer(hidden_dim, hidden_dim, num_heads, ntypes, etypes, dropout, use_norm)
for _ in range(num_layers)
])
self.predict = nn.Linear(hidden_dim, out_dim)
def forward(self, blocks, feats):
"""
:param blocks: List[DGLBlock]
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入顶点特征的映射
:return: tensor(N_i, d_out) 待预测顶点的最终嵌入
"""
hs = {ntype: F.gelu(self.adapt_fcs[ntype](feats[ntype])) for ntype in feats}
for i in range(len(self.layers)):
hs = self.layers[i](blocks[i], hs) # {ntype: tensor(N_i, d_hid)}
out = self.predict(hs[self.predict_ntype]) # tensor(N_i, d_out)
return out
class HGTFull(HGT):
def forward(self, g, feats):
return super().forward([g] * len(self.layers), feats)