181 lines
7.4 KiB
Python
181 lines
7.4 KiB
Python
|
import math
|
|||
|
|
|||
|
import dgl.function as fn
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
import torch.nn.functional as F
|
|||
|
from dgl.nn import HeteroGraphConv
|
|||
|
from dgl.ops import edge_softmax
|
|||
|
from dgl.utils import expand_as_pair
|
|||
|
|
|||
|
|
|||
|
class HGTAttention(nn.Module):
|
|||
|
|
|||
|
def __init__(self, out_dim, num_heads, k_linear, q_linear, v_linear, w_att, w_msg, mu):
|
|||
|
"""HGT注意力模块
|
|||
|
|
|||
|
:param out_dim: int 输出特征维数
|
|||
|
:param num_heads: int 注意力头数K
|
|||
|
:param k_linear: nn.Linear(d_in, d_out)
|
|||
|
:param q_linear: nn.Linear(d_in, d_out)
|
|||
|
:param v_linear: nn.Linear(d_in, d_out)
|
|||
|
:param w_att: tensor(K, d_out/K, d_out/K)
|
|||
|
:param w_msg: tensor(K, d_out/K, d_out/K)
|
|||
|
:param mu: tensor(1)
|
|||
|
"""
|
|||
|
super().__init__()
|
|||
|
self.out_dim = out_dim
|
|||
|
self.num_heads = num_heads
|
|||
|
self.d_k = out_dim // num_heads
|
|||
|
self.k_linear = k_linear
|
|||
|
self.q_linear = q_linear
|
|||
|
self.v_linear = v_linear
|
|||
|
self.w_att = w_att
|
|||
|
self.w_msg = w_msg
|
|||
|
self.mu = mu
|
|||
|
|
|||
|
def forward(self, g, feat):
|
|||
|
"""
|
|||
|
:param g: DGLGraph 二分图(只包含一种关系)
|
|||
|
:param feat: tensor(N_src, d_in) or (tensor(N_src, d_in), tensor(N_dst, d_in)) 输入特征
|
|||
|
:return: tensor(N_dst, d_out) 目标顶点该关于关系的表示
|
|||
|
"""
|
|||
|
with g.local_scope():
|
|||
|
feat_src, feat_dst = expand_as_pair(feat, g)
|
|||
|
# (N_src, d_in) -> (N_src, d_out) -> (N_src, K, d_out/K)
|
|||
|
k = self.k_linear(feat_src).view(-1, self.num_heads, self.d_k)
|
|||
|
v = self.v_linear(feat_src).view(-1, self.num_heads, self.d_k)
|
|||
|
q = self.q_linear(feat_dst).view(-1, self.num_heads, self.d_k)
|
|||
|
|
|||
|
# k[:, h] @= w_att[h] => k[n, h, j] = ∑(i) k[n, h, i] * w_att[h, i, j]
|
|||
|
k = torch.einsum('nhi,hij->nhj', k, self.w_att)
|
|||
|
v = torch.einsum('nhi,hij->nhj', v, self.w_msg)
|
|||
|
|
|||
|
g.srcdata.update({'k': k, 'v': v})
|
|||
|
g.dstdata['q'] = q
|
|||
|
g.apply_edges(fn.v_dot_u('q', 'k', 't')) # g.edata['t']: (E, K, 1)
|
|||
|
attn = g.edata.pop('t').squeeze(dim=-1) * self.mu / math.sqrt(self.d_k)
|
|||
|
attn = edge_softmax(g, attn) # (E, K)
|
|||
|
self.attn = attn.detach()
|
|||
|
g.edata['t'] = attn.unsqueeze(dim=-1) # (E, K, 1)
|
|||
|
|
|||
|
g.update_all(fn.u_mul_e('v', 't', 'm'), fn.sum('m', 'h'))
|
|||
|
out = g.dstdata['h'].view(-1, self.out_dim) # (N_dst, d_out)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class HGTLayer(nn.Module):
|
|||
|
|
|||
|
def __init__(self, in_dim, out_dim, num_heads, ntypes, etypes, dropout=0.2, use_norm=True):
|
|||
|
"""HGT层
|
|||
|
|
|||
|
:param in_dim: int 输入特征维数
|
|||
|
:param out_dim: int 输出特征维数
|
|||
|
:param num_heads: int 注意力头数K
|
|||
|
:param ntypes: List[str] 顶点类型列表
|
|||
|
:param etypes: List[(str, str, str)] 规范边类型列表
|
|||
|
:param dropout: dropout: float, optional Dropout概率,默认为0.2
|
|||
|
:param use_norm: bool, optional 是否使用层归一化,默认为True
|
|||
|
"""
|
|||
|
super().__init__()
|
|||
|
d_k = out_dim // num_heads
|
|||
|
k_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
|
|||
|
q_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
|
|||
|
v_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
|
|||
|
w_att = {r[1]: nn.Parameter(torch.Tensor(num_heads, d_k, d_k)) for r in etypes}
|
|||
|
w_msg = {r[1]: nn.Parameter(torch.Tensor(num_heads, d_k, d_k)) for r in etypes}
|
|||
|
mu = {r[1]: nn.Parameter(torch.ones(num_heads)) for r in etypes}
|
|||
|
self.reset_parameters(w_att, w_msg)
|
|||
|
self.conv = HeteroGraphConv({
|
|||
|
etype: HGTAttention(
|
|||
|
out_dim, num_heads, k_linear[stype], q_linear[dtype], v_linear[stype],
|
|||
|
w_att[etype], w_msg[etype], mu[etype]
|
|||
|
) for stype, etype, dtype in etypes
|
|||
|
}, 'mean')
|
|||
|
|
|||
|
self.a_linear = nn.ModuleDict({ntype: nn.Linear(out_dim, out_dim) for ntype in ntypes})
|
|||
|
self.skip = nn.ParameterDict({ntype: nn.Parameter(torch.ones(1)) for ntype in ntypes})
|
|||
|
self.drop = nn.Dropout(dropout)
|
|||
|
|
|||
|
self.use_norm = use_norm
|
|||
|
if use_norm:
|
|||
|
self.norms = nn.ModuleDict({ntype: nn.LayerNorm(out_dim) for ntype in ntypes})
|
|||
|
|
|||
|
def reset_parameters(self, w_att, w_msg):
|
|||
|
for etype in w_att:
|
|||
|
nn.init.xavier_uniform_(w_att[etype])
|
|||
|
nn.init.xavier_uniform_(w_msg[etype])
|
|||
|
|
|||
|
def forward(self, g, feats):
|
|||
|
"""
|
|||
|
:param g: DGLGraph 异构图
|
|||
|
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入顶点特征的映射
|
|||
|
:return: Dict[str, tensor(N_i, d_out)] 顶点类型到输出特征的映射
|
|||
|
"""
|
|||
|
if g.is_block:
|
|||
|
feats_dst = {ntype: feats[ntype][:g.num_dst_nodes(ntype)] for ntype in feats}
|
|||
|
else:
|
|||
|
feats_dst = feats
|
|||
|
with g.local_scope():
|
|||
|
# 第1步:异构互注意力+异构消息传递+目标相关的聚集
|
|||
|
hs = self.conv(g, feats) # {ntype: tensor(N_i, d_out)}
|
|||
|
|
|||
|
# 第2步:残差连接
|
|||
|
out_feats = {}
|
|||
|
for ntype in g.dsttypes:
|
|||
|
if g.num_dst_nodes(ntype) == 0:
|
|||
|
continue
|
|||
|
alpha = torch.sigmoid(self.skip[ntype])
|
|||
|
trans_out = self.drop(self.a_linear[ntype](hs[ntype]))
|
|||
|
out = alpha * trans_out + (1 - alpha) * feats_dst[ntype]
|
|||
|
out_feats[ntype] = self.norms[ntype](out) if self.use_norm else out
|
|||
|
return out_feats
|
|||
|
|
|||
|
|
|||
|
class HGT(nn.Module):
|
|||
|
|
|||
|
def __init__(
|
|||
|
self, in_dims, hidden_dim, out_dim, num_heads, ntypes, etypes,
|
|||
|
predict_ntype, num_layers, dropout=0.2, use_norm=True):
|
|||
|
"""HGT模型
|
|||
|
|
|||
|
:param in_dims: Dict[str, int] 顶点类型到输入特征维数的映射
|
|||
|
:param hidden_dim: int 隐含特征维数
|
|||
|
:param out_dim: int 输出特征维数
|
|||
|
:param num_heads: int 注意力头数K
|
|||
|
:param ntypes: List[str] 顶点类型列表
|
|||
|
:param etypes: List[(str, str, str)] 规范边类型列表
|
|||
|
:param predict_ntype: str 待预测顶点类型
|
|||
|
:param num_layers: int 层数
|
|||
|
:param dropout: dropout: float, optional Dropout概率,默认为0.2
|
|||
|
:param use_norm: bool, optional 是否使用层归一化,默认为True
|
|||
|
"""
|
|||
|
super().__init__()
|
|||
|
self.predict_ntype = predict_ntype
|
|||
|
self.adapt_fcs = nn.ModuleDict({
|
|||
|
ntype: nn.Linear(in_dim, hidden_dim) for ntype, in_dim in in_dims.items()
|
|||
|
})
|
|||
|
self.layers = nn.ModuleList([
|
|||
|
HGTLayer(hidden_dim, hidden_dim, num_heads, ntypes, etypes, dropout, use_norm)
|
|||
|
for _ in range(num_layers)
|
|||
|
])
|
|||
|
self.predict = nn.Linear(hidden_dim, out_dim)
|
|||
|
|
|||
|
def forward(self, blocks, feats):
|
|||
|
"""
|
|||
|
:param blocks: List[DGLBlock]
|
|||
|
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入顶点特征的映射
|
|||
|
:return: tensor(N_i, d_out) 待预测顶点的最终嵌入
|
|||
|
"""
|
|||
|
hs = {ntype: F.gelu(self.adapt_fcs[ntype](feats[ntype])) for ntype in feats}
|
|||
|
for i in range(len(self.layers)):
|
|||
|
hs = self.layers[i](blocks[i], hs) # {ntype: tensor(N_i, d_hid)}
|
|||
|
out = self.predict(hs[self.predict_ntype]) # tensor(N_i, d_out)
|
|||
|
return out
|
|||
|
|
|||
|
|
|||
|
class HGTFull(HGT):
|
|||
|
|
|||
|
def forward(self, g, feats):
|
|||
|
return super().forward([g] * len(self.layers), feats)
|