forked from nyu-mll/PRPN-Analysis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReadingNetwork.py
executable file
·54 lines (42 loc) · 2.06 KB
/
ReadingNetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from LSTMCell import LSTMCell
from blocks import softmax
class ReadingNetwork(nn.Module):
def __init__(self, ninp, nout, nslots, dropout, idropout):
super(ReadingNetwork, self).__init__()
self.ninp = ninp
self.nout = nout
self.nslots = nslots
self.drop = nn.Dropout(dropout)
self.memory_rnn = LSTMCell(ninp, nout)
self.projector_summ = nn.Sequential(nn.Dropout(idropout),
nn.Linear(ninp + nout, nout),
nn.Dropout(idropout))
def forward(self, input, memory, gate_time, rmask):
memory_h, memory_c = memory
# attention
selected_memory_h, selected_memory_c, attention0 = self.attention(input, memory_h, memory_c,
gate=gate_time)
# recurrent
input = self.drop(input)
h_i, c_i = self.memory_rnn(input, (selected_memory_h * rmask, selected_memory_c))
# updata memory
memory_h = torch.cat([h_i[:, None, :], memory_h[:, :-1, :]], dim=1)
memory_c = torch.cat([c_i[:, None, :], memory_c[:, :-1, :]], dim=1)
return h_i, (memory_h, memory_c), attention0
def attention(self, input, memory_h, memory_c, gate=None):
# select memory to use
key = self.projector_summ(torch.cat([input, memory_h[:, 0, :]], dim=1))
logits = torch.bmm(memory_h, key[:, :, None]).squeeze(2)
logits = logits / math.sqrt(self.nout)
attention = softmax(logits, gate)
selected_memory_h = (memory_h * attention[:, :, None]).sum(dim=1)
selected_memory_c = (memory_c * attention[:, :, None]).sum(dim=1)
return selected_memory_h, selected_memory_c, attention
def init_hidden(self, bsz):
weight = next(self.parameters()).data
return Variable(weight.new(bsz, self.nslots, self.nout).zero_()), \
Variable(weight.new(bsz, self.nslots, self.nout).zero_())