GNNRecom/rank/management/commands/loadoagcs.py

96 lines
3.7 KiB
Python
Raw Permalink Normal View History

2021-11-16 07:04:52 +00:00
import dgl
import dgl.function as fn
from django.core.management import BaseCommand
from tqdm import trange
from gnnrec.config import DATA_DIR
from gnnrec.kgrec.data import OAGCSDataset
from gnnrec.kgrec.utils import iter_json
from rank.models import Venue, Institution, Field, Author, Paper, Writes
class Command(BaseCommand):
help = '将oag-cs数据集导入数据库'
def add_arguments(self, parser):
parser.add_argument('--batch-size', type=int, default=2000, help='批大小')
def handle(self, *args, **options):
batch_size = options['batch_size']
raw_path = DATA_DIR / 'oag/cs'
print('正在导入期刊数据...')
Venue.objects.bulk_create([
Venue(id=i, name=v['name'])
for i, v in enumerate(iter_json(raw_path / 'mag_venues.txt'))
], batch_size=batch_size)
vid_map = {v['id']: i for i, v in enumerate(iter_json(raw_path / 'mag_venues.txt'))}
print('正在导入机构数据...')
Institution.objects.bulk_create([
Institution(id=i, name=o['name'])
for i, o in enumerate(iter_json(raw_path / 'mag_institutions.txt'))
], batch_size=batch_size)
oid_map = {o['id']: i for i, o in enumerate(iter_json(raw_path / 'mag_institutions.txt'))}
print('正在导入领域数据...')
Field.objects.bulk_create([
Field(id=i, name=f['name'])
for i, f in enumerate(iter_json(raw_path / 'mag_fields.txt'))
], batch_size=batch_size)
data = OAGCSDataset()
g = data[0]
apg = dgl.reverse(g['author', 'writes', 'paper'], copy_ndata=False)
apg.nodes['paper'].data['c'] = g.nodes['paper'].data['citation'].float()
apg.update_all(fn.copy_u('c', 'm'), fn.sum('m', 'c'))
author_citation = apg.nodes['author'].data['c'].int().tolist()
print('正在导入学者数据...')
Author.objects.bulk_create([
Author(
id=i, name=a['name'], n_citation=author_citation[i],
institution_id=oid_map[a['org']] if a['org'] is not None else None
) for i, a in enumerate(iter_json(raw_path / 'mag_authors.txt'))
], batch_size=batch_size)
print('正在导入论文数据...')
Paper.objects.bulk_create([
Paper(
id=i, title=p['title'], venue_id=vid_map[p['venue']], year=p['year'],
abstract=p['abstract'], n_citation=p['n_citation']
) for i, p in enumerate(iter_json(raw_path / 'mag_papers.txt'))
], batch_size=batch_size)
print('正在导入论文关联数据(很慢)...')
print('writes')
u, v = g.edges(etype='writes')
order = g.edges['writes'].data['order']
edges = list(zip(u.tolist(), v.tolist(), order.tolist()))
for i in trange(0, len(edges), batch_size):
Writes.objects.bulk_create([
Writes(author_id=a, paper_id=p, order=r)
for a, p, r in edges[i:i + batch_size]
])
print('has_field')
u, v = g.edges(etype='has_field')
edges = list(zip(u.tolist(), v.tolist()))
HasField = Paper.fos.through
for i in trange(0, len(edges), batch_size):
HasField.objects.bulk_create([
HasField(paper_id=p, field_id=f)
for p, f in edges[i:i + batch_size]
])
print('cites')
u, v = g.edges(etype='cites')
edges = list(zip(u.tolist(), v.tolist()))
Cites = Paper.references.through
for i in trange(0, len(edges), batch_size):
Cites.objects.bulk_create([
Cites(from_paper_id=p, to_paper_id=r)
for p, r in edges[i:i + batch_size]
])
print('导入完成')