-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbuild_field.py
39 lines (31 loc) · 1.39 KB
/
build_field.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# -*- coding: utf-8 -*-
import torch
import argparse
import os
from torchtext.data import Field, TabularDataset
def get_parser():
parser = argparse.ArgumentParser(description='build data field files by corpus')
parser.add_argument('corpus_path', type=str,
help='specify the path of the corpus file')
parser.add_argument('-s', '--saving-directory', type=str, default='./field',
help='specify saving directory for field files, default ./field')
return parser
def build_field(corpus_path, save_dir):
label_field = Field(sequential=False, use_vocab=False, batch_first=True, dtype=torch.float)
text_field = Field(tokenize='spacy', tokenizer_language='en_core_web_sm',
lower=True, include_lengths=True, batch_first=True)
corpus = TabularDataset(path=corpus_path,
format='CSV',
fields=[('text', text_field)],
skip_header=True)
text_field.build_vocab(corpus, min_freq=3)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
torch.save(text_field, save_dir + '/text_field.pth')
torch.save(label_field, save_dir + '/label_field.pth')
def main(args):
build_field(args.corpus_path, args.saving_directory)
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
main(args)