import dgl.function as fn import torch import torch.nn as nn import torch.nn.functional as F from dgl.dataloading import MultiLayerFullNeighborSampler, NodeDataLoader from dgl.ops import edge_softmax from dgl.utils import expand_as_pair from tqdm import tqdm class MicroConv(nn.Module): def __init__( self, out_dim, num_heads, fc_src, fc_dst, attn_src, feat_drop=0.0, negative_slope=0.2, activation=None): """微观层次卷积 针对一种关系(边类型)R=,聚集关系R下的邻居信息,得到关系R关于dtype类型顶点的表示 (特征转换矩阵和注意力向量是与顶点类型相关的,除此之外与GAT完全相同) :param out_dim: int 输出特征维数 :param num_heads: int 注意力头数K :param fc_src: nn.Linear(d_in, K*d_out) 源顶点特征转换模块 :param fc_dst: nn.Linear(d_in, K*d_out) 目标顶点特征转换模块 :param attn_src: nn.Parameter(K, 2d_out) 源顶点类型对应的注意力向量 :param feat_drop: float, optional 输入特征Dropout概率,默认为0 :param negative_slope: float, optional LeakyReLU负斜率,默认为0.2 :param activation: callable, optional 用于输出特征的激活函数,默认为None """ super().__init__() self.out_dim = out_dim self.num_heads = num_heads self.fc_src = fc_src self.fc_dst = fc_dst self.attn_src = attn_src self.feat_drop = nn.Dropout(feat_drop) self.leaky_relu = nn.LeakyReLU(negative_slope) self.activation = activation def forward(self, g, feat): """ :param g: DGLGraph 二分图(只包含一种关系) :param feat: tensor(N_src, d_in) or (tensor(N_src, d_in), tensor(N_dst, d_in)) 输入特征 :return: tensor(N_dst, K*d_out) 该关系关于目标顶点的表示 """ with g.local_scope(): feat_src, feat_dst = expand_as_pair(feat, g) feat_src = self.fc_src(self.feat_drop(feat_src)).view(-1, self.num_heads, self.out_dim) feat_dst = self.fc_dst(self.feat_drop(feat_dst)).view(-1, self.num_heads, self.out_dim) # a^T (z_u || z_v) = (a_l^T || a_r^T) (z_u || z_v) = a_l^T z_u + a_r^T z_v = el + er el = (feat_src * self.attn_src[:, :self.out_dim]).sum(dim=-1, keepdim=True) # (N_src, K, 1) er = (feat_dst * self.attn_src[:, self.out_dim:]).sum(dim=-1, keepdim=True) # (N_dst, K, 1) g.srcdata.update({'ft': feat_src, 'el': el}) g.dstdata['er'] = er g.apply_edges(fn.u_add_v('el', 'er', 'e')) e = self.leaky_relu(g.edata.pop('e')) g.edata['a'] = edge_softmax(g, e) # (E, K, 1) # 消息传递 g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) ret = g.dstdata['ft'].view(-1, self.num_heads * self.out_dim) if self.activation: ret = self.activation(ret) return ret class MacroConv(nn.Module): def __init__(self, out_dim, num_heads, fc_node, fc_rel, attn, dropout=0.0, negative_slope=0.2): """宏观层次卷积 针对所有关系(边类型),将每种类型的顶点关联的所有关系关于该类型顶点的表示组合起来 :param out_dim: int 输出特征维数 :param num_heads: int 注意力头数K :param fc_node: Dict[str, nn.Linear(d_in, K*d_out)] 顶点类型到顶点特征转换模块的映射 :param fc_rel: Dict[str, nn.Linear(K*d_out, K*d_out)] 关系到关系表示转换模块的映射 :param attn: nn.Parameter(K, 2d_out) :param dropout: float, optional Dropout概率,默认为0 :param negative_slope: float, optional LeakyReLU负斜率,默认为0.2 """ super().__init__() self.out_dim = out_dim self.num_heads = num_heads self.fc_node = fc_node self.fc_rel = fc_rel self.attn = attn self.dropout = nn.Dropout(dropout) self.leaky_relu = nn.LeakyReLU(negative_slope) def forward(self, node_feats, rel_feats): """ :param node_feats: Dict[str, tensor(N_i, d_in) 顶点类型到输入顶点特征的映射 :param rel_feats: Dict[(str, str, str), tensor(N_i, K*d_out)] 关系(stype, etype, dtype)到关系关于其终点类型的表示的映射 :return: Dict[str, tensor(N_i, K*d_out)] 顶点类型到最终顶点嵌入的映射 """ node_feats = { ntype: self.fc_node[ntype](feat).view(-1, self.num_heads, self.out_dim) for ntype, feat in node_feats.items() } rel_feats = { r: self.fc_rel[r[1]](feat).view(-1, self.num_heads, self.out_dim) for r, feat in rel_feats.items() } out_feats = {} for ntype, node_feat in node_feats.items(): rel_node_feats = [feat for rel, feat in rel_feats.items() if rel[2] == ntype] if not rel_node_feats: continue elif len(rel_node_feats) == 1: out_feats[ntype] = rel_node_feats[0].view(-1, self.num_heads * self.out_dim) else: rel_node_feats = torch.stack(rel_node_feats, dim=0) # (R, N_i, K, d_out) cat_feats = torch.cat( (node_feat.repeat(rel_node_feats.shape[0], 1, 1, 1), rel_node_feats), dim=-1 ) # (R, N_i, K, 2d_out) attn_scores = self.leaky_relu((self.attn * cat_feats).sum(dim=-1, keepdim=True)) attn_scores = F.softmax(attn_scores, dim=0) # (R, N_i, K, 1) out_feat = (attn_scores * rel_node_feats).sum(dim=0) # (N_i, K, d_out) out_feats[ntype] = self.dropout(out_feat.reshape(-1, self.num_heads * self.out_dim)) return out_feats class HGConvLayer(nn.Module): def __init__(self, in_dim, out_dim, num_heads, ntypes, etypes, dropout=0.0, residual=True): """HGConv层 :param in_dim: int 输入特征维数 :param out_dim: int 输出特征维数 :param num_heads: int 注意力头数K :param ntypes: List[str] 顶点类型列表 :param etypes: List[(str, str, str)] 规范边类型列表 :param dropout: float, optional Dropout概率,默认为0 :param residual: bool, optional 是否使用残差连接,默认True """ super().__init__() # 微观层次卷积的参数 micro_fc = {ntype: nn.Linear(in_dim, num_heads * out_dim, bias=False) for ntype in ntypes} micro_attn = { ntype: nn.Parameter(torch.FloatTensor(size=(num_heads, 2 * out_dim))) for ntype in ntypes } # 宏观层次卷积的参数 macro_fc_node = nn.ModuleDict({ ntype: nn.Linear(in_dim, num_heads * out_dim, bias=False) for ntype in ntypes }) macro_fc_rel = nn.ModuleDict({ r[1]: nn.Linear(num_heads * out_dim, num_heads * out_dim, bias=False) for r in etypes }) macro_attn = nn.Parameter(torch.FloatTensor(size=(num_heads, 2 * out_dim))) self.micro_conv = nn.ModuleDict({ etype: MicroConv( out_dim, num_heads, micro_fc[stype], micro_fc[dtype], micro_attn[stype], dropout, activation=F.relu ) for stype, etype, dtype in etypes }) self.macro_conv = MacroConv( out_dim, num_heads, macro_fc_node, macro_fc_rel, macro_attn, dropout ) self.residual = residual if residual: self.res_fc = nn.ModuleDict({ ntype: nn.Linear(in_dim, num_heads * out_dim) for ntype in ntypes }) self.res_weight = nn.ParameterDict({ ntype: nn.Parameter(torch.rand(1)) for ntype in ntypes }) self.reset_parameters(micro_fc, micro_attn, macro_fc_node, macro_fc_rel, macro_attn) def reset_parameters(self, micro_fc, micro_attn, macro_fc_node, macro_fc_rel, macro_attn): gain = nn.init.calculate_gain('relu') for ntype in micro_fc: nn.init.xavier_normal_(micro_fc[ntype].weight, gain=gain) nn.init.xavier_normal_(micro_attn[ntype], gain=gain) nn.init.xavier_normal_(macro_fc_node[ntype].weight, gain=gain) if self.residual: nn.init.xavier_normal_(self.res_fc[ntype].weight, gain=gain) for etype in macro_fc_rel: nn.init.xavier_normal_(macro_fc_rel[etype].weight, gain=gain) nn.init.xavier_normal_(macro_attn, gain=gain) def forward(self, g, feats): """ :param g: DGLGraph 异构图 :param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入顶点特征的映射 :return: Dict[str, tensor(N_i, K*d_out)] 顶点类型到最终顶点嵌入的映射 """ if g.is_block: feats_dst = {ntype: feats[ntype][:g.num_dst_nodes(ntype)] for ntype in feats} else: feats_dst = feats rel_feats = { (stype, etype, dtype): self.micro_conv[etype]( g[stype, etype, dtype], (feats[stype], feats_dst[dtype]) ) for stype, etype, dtype in g.canonical_etypes if g.num_edges((stype, etype, dtype)) > 0 } # {rel: tensor(N_i, K*d_out)} out_feats = self.macro_conv(feats_dst, rel_feats) # {ntype: tensor(N_i, K*d_out)} if self.residual: for ntype in out_feats: alpha = torch.sigmoid(self.res_weight[ntype]) inherit_feat = self.res_fc[ntype](feats_dst[ntype]) out_feats[ntype] = alpha * out_feats[ntype] + (1 - alpha) * inherit_feat return out_feats class HGConv(nn.Module): def __init__( self, in_dims, hidden_dim, out_dim, num_heads, ntypes, etypes, predict_ntype, num_layers, dropout=0.0, residual=True): """HGConv模型 :param in_dims: Dict[str, int] 顶点类型到输入特征维数的映射 :param hidden_dim: int 隐含特征维数 :param out_dim: int 输出特征维数 :param num_heads: int 注意力头数K :param ntypes: List[str] 顶点类型列表 :param etypes: List[(str, str, str)] 规范边类型列表 :param predict_ntype: str 待预测顶点类型 :param num_layers: int 层数 :param dropout: float, optional Dropout概率,默认为0 :param residual: bool, optional 是否使用残差连接,默认True """ super().__init__() self.d = num_heads * hidden_dim self.predict_ntype = predict_ntype # 对齐输入特征维数 self.fc_in = nn.ModuleDict({ ntype: nn.Linear(in_dim, num_heads * hidden_dim) for ntype, in_dim in in_dims.items() }) self.layers = nn.ModuleList([ HGConvLayer( num_heads * hidden_dim, hidden_dim, num_heads, ntypes, etypes, dropout, residual ) for _ in range(num_layers) ]) self.classifier = nn.Linear(num_heads * hidden_dim, out_dim) def forward(self, blocks, feats): """ :param blocks: List[DGLBlock] :param feats: Dict[str, tensor(N_i, d_in_i)] 顶点类型到输入顶点特征的映射 :return: tensor(N_i, d_out) 待预测顶点的最终嵌入 """ feats = {ntype: self.fc_in[ntype](feat) for ntype, feat in feats.items()} for i in range(len(self.layers)): feats = self.layers[i](blocks[i], feats) # {ntype: tensor(N_i, K*d_hid)} return self.classifier(feats[self.predict_ntype]) @torch.no_grad() def inference(self, g, feats, device, batch_size): """离线推断所有顶点的最终嵌入(不使用邻居采样) :param g: DGLGraph 异构图 :param feats: Dict[str, tensor(N_i, d_in_i)] 顶点类型到输入顶点特征的映射 :param device: torch.device :param batch_size: int 批大小 :return: tensor(N_i, d_out) 待预测顶点的最终嵌入 """ g.ndata['emb'] = {ntype: self.fc_in[ntype](feat) for ntype, feat in feats.items()} for layer in self.layers: embeds = { ntype: torch.zeros(g.num_nodes(ntype), self.d, device=device) for ntype in g.ntypes } sampler = MultiLayerFullNeighborSampler(1) loader = NodeDataLoader( g, {ntype: g.nodes(ntype) for ntype in g.ntypes}, sampler, device=device, batch_size=batch_size, shuffle=True ) for input_nodes, output_nodes, blocks in tqdm(loader): block = blocks[0] h = layer(block, block.srcdata['emb']) for ntype in h: embeds[ntype][output_nodes[ntype]] = h[ntype] g.ndata['emb'] = embeds return self.classifier(g.nodes[self.predict_ntype].data['emb']) class HGConvFull(HGConv): def forward(self, g, feats): return super().forward([g] * len(self.layers), feats)