154 lines
6.2 KiB
Python
154 lines
6.2 KiB
Python
import os
|
||
|
||
import dgl
|
||
import pandas as pd
|
||
import torch
|
||
from dgl.data import DGLDataset, extract_archive
|
||
from dgl.data.utils import save_graphs, load_graphs
|
||
|
||
from gnnrec.kgrec.utils import iter_json
|
||
|
||
|
||
class OAGCSDataset(DGLDataset):
|
||
"""OAG MAG数据集(https://www.aminer.cn/oag-2-1)计算机领域的子集,只有一个异构图
|
||
|
||
统计数据
|
||
-----
|
||
顶点
|
||
|
||
* 2248205 author
|
||
* 1852225 paper
|
||
* 11177 venue
|
||
* 13747 institution
|
||
* 120992 field
|
||
|
||
边
|
||
|
||
* 6349317 author-writes->paper
|
||
* 1852225 paper-published_at->venue
|
||
* 17250107 paper-has_field->field
|
||
* 9194781 paper-cites->paper
|
||
* 1726212 author-affiliated_with->institution
|
||
|
||
paper顶点属性
|
||
-----
|
||
* feat: tensor(N_paper, 128) 预训练的标题和摘要词向量
|
||
* year: tensor(N_paper) 发表年份(2010~2021)
|
||
* citation: tensor(N_paper) 引用数
|
||
* 不包含标签
|
||
|
||
field顶点属性
|
||
-----
|
||
* feat: tensor(N_field, 128) 预训练的领域向量
|
||
|
||
writes边属性
|
||
-----
|
||
* order: tensor(N_writes) 作者顺序(从1开始)
|
||
"""
|
||
|
||
def __init__(self, **kwargs):
|
||
super().__init__('oag-cs', 'https://pan.baidu.com/s/1ayH3tQxsiDDnqPoXhR0Ekg', **kwargs)
|
||
|
||
def download(self):
|
||
zip_file_path = os.path.join(self.raw_dir, 'oag-cs.zip')
|
||
if not os.path.exists(zip_file_path):
|
||
raise FileNotFoundError('请手动下载文件 {} 提取码:2ylp 并保存为 {}'.format(
|
||
self.url, zip_file_path
|
||
))
|
||
extract_archive(zip_file_path, self.raw_path)
|
||
|
||
def save(self):
|
||
save_graphs(os.path.join(self.save_path, self.name + '_dgl_graph.bin'), [self.g])
|
||
|
||
def load(self):
|
||
self.g = load_graphs(os.path.join(self.save_path, self.name + '_dgl_graph.bin'))[0][0]
|
||
|
||
def process(self):
|
||
self._vid_map = self._read_venues() # {原始id: 顶点id}
|
||
self._oid_map = self._read_institutions() # {原始id: 顶点id}
|
||
self._fid_map = self._read_fields() # {领域名称: 顶点id}
|
||
self._aid_map, author_inst = self._read_authors() # {原始id: 顶点id}, R(aid, oid)
|
||
# PA(pid, aid), PV(pid, vid), PF(pid, fid), PP(pid, rid), [年份], [引用数]
|
||
paper_author, paper_venue, paper_field, paper_ref, paper_year, paper_citation = self._read_papers()
|
||
self.g = self._build_graph(paper_author, paper_venue, paper_field, paper_ref, author_inst, paper_year, paper_citation)
|
||
|
||
def _iter_json(self, filename):
|
||
yield from iter_json(os.path.join(self.raw_path, filename))
|
||
|
||
def _read_venues(self):
|
||
print('正在读取期刊数据...')
|
||
# 行号=索引=顶点id
|
||
return {v['id']: i for i, v in enumerate(self._iter_json('mag_venues.txt'))}
|
||
|
||
def _read_institutions(self):
|
||
print('正在读取机构数据...')
|
||
return {o['id']: i for i, o in enumerate(self._iter_json('mag_institutions.txt'))}
|
||
|
||
def _read_fields(self):
|
||
print('正在读取领域数据...')
|
||
return {f['name']: f['id'] for f in self._iter_json('mag_fields.txt')}
|
||
|
||
def _read_authors(self):
|
||
print('正在读取学者数据...')
|
||
author_id_map, author_inst = {}, []
|
||
for i, a in enumerate(self._iter_json('mag_authors.txt')):
|
||
author_id_map[a['id']] = i
|
||
if a['org'] is not None:
|
||
author_inst.append([i, self._oid_map[a['org']]])
|
||
return author_id_map, pd.DataFrame(author_inst, columns=['aid', 'oid'])
|
||
|
||
def _read_papers(self):
|
||
print('正在读取论文数据...')
|
||
paper_id_map, paper_author, paper_venue, paper_field = {}, [], [], []
|
||
paper_year, paper_citation = [], []
|
||
for i, p in enumerate(self._iter_json('mag_papers.txt')):
|
||
paper_id_map[p['id']] = i
|
||
paper_author.extend([i, self._aid_map[a], r + 1] for r, a in enumerate(p['authors']))
|
||
paper_venue.append([i, self._vid_map[p['venue']]])
|
||
paper_field.extend([i, self._fid_map[f]] for f in p['fos'] if f in self._fid_map)
|
||
paper_year.append(p['year'])
|
||
paper_citation.append(p['n_citation'])
|
||
|
||
paper_ref = []
|
||
for i, p in enumerate(self._iter_json('mag_papers.txt')):
|
||
paper_ref.extend([i, paper_id_map[r]] for r in p['references'] if r in paper_id_map)
|
||
return (
|
||
pd.DataFrame(paper_author, columns=['pid', 'aid', 'order']).drop_duplicates(subset=['pid', 'aid']),
|
||
pd.DataFrame(paper_venue, columns=['pid', 'vid']),
|
||
pd.DataFrame(paper_field, columns=['pid', 'fid']),
|
||
pd.DataFrame(paper_ref, columns=['pid', 'rid']),
|
||
paper_year, paper_citation
|
||
)
|
||
|
||
def _build_graph(self, paper_author, paper_venue, paper_field, paper_ref, author_inst, paper_year, paper_citation):
|
||
print('正在构造异构图...')
|
||
pa_p, pa_a = paper_author['pid'].to_list(), paper_author['aid'].to_list()
|
||
pv_p, pv_v = paper_venue['pid'].to_list(), paper_venue['vid'].to_list()
|
||
pf_p, pf_f = paper_field['pid'].to_list(), paper_field['fid'].to_list()
|
||
pp_p, pp_r = paper_ref['pid'].to_list(), paper_ref['rid'].to_list()
|
||
ai_a, ai_i = author_inst['aid'].to_list(), author_inst['oid'].to_list()
|
||
g = dgl.heterograph({
|
||
('author', 'writes', 'paper'): (pa_a, pa_p),
|
||
('paper', 'published_at', 'venue'): (pv_p, pv_v),
|
||
('paper', 'has_field', 'field'): (pf_p, pf_f),
|
||
('paper', 'cites', 'paper'): (pp_p, pp_r),
|
||
('author', 'affiliated_with', 'institution'): (ai_a, ai_i)
|
||
})
|
||
g.nodes['paper'].data['feat'] = torch.load(os.path.join(self.raw_path, 'paper_feat.pkl'))
|
||
g.nodes['paper'].data['year'] = torch.tensor(paper_year)
|
||
g.nodes['paper'].data['citation'] = torch.tensor(paper_citation)
|
||
g.nodes['field'].data['feat'] = torch.load(os.path.join(self.raw_path, 'field_feat.pkl'))
|
||
g.edges['writes'].data['order'] = torch.tensor(paper_author['order'].to_list())
|
||
return g
|
||
|
||
def has_cache(self):
|
||
return os.path.exists(os.path.join(self.save_path, self.name + '_dgl_graph.bin'))
|
||
|
||
def __getitem__(self, idx):
|
||
if idx != 0:
|
||
raise IndexError('This dataset has only one graph')
|
||
return self.g
|
||
|
||
def __len__(self):
|
||
return 1
|