GNNRecom/gnnrec/kgrec/data/oagcs.py

154 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import dgl
import pandas as pd
import torch
from dgl.data import DGLDataset, extract_archive
from dgl.data.utils import save_graphs, load_graphs
from gnnrec.kgrec.utils import iter_json
class OAGCSDataset(DGLDataset):
"""OAG MAG数据集(https://www.aminer.cn/oag-2-1)计算机领域的子集,只有一个异构图
统计数据
-----
顶点
* 2248205 author
* 1852225 paper
* 11177 venue
* 13747 institution
* 120992 field
* 6349317 author-writes->paper
* 1852225 paper-published_at->venue
* 17250107 paper-has_field->field
* 9194781 paper-cites->paper
* 1726212 author-affiliated_with->institution
paper顶点属性
-----
* feat: tensor(N_paper, 128) 预训练的标题和摘要词向量
* year: tensor(N_paper) 发表年份2010~2021
* citation: tensor(N_paper) 引用数
* 不包含标签
field顶点属性
-----
* feat: tensor(N_field, 128) 预训练的领域向量
writes边属性
-----
* order: tensor(N_writes) 作者顺序从1开始
"""
def __init__(self, **kwargs):
super().__init__('oag-cs', 'https://pan.baidu.com/s/1ayH3tQxsiDDnqPoXhR0Ekg', **kwargs)
def download(self):
zip_file_path = os.path.join(self.raw_dir, 'oag-cs.zip')
if not os.path.exists(zip_file_path):
raise FileNotFoundError('请手动下载文件 {} 提取码2ylp 并保存为 {}'.format(
self.url, zip_file_path
))
extract_archive(zip_file_path, self.raw_path)
def save(self):
save_graphs(os.path.join(self.save_path, self.name + '_dgl_graph.bin'), [self.g])
def load(self):
self.g = load_graphs(os.path.join(self.save_path, self.name + '_dgl_graph.bin'))[0][0]
def process(self):
self._vid_map = self._read_venues() # {原始id: 顶点id}
self._oid_map = self._read_institutions() # {原始id: 顶点id}
self._fid_map = self._read_fields() # {领域名称: 顶点id}
self._aid_map, author_inst = self._read_authors() # {原始id: 顶点id}, R(aid, oid)
# PA(pid, aid), PV(pid, vid), PF(pid, fid), PP(pid, rid), [年份], [引用数]
paper_author, paper_venue, paper_field, paper_ref, paper_year, paper_citation = self._read_papers()
self.g = self._build_graph(paper_author, paper_venue, paper_field, paper_ref, author_inst, paper_year, paper_citation)
def _iter_json(self, filename):
yield from iter_json(os.path.join(self.raw_path, filename))
def _read_venues(self):
print('正在读取期刊数据...')
# 行号=索引=顶点id
return {v['id']: i for i, v in enumerate(self._iter_json('mag_venues.txt'))}
def _read_institutions(self):
print('正在读取机构数据...')
return {o['id']: i for i, o in enumerate(self._iter_json('mag_institutions.txt'))}
def _read_fields(self):
print('正在读取领域数据...')
return {f['name']: f['id'] for f in self._iter_json('mag_fields.txt')}
def _read_authors(self):
print('正在读取学者数据...')
author_id_map, author_inst = {}, []
for i, a in enumerate(self._iter_json('mag_authors.txt')):
author_id_map[a['id']] = i
if a['org'] is not None:
author_inst.append([i, self._oid_map[a['org']]])
return author_id_map, pd.DataFrame(author_inst, columns=['aid', 'oid'])
def _read_papers(self):
print('正在读取论文数据...')
paper_id_map, paper_author, paper_venue, paper_field = {}, [], [], []
paper_year, paper_citation = [], []
for i, p in enumerate(self._iter_json('mag_papers.txt')):
paper_id_map[p['id']] = i
paper_author.extend([i, self._aid_map[a], r + 1] for r, a in enumerate(p['authors']))
paper_venue.append([i, self._vid_map[p['venue']]])
paper_field.extend([i, self._fid_map[f]] for f in p['fos'] if f in self._fid_map)
paper_year.append(p['year'])
paper_citation.append(p['n_citation'])
paper_ref = []
for i, p in enumerate(self._iter_json('mag_papers.txt')):
paper_ref.extend([i, paper_id_map[r]] for r in p['references'] if r in paper_id_map)
return (
pd.DataFrame(paper_author, columns=['pid', 'aid', 'order']).drop_duplicates(subset=['pid', 'aid']),
pd.DataFrame(paper_venue, columns=['pid', 'vid']),
pd.DataFrame(paper_field, columns=['pid', 'fid']),
pd.DataFrame(paper_ref, columns=['pid', 'rid']),
paper_year, paper_citation
)
def _build_graph(self, paper_author, paper_venue, paper_field, paper_ref, author_inst, paper_year, paper_citation):
print('正在构造异构图...')
pa_p, pa_a = paper_author['pid'].to_list(), paper_author['aid'].to_list()
pv_p, pv_v = paper_venue['pid'].to_list(), paper_venue['vid'].to_list()
pf_p, pf_f = paper_field['pid'].to_list(), paper_field['fid'].to_list()
pp_p, pp_r = paper_ref['pid'].to_list(), paper_ref['rid'].to_list()
ai_a, ai_i = author_inst['aid'].to_list(), author_inst['oid'].to_list()
g = dgl.heterograph({
('author', 'writes', 'paper'): (pa_a, pa_p),
('paper', 'published_at', 'venue'): (pv_p, pv_v),
('paper', 'has_field', 'field'): (pf_p, pf_f),
('paper', 'cites', 'paper'): (pp_p, pp_r),
('author', 'affiliated_with', 'institution'): (ai_a, ai_i)
})
g.nodes['paper'].data['feat'] = torch.load(os.path.join(self.raw_path, 'paper_feat.pkl'))
g.nodes['paper'].data['year'] = torch.tensor(paper_year)
g.nodes['paper'].data['citation'] = torch.tensor(paper_citation)
g.nodes['field'].data['feat'] = torch.load(os.path.join(self.raw_path, 'field_feat.pkl'))
g.edges['writes'].data['order'] = torch.tensor(paper_author['order'].to_list())
return g
def has_cache(self):
return os.path.exists(os.path.join(self.save_path, self.name + '_dgl_graph.bin'))
def __getitem__(self, idx):
if idx != 0:
raise IndexError('This dataset has only one graph')
return self.g
def __len__(self):
return 1