181 lines
7.4 KiB
Python
181 lines
7.4 KiB
Python
import math
|
||
|
||
import dgl.function as fn
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from dgl.nn import HeteroGraphConv
|
||
from dgl.ops import edge_softmax
|
||
from dgl.utils import expand_as_pair
|
||
|
||
|
||
class HGTAttention(nn.Module):
|
||
|
||
def __init__(self, out_dim, num_heads, k_linear, q_linear, v_linear, w_att, w_msg, mu):
|
||
"""HGT注意力模块
|
||
|
||
:param out_dim: int 输出特征维数
|
||
:param num_heads: int 注意力头数K
|
||
:param k_linear: nn.Linear(d_in, d_out)
|
||
:param q_linear: nn.Linear(d_in, d_out)
|
||
:param v_linear: nn.Linear(d_in, d_out)
|
||
:param w_att: tensor(K, d_out/K, d_out/K)
|
||
:param w_msg: tensor(K, d_out/K, d_out/K)
|
||
:param mu: tensor(1)
|
||
"""
|
||
super().__init__()
|
||
self.out_dim = out_dim
|
||
self.num_heads = num_heads
|
||
self.d_k = out_dim // num_heads
|
||
self.k_linear = k_linear
|
||
self.q_linear = q_linear
|
||
self.v_linear = v_linear
|
||
self.w_att = w_att
|
||
self.w_msg = w_msg
|
||
self.mu = mu
|
||
|
||
def forward(self, g, feat):
|
||
"""
|
||
:param g: DGLGraph 二分图(只包含一种关系)
|
||
:param feat: tensor(N_src, d_in) or (tensor(N_src, d_in), tensor(N_dst, d_in)) 输入特征
|
||
:return: tensor(N_dst, d_out) 目标顶点该关于关系的表示
|
||
"""
|
||
with g.local_scope():
|
||
feat_src, feat_dst = expand_as_pair(feat, g)
|
||
# (N_src, d_in) -> (N_src, d_out) -> (N_src, K, d_out/K)
|
||
k = self.k_linear(feat_src).view(-1, self.num_heads, self.d_k)
|
||
v = self.v_linear(feat_src).view(-1, self.num_heads, self.d_k)
|
||
q = self.q_linear(feat_dst).view(-1, self.num_heads, self.d_k)
|
||
|
||
# k[:, h] @= w_att[h] => k[n, h, j] = ∑(i) k[n, h, i] * w_att[h, i, j]
|
||
k = torch.einsum('nhi,hij->nhj', k, self.w_att)
|
||
v = torch.einsum('nhi,hij->nhj', v, self.w_msg)
|
||
|
||
g.srcdata.update({'k': k, 'v': v})
|
||
g.dstdata['q'] = q
|
||
g.apply_edges(fn.v_dot_u('q', 'k', 't')) # g.edata['t']: (E, K, 1)
|
||
attn = g.edata.pop('t').squeeze(dim=-1) * self.mu / math.sqrt(self.d_k)
|
||
attn = edge_softmax(g, attn) # (E, K)
|
||
self.attn = attn.detach()
|
||
g.edata['t'] = attn.unsqueeze(dim=-1) # (E, K, 1)
|
||
|
||
g.update_all(fn.u_mul_e('v', 't', 'm'), fn.sum('m', 'h'))
|
||
out = g.dstdata['h'].view(-1, self.out_dim) # (N_dst, d_out)
|
||
return out
|
||
|
||
|
||
class HGTLayer(nn.Module):
|
||
|
||
def __init__(self, in_dim, out_dim, num_heads, ntypes, etypes, dropout=0.2, use_norm=True):
|
||
"""HGT层
|
||
|
||
:param in_dim: int 输入特征维数
|
||
:param out_dim: int 输出特征维数
|
||
:param num_heads: int 注意力头数K
|
||
:param ntypes: List[str] 顶点类型列表
|
||
:param etypes: List[(str, str, str)] 规范边类型列表
|
||
:param dropout: dropout: float, optional Dropout概率,默认为0.2
|
||
:param use_norm: bool, optional 是否使用层归一化,默认为True
|
||
"""
|
||
super().__init__()
|
||
d_k = out_dim // num_heads
|
||
k_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
|
||
q_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
|
||
v_linear = {ntype: nn.Linear(in_dim, out_dim) for ntype in ntypes}
|
||
w_att = {r[1]: nn.Parameter(torch.Tensor(num_heads, d_k, d_k)) for r in etypes}
|
||
w_msg = {r[1]: nn.Parameter(torch.Tensor(num_heads, d_k, d_k)) for r in etypes}
|
||
mu = {r[1]: nn.Parameter(torch.ones(num_heads)) for r in etypes}
|
||
self.reset_parameters(w_att, w_msg)
|
||
self.conv = HeteroGraphConv({
|
||
etype: HGTAttention(
|
||
out_dim, num_heads, k_linear[stype], q_linear[dtype], v_linear[stype],
|
||
w_att[etype], w_msg[etype], mu[etype]
|
||
) for stype, etype, dtype in etypes
|
||
}, 'mean')
|
||
|
||
self.a_linear = nn.ModuleDict({ntype: nn.Linear(out_dim, out_dim) for ntype in ntypes})
|
||
self.skip = nn.ParameterDict({ntype: nn.Parameter(torch.ones(1)) for ntype in ntypes})
|
||
self.drop = nn.Dropout(dropout)
|
||
|
||
self.use_norm = use_norm
|
||
if use_norm:
|
||
self.norms = nn.ModuleDict({ntype: nn.LayerNorm(out_dim) for ntype in ntypes})
|
||
|
||
def reset_parameters(self, w_att, w_msg):
|
||
for etype in w_att:
|
||
nn.init.xavier_uniform_(w_att[etype])
|
||
nn.init.xavier_uniform_(w_msg[etype])
|
||
|
||
def forward(self, g, feats):
|
||
"""
|
||
:param g: DGLGraph 异构图
|
||
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入顶点特征的映射
|
||
:return: Dict[str, tensor(N_i, d_out)] 顶点类型到输出特征的映射
|
||
"""
|
||
if g.is_block:
|
||
feats_dst = {ntype: feats[ntype][:g.num_dst_nodes(ntype)] for ntype in feats}
|
||
else:
|
||
feats_dst = feats
|
||
with g.local_scope():
|
||
# 第1步:异构互注意力+异构消息传递+目标相关的聚集
|
||
hs = self.conv(g, feats) # {ntype: tensor(N_i, d_out)}
|
||
|
||
# 第2步:残差连接
|
||
out_feats = {}
|
||
for ntype in g.dsttypes:
|
||
if g.num_dst_nodes(ntype) == 0:
|
||
continue
|
||
alpha = torch.sigmoid(self.skip[ntype])
|
||
trans_out = self.drop(self.a_linear[ntype](hs[ntype]))
|
||
out = alpha * trans_out + (1 - alpha) * feats_dst[ntype]
|
||
out_feats[ntype] = self.norms[ntype](out) if self.use_norm else out
|
||
return out_feats
|
||
|
||
|
||
class HGT(nn.Module):
|
||
|
||
def __init__(
|
||
self, in_dims, hidden_dim, out_dim, num_heads, ntypes, etypes,
|
||
predict_ntype, num_layers, dropout=0.2, use_norm=True):
|
||
"""HGT模型
|
||
|
||
:param in_dims: Dict[str, int] 顶点类型到输入特征维数的映射
|
||
:param hidden_dim: int 隐含特征维数
|
||
:param out_dim: int 输出特征维数
|
||
:param num_heads: int 注意力头数K
|
||
:param ntypes: List[str] 顶点类型列表
|
||
:param etypes: List[(str, str, str)] 规范边类型列表
|
||
:param predict_ntype: str 待预测顶点类型
|
||
:param num_layers: int 层数
|
||
:param dropout: dropout: float, optional Dropout概率,默认为0.2
|
||
:param use_norm: bool, optional 是否使用层归一化,默认为True
|
||
"""
|
||
super().__init__()
|
||
self.predict_ntype = predict_ntype
|
||
self.adapt_fcs = nn.ModuleDict({
|
||
ntype: nn.Linear(in_dim, hidden_dim) for ntype, in_dim in in_dims.items()
|
||
})
|
||
self.layers = nn.ModuleList([
|
||
HGTLayer(hidden_dim, hidden_dim, num_heads, ntypes, etypes, dropout, use_norm)
|
||
for _ in range(num_layers)
|
||
])
|
||
self.predict = nn.Linear(hidden_dim, out_dim)
|
||
|
||
def forward(self, blocks, feats):
|
||
"""
|
||
:param blocks: List[DGLBlock]
|
||
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入顶点特征的映射
|
||
:return: tensor(N_i, d_out) 待预测顶点的最终嵌入
|
||
"""
|
||
hs = {ntype: F.gelu(self.adapt_fcs[ntype](feats[ntype])) for ntype in feats}
|
||
for i in range(len(self.layers)):
|
||
hs = self.layers[i](blocks[i], hs) # {ntype: tensor(N_i, d_hid)}
|
||
out = self.predict(hs[self.predict_ntype]) # tensor(N_i, d_out)
|
||
return out
|
||
|
||
|
||
class HGTFull(HGT):
|
||
|
||
def forward(self, g, feats):
|
||
return super().forward([g] * len(self.layers), feats)
|