Skip to content

Commit

Permalink
feat: add Reformer modules;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Jun 15, 2024
1 parent 487f8f8 commit b5f95ca
Show file tree
Hide file tree
Showing 5 changed files with 1,156 additions and 0 deletions.
25 changes: 25 additions & 0 deletions pypots/nn/modules/reformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
The package including the modules of Reformer.
Refer to the paper
`Nikita Kitaev, Lukasz Kaiser, and Anselm Levskaya.
"Reformer: The Efficient Transformer".
In International Conference on Learning Representations, 2020.
<https://openreview.net/pdf?id=rkgNKkHtvB>`_
Notes
-----
This implementation is inspired by the official one https://github.com/google/trax/tree/master/trax/models/reformer and
https://github.com/lucidrains/reformer-pytorch
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause


from .autoencoder import ReformerEncoder

__all__ = [
"ReformerEncoder",
]
54 changes: 54 additions & 0 deletions pypots/nn/modules/reformer/autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from .layers import ReformerLayer


class ReformerEncoder(nn.Module):
def __init__(
self,
n_steps,
n_layers,
d_model,
n_heads,
bucket_size,
n_hashes,
causal,
d_ffn,
dropout,
):
super().__init__()

assert (
n_steps % (bucket_size * 2) == 0
), f"Sequence length ({n_steps}) needs to be divisible by target bucket size x 2 - {bucket_size * 2}"

self.enc_layer_stack = nn.ModuleList(
[
ReformerLayer(
d_model,
n_heads,
bucket_size,
n_hashes,
causal,
d_ffn,
dropout,
)
for _ in range(n_layers)
]
)

def forward(self, x: torch.Tensor):
enc_output = x

for layer in self.enc_layer_stack:
enc_output = layer(enc_output)

return enc_output
52 changes: 52 additions & 0 deletions pypots/nn/modules/reformer/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch
import torch.nn as nn

from .lsh_attention import LSHSelfAttention
from ..transformer import PositionWiseFeedForward


class ReformerLayer(nn.Module):
def __init__(
self,
d_model,
n_heads,
bucket_size,
n_hashes,
causal,
d_ffn,
dropout,
):
super().__init__()
self.attn = LSHSelfAttention(
dim=d_model,
heads=n_heads,
bucket_size=bucket_size,
n_hashes=n_hashes,
causal=causal,
)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout)

def forward(
self,
enc_input: torch.Tensor,
):
enc_output = self.attn(enc_input)

# apply dropout and residual connection
enc_output = self.dropout(enc_output)
enc_output += enc_input

# apply layer-norm
enc_output = self.layer_norm(enc_output)

enc_output = self.pos_ffn(enc_output)
return enc_output
Loading

0 comments on commit b5f95ca

Please sign in to comment.