31 lines
990 B
Python
31 lines
990 B
Python
|
from torch.utils.data import Dataset
|
|||
|
|
|||
|
from gnnrec.kgrec.utils import iter_json
|
|||
|
|
|||
|
|
|||
|
class OAGCSContrastDataset(Dataset):
|
|||
|
SPLIT_YEAR = 2016
|
|||
|
|
|||
|
def __init__(self, raw_file, split='train'):
|
|||
|
"""oag-cs论文标题-关键词对比学习数据集
|
|||
|
|
|||
|
由于原始数据不包含关键词,因此使用研究领域(fos字段)作为关键词
|
|||
|
|
|||
|
:param raw_file: str 原始论文数据文件
|
|||
|
:param split: str "train", "valid", "all"
|
|||
|
"""
|
|||
|
self.titles = []
|
|||
|
self.keywords = []
|
|||
|
for p in iter_json(raw_file):
|
|||
|
if split == 'train' and p['year'] <= self.SPLIT_YEAR \
|
|||
|
or split == 'valid' and p['year'] > self.SPLIT_YEAR \
|
|||
|
or split == 'all':
|
|||
|
self.titles.append(p['title'])
|
|||
|
self.keywords.append('; '.join(p['fos']))
|
|||
|
|
|||
|
def __getitem__(self, item):
|
|||
|
return self.titles[item], self.keywords[item]
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.titles)
|