diff --git a/RIM.py b/RIM.py index a06fcb9..57e773f 100644 --- a/RIM.py +++ b/RIM.py @@ -182,6 +182,7 @@ def input_attention_mask(self, x, h): attention_scores = torch.mean(attention_scores, dim = 1) mask_ = torch.zeros(x.size(0), self.num_units).to(self.device) + attention_scores = nn.Softmax(dim = -1)(attention_scores) not_null_scores = attention_scores[:,:, 0] topk1 = torch.topk(not_null_scores,self.k, dim = 1) row_index = np.arange(x.size(0)) @@ -189,7 +190,7 @@ def input_attention_mask(self, x, h): mask_[row_index, topk1.indices.view(-1)] = 1 - attention_probs = self.input_dropout(nn.Softmax(dim = -1)(attention_scores)) + attention_probs = self.input_dropout(attention_scores) inputs = torch.matmul(attention_probs, value_layer) * mask_.unsqueeze(2) return inputs, mask_ @@ -345,4 +346,4 @@ def forward(self, x, h = None, c = None): if cs is not None: cs = torch.stack(cs, dim = 0) return x, hs, cs - return x, hs \ No newline at end of file + return x, hs