29 lines
1.3 KiB
Python
29 lines
1.3 KiB
Python
import torch
|
||
from dgl.dataloading import MultiLayerNeighborSampler
|
||
|
||
|
||
class PositiveSampler(MultiLayerNeighborSampler):
|
||
|
||
def __init__(self, fanouts, pos):
|
||
"""用于HeCo模型的邻居采样器
|
||
|
||
对于每个batch的目标顶点,将其正样本添加到目标顶点并生成block
|
||
|
||
:param fanouts: 每层的邻居采样数(见MultiLayerNeighborSampler)
|
||
:param pos: tensor(N, T_pos) 每个顶点的正样本id,N是目标顶点数
|
||
"""
|
||
super().__init__(fanouts)
|
||
self.pos = pos
|
||
|
||
def sample_blocks(self, g, seed_nodes, exclude_eids=None):
|
||
# 如果g是异构图则seed_nodes是字典,应当只有目标顶点类型
|
||
if not g.is_homogeneous:
|
||
assert len(seed_nodes) == 1, 'PositiveSampler: 异构图只能指定目标顶点这一种类型'
|
||
ntype, seed_nodes = next(iter(seed_nodes.items()))
|
||
pos_samples = self.pos[seed_nodes].flatten() # (B, T_pos) -> (B*T_pos,)
|
||
added = list(set(pos_samples.tolist()) - set(seed_nodes.tolist()))
|
||
seed_nodes = torch.cat([seed_nodes, torch.tensor(added, device=seed_nodes.device)])
|
||
if not g.is_homogeneous:
|
||
seed_nodes = {ntype: seed_nodes}
|
||
return super().sample_blocks(g, seed_nodes, exclude_eids)
|