127 lines
5.3 KiB
Python
127 lines
5.3 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from dgl.dataloading import MultiLayerFullNeighborSampler, NodeDataLoader
|
||
|
||
from ..heco.model import PositiveGraphEncoder, Contrast
|
||
from ..rhgnn.model import RHGNN
|
||
|
||
|
||
class RHCO(nn.Module):
|
||
|
||
def __init__(
|
||
self, in_dims, hidden_dim, out_dim, rel_hidden_dim, num_heads,
|
||
ntypes, etypes, predict_ntype, num_layers, dropout, num_pos_graphs, tau, lambda_):
|
||
"""基于对比学习的关系感知异构图神经网络RHCO
|
||
|
||
:param in_dims: Dict[str, int] 顶点类型到输入特征维数的映射
|
||
:param hidden_dim: int 隐含特征维数
|
||
:param out_dim: int 输出特征维数
|
||
:param rel_hidden_dim: int 关系隐含特征维数
|
||
:param num_heads: int 注意力头数K
|
||
:param ntypes: List[str] 顶点类型列表
|
||
:param etypes – List[(str, str, str)] 规范边类型列表
|
||
:param predict_ntype: str 目标顶点类型
|
||
:param num_layers: int 网络结构编码器层数
|
||
:param dropout: float 输入特征dropout
|
||
:param num_pos_graphs: int 正样本图个数M
|
||
:param tau: float 温度参数τ
|
||
:param lambda_: float 0~1之间,网络结构视图损失的系数λ(元路径视图损失的系数为1-λ)
|
||
"""
|
||
super().__init__()
|
||
self.hidden_dim = hidden_dim
|
||
self.predict_ntype = predict_ntype
|
||
self.sc_encoder = RHGNN(
|
||
in_dims, hidden_dim, hidden_dim, rel_hidden_dim, rel_hidden_dim, num_heads,
|
||
ntypes, etypes, predict_ntype, num_layers, dropout
|
||
)
|
||
self.pg_encoder = PositiveGraphEncoder(
|
||
num_pos_graphs, in_dims[predict_ntype], hidden_dim, dropout
|
||
)
|
||
self.contrast = Contrast(hidden_dim, tau, lambda_)
|
||
self.predict = nn.Linear(hidden_dim, out_dim)
|
||
self.reset_parameters()
|
||
|
||
def reset_parameters(self):
|
||
gain = nn.init.calculate_gain('relu')
|
||
nn.init.xavier_normal_(self.predict.weight, gain)
|
||
|
||
def forward(self, blocks, feats, mgs, mg_feats, pos):
|
||
"""
|
||
:param blocks: List[DGLBlock]
|
||
:param feats: Dict[str, tensor(N_i, d_in)] 顶点类型到输入特征的映射
|
||
:param mgs: List[DGLBlock] 正样本图,len(mgs)=元路径数量=目标顶点邻居类型数S≠模型层数
|
||
:param mg_feats: List[tensor(N_pos_src, d_in)] 正样本图源顶点的输入特征
|
||
:param pos: tensor(B, N) 布尔张量,每个顶点的正样本
|
||
(B是batch大小,真正的目标顶点;N是B个目标顶点加上其正样本后的顶点数)
|
||
:return: float, tensor(B, d_out) 对比损失,目标顶点输出特征
|
||
"""
|
||
z_sc = self.sc_encoder(blocks, feats) # (N, d_hid)
|
||
z_pg = self.pg_encoder(mgs, mg_feats) # (N, d_hid)
|
||
loss = self.contrast(z_sc, z_pg, pos)
|
||
return loss, self.predict(z_sc[:pos.shape[0]])
|
||
|
||
@torch.no_grad()
|
||
def get_embeds(self, g, batch_size, device):
|
||
"""计算目标顶点的最终嵌入(z_sc)
|
||
|
||
:param g: DGLGraph 异构图
|
||
:param batch_size: int 批大小
|
||
:param device torch.device GPU设备
|
||
:return: tensor(N_tgt, d_out) 目标顶点的最终嵌入
|
||
"""
|
||
sampler = MultiLayerFullNeighborSampler(len(self.sc_encoder.layers))
|
||
loader = NodeDataLoader(
|
||
g, {self.predict_ntype: g.nodes(self.predict_ntype)}, sampler,
|
||
device=device, batch_size=batch_size
|
||
)
|
||
embeds = torch.zeros(g.num_nodes(self.predict_ntype), self.hidden_dim, device=device)
|
||
for input_nodes, output_nodes, blocks in loader:
|
||
z_sc = self.sc_encoder(blocks, blocks[0].srcdata['feat'])
|
||
embeds[output_nodes[self.predict_ntype]] = z_sc
|
||
return self.predict(embeds)
|
||
|
||
|
||
class RHCOFull(RHCO):
|
||
"""Full-batch RHCO"""
|
||
|
||
def forward(self, g, feats, mgs, mg_feat, pos):
|
||
return super().forward(
|
||
[g] * len(self.sc_encoder.layers), feats, mgs, [mg_feat] * len(mgs), pos
|
||
)
|
||
|
||
@torch.no_grad()
|
||
def get_embeds(self, g, *args):
|
||
return self.predict(self.sc_encoder([g] * len(self.sc_encoder.layers), g.ndata['feat']))
|
||
|
||
|
||
class RHCOsc(RHCO):
|
||
"""RHCO消融实验变体:仅使用网络结构编码器"""
|
||
|
||
def forward(self, blocks, feats, mgs, mg_feats, pos):
|
||
z_sc = self.sc_encoder(blocks, feats) # (N, d_hid)
|
||
loss = self.contrast(z_sc, z_sc, pos)
|
||
return loss, self.predict(z_sc[:pos.shape[0]])
|
||
|
||
|
||
class RHCOpg(RHCO):
|
||
"""RHCO消融实验变体:仅使用正样本图编码器"""
|
||
|
||
def forward(self, blocks, feats, mgs, mg_feats, pos):
|
||
z_pg = self.pg_encoder(mgs, mg_feats) # (N, d_hid)
|
||
loss = self.contrast(z_pg, z_pg, pos)
|
||
return loss, self.predict(z_pg[:pos.shape[0]])
|
||
|
||
def get_embeds(self, mgs, feat, batch_size, device):
|
||
sampler = MultiLayerFullNeighborSampler(1)
|
||
mg_loaders = [
|
||
NodeDataLoader(mg, mg.nodes(self.predict_ntype), sampler, device=device, batch_size=batch_size)
|
||
for mg in mgs
|
||
]
|
||
embeds = torch.zeros(mgs[0].num_nodes(self.predict_ntype), self.hidden_dim, device=device)
|
||
for mg_blocks in zip(*mg_loaders):
|
||
output_nodes = mg_blocks[0][1]
|
||
mg_feats = [feat[i] for i, _, _ in mg_blocks]
|
||
mg_blocks = [b[0] for _, _, b in mg_blocks]
|
||
embeds[output_nodes] = self.pg_encoder(mg_blocks, mg_feats)
|
||
return self.predict(embeds)
|