forked from KaiyangZhou/pytorch-vsumm-reinforce
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
21 lines (18 loc) · 746 Bytes
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn
from torch.nn import functional as F
__all__ = ['DSN']
class DSN(nn.Module):
"""Deep Summarization Network"""
def __init__(self, in_dim=1024, hid_dim=256, num_layers=1, cell='lstm'):
super(DSN, self).__init__()
assert cell in ['lstm', 'gru'], "cell must be either 'lstm' or 'gru'"
if cell == 'lstm':
self.rnn = nn.LSTM(in_dim, hid_dim, num_layers=num_layers, bidirectional=True, batch_first=True)
else:
self.rnn = nn.GRU(in_dim, hid_dim, num_layers=num_layers, bidirectional=True, batch_first=True)
self.fc = nn.Linear(hid_dim*2, 1)
def forward(self, x):
h, _ = self.rnn(x)
p = F.sigmoid(self.fc(h))
return p