-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMylstm.py
36 lines (31 loc) · 1.42 KB
/
Mylstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
text_field = torch.load('./field/text_field.pth')
class MyLSTM(nn.Module):
def __init__(self, dimension=128):
super(MyLSTM, self).__init__()
self.embedding = nn.Embedding(len(text_field.vocab), 300)
self.dimension = dimension
self.lstm = nn.LSTM(input_size=300,
hidden_size=dimension,
num_layers=2,
dropout=0.3,
batch_first=True,
bidirectional=True)
self.drop = nn.Dropout(p=0.3)
self.fc = nn.Linear(2*dimension, 6)
def forward(self, text, text_len):
text_emb = self.embedding(text)
packed_input = pack_padded_sequence(text_emb, text_len, batch_first=True, enforce_sorted=False)
packed_output, _ = self.lstm(packed_input)
output, _ = pad_packed_sequence(packed_output, batch_first=True)
out_forward = output[range(len(output)), text_len - 1, :self.dimension]
out_reverse = output[:, 0, self.dimension:]
out_reduced = torch.cat((out_forward, out_reverse), 1)
text_fea = self.drop(out_reduced)
text_fea = self.fc(text_fea)
text_fea = torch.squeeze(text_fea, 1)
text_out = torch.nn.functional.relu(text_fea)
return text_out