Skip to content

Implementation of batch normalization LSTM in pytorch.

License

Notifications You must be signed in to change notification settings

h-jia/batch_normalized_LSTM

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

29 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

An Implementation of Batch Normalization LSTM in Pytorch

Tim Cooijmans etl. Recurrent Batch Normalization(arxiv1603.09025)

Frok from sysuNie

Modified to be compatible with Pytorch 1.0.0

To use:

import torch
import torch.nn as nn
from batch_normalization_LSTM import BNLSTMCell, LSTM


model = LSTM(cell_class=BNLSTMCell, input_size=28, hidden_size=512, batch_first=True, max_length=152)

if __name__ == "__main__":
    size = 28
    dummy = torch.rand(300, 2, size)
    out = model(dummy)
    print(model)
    print(out[0])

About

Implementation of batch normalization LSTM in pytorch.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%