GNNRecom/gnnrec/kgrec/data/preprocess/extract_cs.py

130 lines
4.7 KiB
Python
Raw Normal View History

2021-11-16 07:04:52 +00:00
import argparse
import json
from gnnrec.config import DATA_DIR
from gnnrec.kgrec.data.config import CS, CS_FIELD_L2
from gnnrec.kgrec.data.preprocess.utils import iter_lines
def extract_papers(raw_path):
valid_keys = ['title', 'authors', 'venue', 'year', 'indexed_abstract', 'fos', 'references']
cs_fields = set(CS_FIELD_L2)
for p in iter_lines(raw_path, 'paper'):
if not all(p.get(k) for k in valid_keys):
continue
fos = {f['name'] for f in p['fos']}
abstract = parse_abstract(p['indexed_abstract'])
if CS in fos and not fos.isdisjoint(cs_fields) \
and 2010 <= p['year'] <= 2021 \
and len(p['title']) <= 200 and len(abstract) <= 4000 \
and 1 <= len(p['authors']) <= 20 and 1 <= len(p['references']) <= 100:
try:
yield {
'id': p['id'],
'title': p['title'],
'authors': [a['id'] for a in p['authors']],
'venue': p['venue']['id'],
'year': p['year'],
'abstract': abstract,
'fos': list(fos),
'references': p['references'],
'n_citation': p.get('n_citation', 0),
}
except KeyError:
pass
def parse_abstract(indexed_abstract):
try:
abstract = json.loads(indexed_abstract)
words = [''] * abstract['IndexLength']
for w, idx in abstract['InvertedIndex'].items():
for i in idx:
words[i] = w
return ' '.join(words)
except json.JSONDecodeError:
return ''
def extract_authors(raw_path, author_ids):
for a in iter_lines(raw_path, 'author'):
if a['id'] in author_ids:
yield {
'id': a['id'],
'name': a['name'],
'org': int(a['last_known_aff_id']) if 'last_known_aff_id' in a else None
}
def extract_venues(raw_path, venue_ids):
for v in iter_lines(raw_path, 'venue'):
if v['id'] in venue_ids:
yield {'id': v['id'], 'name': v['DisplayName']}
def extract_institutions(raw_path, institution_ids):
for i in iter_lines(raw_path, 'affiliation'):
if i['id'] in institution_ids:
yield {'id': i['id'], 'name': i['DisplayName']}
def extract(args):
print('正在抽取计算机领域的论文...')
paper_ids, author_ids, venue_ids, fields = set(), set(), set(), set()
output_path = DATA_DIR / 'oag/cs'
with open(output_path / 'mag_papers.txt', 'w', encoding='utf8') as f:
for p in extract_papers(args.raw_path):
paper_ids.add(p['id'])
author_ids.update(p['authors'])
venue_ids.add(p['venue'])
fields.update(p['fos'])
json.dump(p, f, ensure_ascii=False)
f.write('\n')
print(f'论文抽取完成,已保存到{f.name}')
print(f'论文数{len(paper_ids)},学者数{len(author_ids)},期刊数{len(venue_ids)},领域数{len(fields)}')
print('正在抽取学者...')
institution_ids = set()
with open(output_path / 'mag_authors.txt', 'w', encoding='utf8') as f:
for a in extract_authors(args.raw_path, author_ids):
if a['org']:
institution_ids.add(a['org'])
json.dump(a, f, ensure_ascii=False)
f.write('\n')
print(f'学者抽取完成,已保存到{f.name}')
print(f'机构数{len(institution_ids)}')
print('正在抽取期刊...')
with open(output_path / 'mag_venues.txt', 'w', encoding='utf8') as f:
for v in extract_venues(args.raw_path, venue_ids):
json.dump(v, f, ensure_ascii=False)
f.write('\n')
print(f'期刊抽取完成,已保存到{f.name}')
print('正在抽取机构...')
with open(output_path / 'mag_institutions.txt', 'w', encoding='utf8') as f:
for i in extract_institutions(args.raw_path, institution_ids):
json.dump(i, f, ensure_ascii=False)
f.write('\n')
print(f'机构抽取完成,已保存到{f.name}')
print('正在抽取领域...')
fields.remove(CS)
fields = sorted(fields)
with open(output_path / 'mag_fields.txt', 'w', encoding='utf8') as f:
for i, field in enumerate(fields):
json.dump({'id': i, 'name': field}, f, ensure_ascii=False)
f.write('\n')
print(f'领域抽取完成,已保存到{f.name}')
def main():
parser = argparse.ArgumentParser(description='抽取OAG数据集计算机领域的子集')
parser.add_argument('raw_path', help='原始zip文件所在目录')
args = parser.parse_args()
extract(args)
if __name__ == '__main__':
main()