GNNRecom/gnnrec/hge/rgcn/model.py

96 lines
3.8 KiB
Python
Raw Normal View History

2021-11-16 07:04:52 +00:00
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import HeteroGraphConv, GraphConv
class RelGraphConv(nn.Module):
def __init__(self, in_dim, out_dim, ntypes, etypes, activation=None, dropout=0.0):
"""R-GCN层用于异构图
:param in_dim: 输入特征维数
:param out_dim: 输出特征维数
:param ntypes: List[str] 顶点类型列表
:param etypes: List[str] 边类型列表
:param activation: callable, optional 激活函数默认为None
:param dropout: float, optional Dropout概率默认为0
"""
super().__init__()
self.activation = activation
self.dropout = nn.Dropout(dropout)
self.conv = HeteroGraphConv({
etype: GraphConv(in_dim, out_dim, norm='right', bias=False)
for etype in etypes
}, 'sum')
self.loop_weight = nn.ModuleDict({
ntype: nn.Linear(in_dim, out_dim, bias=False) for ntype in ntypes
})
def forward(self, g, feats):
"""
:param g: DGLGraph 异构图
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入特征的映射
:return: Dict[str, tensor(N_i, d_out)] 顶点类型到输出特征的映射
"""
if g.is_block:
feats_dst = {ntype: feat[:g.num_dst_nodes(ntype)] for ntype, feat in feats.items()}
else:
feats_dst = feats
out = self.conv(g, (feats, feats_dst)) # Dict[ntype, (N_i, d_out)]
for ntype in out:
out[ntype] += self.loop_weight[ntype](feats_dst[ntype])
if self.activation:
out[ntype] = self.activation(out[ntype])
out[ntype] = self.dropout(out[ntype])
return out
class RGCN(nn.Module):
def __init__(
self, in_dim, hidden_dim, out_dim, input_ntypes, num_nodes, etypes, predict_ntype,
num_layers=2, dropout=0.0):
"""R-GCN模型
:param in_dim: int 输入特征维数
:param hidden_dim: int 隐含特征维数
:param out_dim: int 输出特征维数
:param input_ntypes: List[str] 有输入特征的顶点类型列表
:param num_nodes: Dict[str, int] 顶点类型到顶点数的映射
:param etypes: List[str] 边类型列表
:param predict_ntype: str 待预测顶点类型
:param num_layers: int, optional 层数默认为2
:param dropout: float, optional Dropout概率默认为0
"""
super().__init__()
self.embeds = nn.ModuleDict({
ntype: nn.Embedding(num_nodes[ntype], in_dim)
for ntype in num_nodes if ntype not in input_ntypes
})
ntypes = list(num_nodes)
self.layers = nn.ModuleList()
self.layers.append(RelGraphConv(in_dim, hidden_dim, ntypes, etypes, F.relu, dropout))
for i in range(num_layers - 2):
self.layers.append(RelGraphConv(hidden_dim, hidden_dim, ntypes, etypes, F.relu, dropout))
self.layers.append(RelGraphConv(hidden_dim, out_dim, ntypes, etypes))
self.predict_ntype = predict_ntype
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain('relu')
for k in self.embeds:
nn.init.xavier_uniform_(self.embeds[k].weight, gain=gain)
def forward(self, g, feats):
"""
:param g: DGLGraph 异构图
:param feats: Dict[str, tensor(N_i, d_in_i)] 部分顶点类型到输入特征的映射
:return: Dict[str, tensor(N_i, d_out)] 顶点类型到顶点嵌入的映射
"""
for k in self.embeds:
feats[k] = self.embeds[k].weight
for i in range(len(self.layers)):
feats = self.layers[i](g, feats) # Dict[ntype, (N_i, d_hid)]
return feats[self.predict_ntype]