GNNRecom/gnnrec/hge/rgcn/model.py

96 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import HeteroGraphConv, GraphConv
class RelGraphConv(nn.Module):
def __init__(self, in_dim, out_dim, ntypes, etypes, activation=None, dropout=0.0):
"""R-GCN层用于异构图
:param in_dim: 输入特征维数
:param out_dim: 输出特征维数
:param ntypes: List[str] 顶点类型列表
:param etypes: List[str] 边类型列表
:param activation: callable, optional 激活函数默认为None
:param dropout: float, optional Dropout概率默认为0
"""
super().__init__()
self.activation = activation
self.dropout = nn.Dropout(dropout)
self.conv = HeteroGraphConv({
etype: GraphConv(in_dim, out_dim, norm='right', bias=False)
for etype in etypes
}, 'sum')
self.loop_weight = nn.ModuleDict({
ntype: nn.Linear(in_dim, out_dim, bias=False) for ntype in ntypes
})
def forward(self, g, feats):
"""
:param g: DGLGraph 异构图
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入特征的映射
:return: Dict[str, tensor(N_i, d_out)] 顶点类型到输出特征的映射
"""
if g.is_block:
feats_dst = {ntype: feat[:g.num_dst_nodes(ntype)] for ntype, feat in feats.items()}
else:
feats_dst = feats
out = self.conv(g, (feats, feats_dst)) # Dict[ntype, (N_i, d_out)]
for ntype in out:
out[ntype] += self.loop_weight[ntype](feats_dst[ntype])
if self.activation:
out[ntype] = self.activation(out[ntype])
out[ntype] = self.dropout(out[ntype])
return out
class RGCN(nn.Module):
def __init__(
self, in_dim, hidden_dim, out_dim, input_ntypes, num_nodes, etypes, predict_ntype,
num_layers=2, dropout=0.0):
"""R-GCN模型
:param in_dim: int 输入特征维数
:param hidden_dim: int 隐含特征维数
:param out_dim: int 输出特征维数
:param input_ntypes: List[str] 有输入特征的顶点类型列表
:param num_nodes: Dict[str, int] 顶点类型到顶点数的映射
:param etypes: List[str] 边类型列表
:param predict_ntype: str 待预测顶点类型
:param num_layers: int, optional 层数默认为2
:param dropout: float, optional Dropout概率默认为0
"""
super().__init__()
self.embeds = nn.ModuleDict({
ntype: nn.Embedding(num_nodes[ntype], in_dim)
for ntype in num_nodes if ntype not in input_ntypes
})
ntypes = list(num_nodes)
self.layers = nn.ModuleList()
self.layers.append(RelGraphConv(in_dim, hidden_dim, ntypes, etypes, F.relu, dropout))
for i in range(num_layers - 2):
self.layers.append(RelGraphConv(hidden_dim, hidden_dim, ntypes, etypes, F.relu, dropout))
self.layers.append(RelGraphConv(hidden_dim, out_dim, ntypes, etypes))
self.predict_ntype = predict_ntype
self.reset_parameters()
def reset_parameters(self):
gain = nn.init.calculate_gain('relu')
for k in self.embeds:
nn.init.xavier_uniform_(self.embeds[k].weight, gain=gain)
def forward(self, g, feats):
"""
:param g: DGLGraph 异构图
:param feats: Dict[str, tensor(N_i, d_in_i)] (部分)顶点类型到输入特征的映射
:return: Dict[str, tensor(N_i, d_out)] 顶点类型到顶点嵌入的映射
"""
for k in self.embeds:
feats[k] = self.embeds[k].weight
for i in range(len(self.layers)):
feats = self.layers[i](g, feats) # Dict[ntype, (N_i, d_hid)]
return feats[self.predict_ntype]