GNNRecom/gnnrec/kgrec/data/contrast.py

31 lines
990 B
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from torch.utils.data import Dataset
from gnnrec.kgrec.utils import iter_json
class OAGCSContrastDataset(Dataset):
SPLIT_YEAR = 2016
def __init__(self, raw_file, split='train'):
"""oag-cs论文标题-关键词对比学习数据集
由于原始数据不包含关键词因此使用研究领域fos字段作为关键词
:param raw_file: str 原始论文数据文件
:param split: str "train", "valid", "all"
"""
self.titles = []
self.keywords = []
for p in iter_json(raw_file):
if split == 'train' and p['year'] <= self.SPLIT_YEAR \
or split == 'valid' and p['year'] > self.SPLIT_YEAR \
or split == 'all':
self.titles.append(p['title'])
self.keywords.append('; '.join(p['fos']))
def __getitem__(self, item):
return self.titles[item], self.keywords[item]
def __len__(self):
return len(self.titles)