96 lines
3.7 KiB
Python
96 lines
3.7 KiB
Python
import dgl
|
|
import dgl.function as fn
|
|
from django.core.management import BaseCommand
|
|
from tqdm import trange
|
|
|
|
from gnnrec.config import DATA_DIR
|
|
from gnnrec.kgrec.data import OAGCSDataset
|
|
from gnnrec.kgrec.utils import iter_json
|
|
from rank.models import Venue, Institution, Field, Author, Paper, Writes
|
|
|
|
|
|
class Command(BaseCommand):
|
|
help = '将oag-cs数据集导入数据库'
|
|
|
|
def add_arguments(self, parser):
|
|
parser.add_argument('--batch-size', type=int, default=2000, help='批大小')
|
|
|
|
def handle(self, *args, **options):
|
|
batch_size = options['batch_size']
|
|
raw_path = DATA_DIR / 'oag/cs'
|
|
|
|
print('正在导入期刊数据...')
|
|
Venue.objects.bulk_create([
|
|
Venue(id=i, name=v['name'])
|
|
for i, v in enumerate(iter_json(raw_path / 'mag_venues.txt'))
|
|
], batch_size=batch_size)
|
|
vid_map = {v['id']: i for i, v in enumerate(iter_json(raw_path / 'mag_venues.txt'))}
|
|
|
|
print('正在导入机构数据...')
|
|
Institution.objects.bulk_create([
|
|
Institution(id=i, name=o['name'])
|
|
for i, o in enumerate(iter_json(raw_path / 'mag_institutions.txt'))
|
|
], batch_size=batch_size)
|
|
oid_map = {o['id']: i for i, o in enumerate(iter_json(raw_path / 'mag_institutions.txt'))}
|
|
|
|
print('正在导入领域数据...')
|
|
Field.objects.bulk_create([
|
|
Field(id=i, name=f['name'])
|
|
for i, f in enumerate(iter_json(raw_path / 'mag_fields.txt'))
|
|
], batch_size=batch_size)
|
|
|
|
data = OAGCSDataset()
|
|
g = data[0]
|
|
apg = dgl.reverse(g['author', 'writes', 'paper'], copy_ndata=False)
|
|
apg.nodes['paper'].data['c'] = g.nodes['paper'].data['citation'].float()
|
|
apg.update_all(fn.copy_u('c', 'm'), fn.sum('m', 'c'))
|
|
author_citation = apg.nodes['author'].data['c'].int().tolist()
|
|
|
|
print('正在导入学者数据...')
|
|
Author.objects.bulk_create([
|
|
Author(
|
|
id=i, name=a['name'], n_citation=author_citation[i],
|
|
institution_id=oid_map[a['org']] if a['org'] is not None else None
|
|
) for i, a in enumerate(iter_json(raw_path / 'mag_authors.txt'))
|
|
], batch_size=batch_size)
|
|
|
|
print('正在导入论文数据...')
|
|
Paper.objects.bulk_create([
|
|
Paper(
|
|
id=i, title=p['title'], venue_id=vid_map[p['venue']], year=p['year'],
|
|
abstract=p['abstract'], n_citation=p['n_citation']
|
|
) for i, p in enumerate(iter_json(raw_path / 'mag_papers.txt'))
|
|
], batch_size=batch_size)
|
|
|
|
print('正在导入论文关联数据(很慢)...')
|
|
print('writes')
|
|
u, v = g.edges(etype='writes')
|
|
order = g.edges['writes'].data['order']
|
|
edges = list(zip(u.tolist(), v.tolist(), order.tolist()))
|
|
for i in trange(0, len(edges), batch_size):
|
|
Writes.objects.bulk_create([
|
|
Writes(author_id=a, paper_id=p, order=r)
|
|
for a, p, r in edges[i:i + batch_size]
|
|
])
|
|
|
|
print('has_field')
|
|
u, v = g.edges(etype='has_field')
|
|
edges = list(zip(u.tolist(), v.tolist()))
|
|
HasField = Paper.fos.through
|
|
for i in trange(0, len(edges), batch_size):
|
|
HasField.objects.bulk_create([
|
|
HasField(paper_id=p, field_id=f)
|
|
for p, f in edges[i:i + batch_size]
|
|
])
|
|
|
|
print('cites')
|
|
u, v = g.edges(etype='cites')
|
|
edges = list(zip(u.tolist(), v.tolist()))
|
|
Cites = Paper.references.through
|
|
for i in trange(0, len(edges), batch_size):
|
|
Cites.objects.bulk_create([
|
|
Cites(from_paper_id=p, to_paper_id=r)
|
|
for p, r in edges[i:i + batch_size]
|
|
])
|
|
print('导入完成')
|