GNNRecom/gnnrec/hge/hgt/model.py

181 lines
7.4 KiB
Python
Raw Permalink Normal View History

2021-11-16 07:04:52 +00:00
import math
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import HeteroGraphConv
from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair
class HGTAttention(nn.Module):
def __init__(self, out_dim, num_heads, k_linear, q_linear, v_linear, w_att, w_msg, mu):
"""HGT注意力模块
:param out_dim: int 输出特征维数
:param num_heads: int 注意力头数K
:param k_linear: nn.Linear(d_in, d_out)
:param q_linear: nn.Linear(d_in, d_out)
:param v_linear: nn.Linear(d_in, d_out)
:param w_att: tensor(K, d_out/K, d_out/K)
:param w_msg: tensor(K, d_out/K, d_out/K)
:param mu: tensor(1)
"""
super().__init__()
self.out_dim = out_dim
self.num_heads = num_heads
self.d_k = out_dim // num_heads
self.k_linear = k_linear
self.q_linear = q_linear
self.v_linear = v_linear
self.w_att = w_att
self.w_msg = w_msg
self.mu = mu
def forward(self, g, feat):
"""
:param g: DGLGraph 二分图只包含一种关系
:param feat: tensor(N_src, d_in) or (tensor(N_src, d_in), tensor(N_dst, d_in)) 输入特征
:return: tensor(N_dst, d_out) 目标顶点该关于关系的表示
"""
with g.local_scope():
feat_src, feat_dst = expand_as_pair(feat, g)
# (N_src, d_in) -> (N_src, d_out) -> (N_src, K, d_out/K)
k = self.k_linear(feat_src).view(-1, self.num_heads, self.d_k)
v = self.v_linear(feat_src).view(-1, self.num_heads, self.d_k)
q = self.q_linear(feat_dst).view(-1, self.num_heads, self.d_k)
# k[:, h] @= w_att[h] => k[n, h, j] = ∑(i) k[n, h, i] * w_att[h, i, j]
k = torch.einsum('nhi,hij->nhj', k, self.w_att)
v = torch.einsum('nhi,hij->nhj', v, self.w_msg)
g.srcdata.update({'k': k, 'v': v})
g.dstdata['q'] = q
g.apply_edges(fn.v_dot_u('q', 'k', 't')) # g.edata['t']: (E, K, 1)
attn = g.edata.pop('t').squeeze(dim=-1) * self.mu / math.sqrt(self.d_k)
attn = edge_softmax(g, attn) # (E, K)
self.attn = attn.detach()
g.edata['t'] = attn.unsqueeze(dim=-1) # (E, K, 1)
g.update_all(fn.u_mul_e('v', 't', 'm'), fn.sum('m', 'h'))
out = g.dstdata['h'].view(-1, self.out_dim) # (N_dst, d_out)
return out
class HGTLayer(nn.Module):
def __init__(self, in_dim, out_dim, num_heads, ntypes, etypes, dropout=0.2, use_norm=True):
"""HGT层
:param in_dim: int 输入特征维数
:param out_dim: int 输出特征维数
:param num_heads: int 注意力头数K
:param ntypes: List[str] 顶点类型列表
:param etypes: List[(str, str, str)] 规范边类型列表
:param dropout: dropout: float, optional Dropout概率默认为0.2
:param use_norm: bool, optional 是否使用层归一化默认为True
"""
super().__init__()
d_k = out_dim // num_heads
k_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
q_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
v_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
w_att = {r[1]: nn.Parameter(torch.Tensor(num_heads, d_k, d_k)) for r in etypes}
w_msg = {r[1]: nn.Parameter(torch.Tensor(num_heads, d_k, d_k)) for r in etypes}
mu = {r[1]: nn.Parameter(torch.ones(num_heads)) for r in etypes}
self.reset_parameters(w_att, w_msg)
self.conv = HeteroGraphConv({
etype: HGTAttention(
out_dim, num_heads, k_linear[stype], q_linear[dtype], v_linear[stype],
w_att[etype], w_msg[etype], mu[etype]
) for stype, etype, dtype in etypes
}, 'mean')
self.a_linear = nn.ModuleDict({ntype: nn.Linear(out_dim, out_dim) for ntype in ntypes})
self.skip = nn.ParameterDict({ntype: nn.Parameter(torch.ones(1)) for ntype in ntypes})
self.drop = nn.Dropout(dropout)
self.use_norm = use_norm
if use_norm:
self.norms = nn.ModuleDict({ntype: nn.LayerNorm(out_dim) for ntype in ntypes})
def reset_parameters(self, w_att, w_msg):
for etype in w_att:
nn.init.xavier_uniform_(w_att[etype])
nn.init.xavier_uniform_(w_msg[etype])
def forward(self, g, feats):
"""
:param g: DGLGraph 异构图
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入顶点特征的映射
:return: Dict[str, tensor(N_i, d_out)] 顶点类型到输出特征的映射
"""
if g.is_block:
feats_dst = {ntype: feats[ntype][:g.num_dst_nodes(ntype)] for ntype in feats}
else:
feats_dst = feats
with g.local_scope():
# 第1步异构互注意力+异构消息传递+目标相关的聚集
hs = self.conv(g, feats) # {ntype: tensor(N_i, d_out)}
# 第2步残差连接
out_feats = {}
for ntype in g.dsttypes:
if g.num_dst_nodes(ntype) == 0:
continue
alpha = torch.sigmoid(self.skip[ntype])
trans_out = self.drop(self.a_linear[ntype](hs[ntype]))
out = alpha * trans_out + (1 - alpha) * feats_dst[ntype]
out_feats[ntype] = self.norms[ntype](out) if self.use_norm else out
return out_feats
class HGT(nn.Module):
def __init__(
self, in_dims, hidden_dim, out_dim, num_heads, ntypes, etypes,
predict_ntype, num_layers, dropout=0.2, use_norm=True):
"""HGT模型
:param in_dims: Dict[str, int] 顶点类型到输入特征维数的映射
:param hidden_dim: int 隐含特征维数
:param out_dim: int 输出特征维数
:param num_heads: int 注意力头数K
:param ntypes: List[str] 顶点类型列表
:param etypes: List[(str, str, str)] 规范边类型列表
:param predict_ntype: str 待预测顶点类型
:param num_layers: int 层数
:param dropout: dropout: float, optional Dropout概率默认为0.2
:param use_norm: bool, optional 是否使用层归一化默认为True
"""
super().__init__()
self.predict_ntype = predict_ntype
self.adapt_fcs = nn.ModuleDict({
ntype: nn.Linear(in_dim, hidden_dim) for ntype, in_dim in in_dims.items()
})
self.layers = nn.ModuleList([
HGTLayer(hidden_dim, hidden_dim, num_heads, ntypes, etypes, dropout, use_norm)
for _ in range(num_layers)
])
self.predict = nn.Linear(hidden_dim, out_dim)
def forward(self, blocks, feats):
"""
:param blocks: List[DGLBlock]
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入顶点特征的映射
:return: tensor(N_i, d_out) 待预测顶点的最终嵌入
"""
hs = {ntype: F.gelu(self.adapt_fcs[ntype](feats[ntype])) for ntype in feats}
for i in range(len(self.layers)):
hs = self.layers[i](blocks[i], hs) # {ntype: tensor(N_i, d_hid)}
out = self.predict(hs[self.predict_ntype]) # tensor(N_i, d_out)
return out
class HGTFull(HGT):
def forward(self, g, feats):
return super().forward([g] * len(self.layers), feats)