-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdatasets.py
29 lines (24 loc) · 933 Bytes
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from torch.utils.data import Dataset
import torch
import os
class HANDataset(Dataset):
"""
A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
"""
def __init__(self, data_folder, split):
"""
:param data_folder: folder where data files are stored
:param split: split, one of 'TRAIN' or 'TEST'
"""
split = split.upper()
assert split in {'TRAIN', 'TEST'}
self.split = split
# Load data
self.data = torch.load(os.path.join(data_folder, split + '_data.pth.tar'))
def __getitem__(self, i):
return torch.LongTensor(self.data['docs'][i]), \
torch.LongTensor([self.data['sentences_per_document'][i]]), \
torch.LongTensor(self.data['words_per_sentence'][i]), \
torch.LongTensor([self.data['labels'][i]])
def __len__(self):
return len(self.data['labels'])