130 lines
4.7 KiB
Python
130 lines
4.7 KiB
Python
|
import argparse
|
||
|
import json
|
||
|
|
||
|
from gnnrec.config import DATA_DIR
|
||
|
from gnnrec.kgrec.data.config import CS, CS_FIELD_L2
|
||
|
from gnnrec.kgrec.data.preprocess.utils import iter_lines
|
||
|
|
||
|
|
||
|
def extract_papers(raw_path):
|
||
|
valid_keys = ['title', 'authors', 'venue', 'year', 'indexed_abstract', 'fos', 'references']
|
||
|
cs_fields = set(CS_FIELD_L2)
|
||
|
for p in iter_lines(raw_path, 'paper'):
|
||
|
if not all(p.get(k) for k in valid_keys):
|
||
|
continue
|
||
|
fos = {f['name'] for f in p['fos']}
|
||
|
abstract = parse_abstract(p['indexed_abstract'])
|
||
|
if CS in fos and not fos.isdisjoint(cs_fields) \
|
||
|
and 2010 <= p['year'] <= 2021 \
|
||
|
and len(p['title']) <= 200 and len(abstract) <= 4000 \
|
||
|
and 1 <= len(p['authors']) <= 20 and 1 <= len(p['references']) <= 100:
|
||
|
try:
|
||
|
yield {
|
||
|
'id': p['id'],
|
||
|
'title': p['title'],
|
||
|
'authors': [a['id'] for a in p['authors']],
|
||
|
'venue': p['venue']['id'],
|
||
|
'year': p['year'],
|
||
|
'abstract': abstract,
|
||
|
'fos': list(fos),
|
||
|
'references': p['references'],
|
||
|
'n_citation': p.get('n_citation', 0),
|
||
|
}
|
||
|
except KeyError:
|
||
|
pass
|
||
|
|
||
|
|
||
|
def parse_abstract(indexed_abstract):
|
||
|
try:
|
||
|
abstract = json.loads(indexed_abstract)
|
||
|
words = [''] * abstract['IndexLength']
|
||
|
for w, idx in abstract['InvertedIndex'].items():
|
||
|
for i in idx:
|
||
|
words[i] = w
|
||
|
return ' '.join(words)
|
||
|
except json.JSONDecodeError:
|
||
|
return ''
|
||
|
|
||
|
|
||
|
def extract_authors(raw_path, author_ids):
|
||
|
for a in iter_lines(raw_path, 'author'):
|
||
|
if a['id'] in author_ids:
|
||
|
yield {
|
||
|
'id': a['id'],
|
||
|
'name': a['name'],
|
||
|
'org': int(a['last_known_aff_id']) if 'last_known_aff_id' in a else None
|
||
|
}
|
||
|
|
||
|
|
||
|
def extract_venues(raw_path, venue_ids):
|
||
|
for v in iter_lines(raw_path, 'venue'):
|
||
|
if v['id'] in venue_ids:
|
||
|
yield {'id': v['id'], 'name': v['DisplayName']}
|
||
|
|
||
|
|
||
|
def extract_institutions(raw_path, institution_ids):
|
||
|
for i in iter_lines(raw_path, 'affiliation'):
|
||
|
if i['id'] in institution_ids:
|
||
|
yield {'id': i['id'], 'name': i['DisplayName']}
|
||
|
|
||
|
|
||
|
def extract(args):
|
||
|
print('正在抽取计算机领域的论文...')
|
||
|
paper_ids, author_ids, venue_ids, fields = set(), set(), set(), set()
|
||
|
output_path = DATA_DIR / 'oag/cs'
|
||
|
with open(output_path / 'mag_papers.txt', 'w', encoding='utf8') as f:
|
||
|
for p in extract_papers(args.raw_path):
|
||
|
paper_ids.add(p['id'])
|
||
|
author_ids.update(p['authors'])
|
||
|
venue_ids.add(p['venue'])
|
||
|
fields.update(p['fos'])
|
||
|
json.dump(p, f, ensure_ascii=False)
|
||
|
f.write('\n')
|
||
|
print(f'论文抽取完成,已保存到{f.name}')
|
||
|
print(f'论文数{len(paper_ids)},学者数{len(author_ids)},期刊数{len(venue_ids)},领域数{len(fields)}')
|
||
|
|
||
|
print('正在抽取学者...')
|
||
|
institution_ids = set()
|
||
|
with open(output_path / 'mag_authors.txt', 'w', encoding='utf8') as f:
|
||
|
for a in extract_authors(args.raw_path, author_ids):
|
||
|
if a['org']:
|
||
|
institution_ids.add(a['org'])
|
||
|
json.dump(a, f, ensure_ascii=False)
|
||
|
f.write('\n')
|
||
|
print(f'学者抽取完成,已保存到{f.name}')
|
||
|
print(f'机构数{len(institution_ids)}')
|
||
|
|
||
|
print('正在抽取期刊...')
|
||
|
with open(output_path / 'mag_venues.txt', 'w', encoding='utf8') as f:
|
||
|
for v in extract_venues(args.raw_path, venue_ids):
|
||
|
json.dump(v, f, ensure_ascii=False)
|
||
|
f.write('\n')
|
||
|
print(f'期刊抽取完成,已保存到{f.name}')
|
||
|
|
||
|
print('正在抽取机构...')
|
||
|
with open(output_path / 'mag_institutions.txt', 'w', encoding='utf8') as f:
|
||
|
for i in extract_institutions(args.raw_path, institution_ids):
|
||
|
json.dump(i, f, ensure_ascii=False)
|
||
|
f.write('\n')
|
||
|
print(f'机构抽取完成,已保存到{f.name}')
|
||
|
|
||
|
print('正在抽取领域...')
|
||
|
fields.remove(CS)
|
||
|
fields = sorted(fields)
|
||
|
with open(output_path / 'mag_fields.txt', 'w', encoding='utf8') as f:
|
||
|
for i, field in enumerate(fields):
|
||
|
json.dump({'id': i, 'name': field}, f, ensure_ascii=False)
|
||
|
f.write('\n')
|
||
|
print(f'领域抽取完成,已保存到{f.name}')
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(description='抽取OAG数据集计算机领域的子集')
|
||
|
parser.add_argument('raw_path', help='原始zip文件所在目录')
|
||
|
args = parser.parse_args()
|
||
|
extract(args)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|