96 lines
3.8 KiB
Python
96 lines
3.8 KiB
Python
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from dgl.nn import HeteroGraphConv, GraphConv
|
||
|
||
|
||
class RelGraphConv(nn.Module):
|
||
|
||
def __init__(self, in_dim, out_dim, ntypes, etypes, activation=None, dropout=0.0):
|
||
"""R-GCN层(用于异构图)
|
||
|
||
:param in_dim: 输入特征维数
|
||
:param out_dim: 输出特征维数
|
||
:param ntypes: List[str] 顶点类型列表
|
||
:param etypes: List[str] 边类型列表
|
||
:param activation: callable, optional 激活函数,默认为None
|
||
:param dropout: float, optional Dropout概率,默认为0
|
||
"""
|
||
super().__init__()
|
||
self.activation = activation
|
||
self.dropout = nn.Dropout(dropout)
|
||
|
||
self.conv = HeteroGraphConv({
|
||
etype: GraphConv(in_dim, out_dim, norm='right', bias=False)
|
||
for etype in etypes
|
||
}, 'sum')
|
||
self.loop_weight = nn.ModuleDict({
|
||
ntype: nn.Linear(in_dim, out_dim, bias=False) for ntype in ntypes
|
||
})
|
||
|
||
def forward(self, g, feats):
|
||
"""
|
||
:param g: DGLGraph 异构图
|
||
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入特征的映射
|
||
:return: Dict[str, tensor(N_i, d_out)] 顶点类型到输出特征的映射
|
||
"""
|
||
if g.is_block:
|
||
feats_dst = {ntype: feat[:g.num_dst_nodes(ntype)] for ntype, feat in feats.items()}
|
||
else:
|
||
feats_dst = feats
|
||
out = self.conv(g, (feats, feats_dst)) # Dict[ntype, (N_i, d_out)]
|
||
for ntype in out:
|
||
out[ntype] += self.loop_weight[ntype](feats_dst[ntype])
|
||
if self.activation:
|
||
out[ntype] = self.activation(out[ntype])
|
||
out[ntype] = self.dropout(out[ntype])
|
||
return out
|
||
|
||
|
||
class RGCN(nn.Module):
|
||
|
||
def __init__(
|
||
self, in_dim, hidden_dim, out_dim, input_ntypes, num_nodes, etypes, predict_ntype,
|
||
num_layers=2, dropout=0.0):
|
||
"""R-GCN模型
|
||
|
||
:param in_dim: int 输入特征维数
|
||
:param hidden_dim: int 隐含特征维数
|
||
:param out_dim: int 输出特征维数
|
||
:param input_ntypes: List[str] 有输入特征的顶点类型列表
|
||
:param num_nodes: Dict[str, int] 顶点类型到顶点数的映射
|
||
:param etypes: List[str] 边类型列表
|
||
:param predict_ntype: str 待预测顶点类型
|
||
:param num_layers: int, optional 层数,默认为2
|
||
:param dropout: float, optional Dropout概率,默认为0
|
||
"""
|
||
super().__init__()
|
||
self.embeds = nn.ModuleDict({
|
||
ntype: nn.Embedding(num_nodes[ntype], in_dim)
|
||
for ntype in num_nodes if ntype not in input_ntypes
|
||
})
|
||
ntypes = list(num_nodes)
|
||
self.layers = nn.ModuleList()
|
||
self.layers.append(RelGraphConv(in_dim, hidden_dim, ntypes, etypes, F.relu, dropout))
|
||
for i in range(num_layers - 2):
|
||
self.layers.append(RelGraphConv(hidden_dim, hidden_dim, ntypes, etypes, F.relu, dropout))
|
||
self.layers.append(RelGraphConv(hidden_dim, out_dim, ntypes, etypes))
|
||
self.predict_ntype = predict_ntype
|
||
self.reset_parameters()
|
||
|
||
def reset_parameters(self):
|
||
gain = nn.init.calculate_gain('relu')
|
||
for k in self.embeds:
|
||
nn.init.xavier_uniform_(self.embeds[k].weight, gain=gain)
|
||
|
||
def forward(self, g, feats):
|
||
"""
|
||
:param g: DGLGraph 异构图
|
||
:param feats: Dict[str, tensor(N_i, d_in_i)] (部分)顶点类型到输入特征的映射
|
||
:return: Dict[str, tensor(N_i, d_out)] 顶点类型到顶点嵌入的映射
|
||
"""
|
||
for k in self.embeds:
|
||
feats[k] = self.embeds[k].weight
|
||
for i in range(len(self.layers)):
|
||
feats = self.layers[i](g, feats) # Dict[ntype, (N_i, d_hid)]
|
||
return feats[self.predict_ntype]
|