From c58eaeb9ff8590b23a91f6f8fc8754d53d1e4925 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 20 Mar 2020 00:14:51 +0800 Subject: [PATCH] Update RIM.py This is based on the description given in the paper: "Based on the softmax values in (4), we select the top k_A RIMs (out of the total K RIMs) to be activated for each step, which have the least attention on the null input,....." Hence we have to perform softmax prior to the construction of the mask. --- RIM.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/RIM.py b/RIM.py index a06fcb9..57e773f 100644 --- a/RIM.py +++ b/RIM.py @@ -182,6 +182,7 @@ def input_attention_mask(self, x, h): attention_scores = torch.mean(attention_scores, dim = 1) mask_ = torch.zeros(x.size(0), self.num_units).to(self.device) + attention_scores = nn.Softmax(dim = -1)(attention_scores) not_null_scores = attention_scores[:,:, 0] topk1 = torch.topk(not_null_scores,self.k, dim = 1) row_index = np.arange(x.size(0)) @@ -189,7 +190,7 @@ def input_attention_mask(self, x, h): mask_[row_index, topk1.indices.view(-1)] = 1 - attention_probs = self.input_dropout(nn.Softmax(dim = -1)(attention_scores)) + attention_probs = self.input_dropout(attention_scores) inputs = torch.matmul(attention_probs, value_layer) * mask_.unsqueeze(2) return inputs, mask_ @@ -345,4 +346,4 @@ def forward(self, x, h = None, c = None): if cs is not None: cs = torch.stack(cs, dim = 0) return x, hs, cs - return x, hs \ No newline at end of file + return x, hs