GNNRecom/gnnrec/kgrec/data/oagcs.py

154 lines
6.2 KiB
Python
Raw Normal View History

2021-11-16 07:04:52 +00:00
import os
import dgl
import pandas as pd
import torch
from dgl.data import DGLDataset, extract_archive
from dgl.data.utils import save_graphs, load_graphs
from gnnrec.kgrec.utils import iter_json
class OAGCSDataset(DGLDataset):
"""OAG MAG数据集(https://www.aminer.cn/oag-2-1)计算机领域的子集,只有一个异构图
统计数据
-----
顶点
* 2248205 author
* 1852225 paper
* 11177 venue
* 13747 institution
* 120992 field
* 6349317 author-writes->paper
* 1852225 paper-published_at->venue
* 17250107 paper-has_field->field
* 9194781 paper-cites->paper
* 1726212 author-affiliated_with->institution
paper顶点属性
-----
* feat: tensor(N_paper, 128) 预训练的标题和摘要词向量
* year: tensor(N_paper) 发表年份2010~2021
* citation: tensor(N_paper) 引用数
* 不包含标签
field顶点属性
-----
* feat: tensor(N_field, 128) 预训练的领域向量
writes边属性
-----
* order: tensor(N_writes) 作者顺序从1开始
"""
def __init__(self, **kwargs):
super().__init__('oag-cs', 'https://pan.baidu.com/s/1ayH3tQxsiDDnqPoXhR0Ekg', **kwargs)
def download(self):
zip_file_path = os.path.join(self.raw_dir, 'oag-cs.zip')
if not os.path.exists(zip_file_path):
raise FileNotFoundError('请手动下载文件 {} 提取码2ylp 并保存为 {}'.format(
self.url, zip_file_path
))
extract_archive(zip_file_path, self.raw_path)
def save(self):
save_graphs(os.path.join(self.save_path, self.name + '_dgl_graph.bin'), [self.g])
def load(self):
self.g = load_graphs(os.path.join(self.save_path, self.name + '_dgl_graph.bin'))[0][0]
def process(self):
self._vid_map = self._read_venues() # {原始id: 顶点id}
self._oid_map = self._read_institutions() # {原始id: 顶点id}
self._fid_map = self._read_fields() # {领域名称: 顶点id}
self._aid_map, author_inst = self._read_authors() # {原始id: 顶点id}, R(aid, oid)
# PA(pid, aid), PV(pid, vid), PF(pid, fid), PP(pid, rid), [年份], [引用数]
paper_author, paper_venue, paper_field, paper_ref, paper_year, paper_citation = self._read_papers()
self.g = self._build_graph(paper_author, paper_venue, paper_field, paper_ref, author_inst, paper_year, paper_citation)
def _iter_json(self, filename):
yield from iter_json(os.path.join(self.raw_path, filename))
def _read_venues(self):
print('正在读取期刊数据...')
# 行号=索引=顶点id
return {v['id']: i for i, v in enumerate(self._iter_json('mag_venues.txt'))}
def _read_institutions(self):
print('正在读取机构数据...')
return {o['id']: i for i, o in enumerate(self._iter_json('mag_institutions.txt'))}
def _read_fields(self):
print('正在读取领域数据...')
return {f['name']: f['id'] for f in self._iter_json('mag_fields.txt')}
def _read_authors(self):
print('正在读取学者数据...')
author_id_map, author_inst = {}, []
for i, a in enumerate(self._iter_json('mag_authors.txt')):
author_id_map[a['id']] = i
if a['org'] is not None:
author_inst.append([i, self._oid_map[a['org']]])
return author_id_map, pd.DataFrame(author_inst, columns=['aid', 'oid'])
def _read_papers(self):
print('正在读取论文数据...')
paper_id_map, paper_author, paper_venue, paper_field = {}, [], [], []
paper_year, paper_citation = [], []
for i, p in enumerate(self._iter_json('mag_papers.txt')):
paper_id_map[p['id']] = i
paper_author.extend([i, self._aid_map[a], r + 1] for r, a in enumerate(p['authors']))
paper_venue.append([i, self._vid_map[p['venue']]])
paper_field.extend([i, self._fid_map[f]] for f in p['fos'] if f in self._fid_map)
paper_year.append(p['year'])
paper_citation.append(p['n_citation'])
paper_ref = []
for i, p in enumerate(self._iter_json('mag_papers.txt')):
paper_ref.extend([i, paper_id_map[r]] for r in p['references'] if r in paper_id_map)
return (
pd.DataFrame(paper_author, columns=['pid', 'aid', 'order']).drop_duplicates(subset=['pid', 'aid']),
pd.DataFrame(paper_venue, columns=['pid', 'vid']),
pd.DataFrame(paper_field, columns=['pid', 'fid']),
pd.DataFrame(paper_ref, columns=['pid', 'rid']),
paper_year, paper_citation
)
def _build_graph(self, paper_author, paper_venue, paper_field, paper_ref, author_inst, paper_year, paper_citation):
print('正在构造异构图...')
pa_p, pa_a = paper_author['pid'].to_list(), paper_author['aid'].to_list()
pv_p, pv_v = paper_venue['pid'].to_list(), paper_venue['vid'].to_list()
pf_p, pf_f = paper_field['pid'].to_list(), paper_field['fid'].to_list()
pp_p, pp_r = paper_ref['pid'].to_list(), paper_ref['rid'].to_list()
ai_a, ai_i = author_inst['aid'].to_list(), author_inst['oid'].to_list()
g = dgl.heterograph({
('author', 'writes', 'paper'): (pa_a, pa_p),
('paper', 'published_at', 'venue'): (pv_p, pv_v),
('paper', 'has_field', 'field'): (pf_p, pf_f),
('paper', 'cites', 'paper'): (pp_p, pp_r),
('author', 'affiliated_with', 'institution'): (ai_a, ai_i)
})
g.nodes['paper'].data['feat'] = torch.load(os.path.join(self.raw_path, 'paper_feat.pkl'))
g.nodes['paper'].data['year'] = torch.tensor(paper_year)
g.nodes['paper'].data['citation'] = torch.tensor(paper_citation)
g.nodes['field'].data['feat'] = torch.load(os.path.join(self.raw_path, 'field_feat.pkl'))
g.edges['writes'].data['order'] = torch.tensor(paper_author['order'].to_list())
return g
def has_cache(self):
return os.path.exists(os.path.join(self.save_path, self.name + '_dgl_graph.bin'))
def __getitem__(self, idx):
if idx != 0:
raise IndexError('This dataset has only one graph')
return self.g
def __len__(self):
return 1