diff --git a/README.md b/README.md index 03e1f47..98ae209 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,6 @@ model = CompressiveTransformer( gru_gated_residual = True, # whether to gate the residual intersection, from 'Stabilizing Transformer for RL' paper mogrify_gru = False, # experimental feature that adds a mogrifier for the update and residual before gating by the GRU memory_layers = range(6, 13), # specify which layers to use long-range memory, from 'Do Transformers Need LR Memory' paper - one_head_kv = True, # share one key/value head for all queries, from Shazeers 'One Write-Head is All You Need' ff_glu = True # use GLU variant for feedforward ) @@ -91,37 +90,37 @@ sample = model.generate(prime, 4096) ```bibtex @misc{rae2019compressive, - title={Compressive Transformers for Long-Range Sequence Modelling}, - author={Jack W. Rae and Anna Potapenko and Siddhant M. Jayakumar and Timothy P. Lillicrap}, - year={2019}, - eprint={1911.05507}, - archivePrefix={arXiv}, - primaryClass={cs.LG} + title = {Compressive Transformers for Long-Range Sequence Modelling}, + author = {Jack W. Rae and Anna Potapenko and Siddhant M. Jayakumar and Timothy P. Lillicrap}, + year = {2019}, + eprint = {1911.05507}, + archivePrefix = {arXiv}, + primaryClass = {cs.LG} } ``` ```bibtex @misc{parisotto2019stabilizing, - title={Stabilizing Transformers for Reinforcement Learning}, - author={Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell}, - year={2019}, - eprint={1910.06764}, - archivePrefix={arXiv}, - primaryClass={cs.LG} + title = {Stabilizing Transformers for Reinforcement Learning}, + author = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell}, + year = {2019}, + eprint = {1910.06764}, + archivePrefix = {arXiv}, + primaryClass = {cs.LG} } ``` ```bibtex @inproceedings{rae-razavi-2020-transformers, - title = "Do Transformers Need Deep Long-Range Memory?", - author = "Rae, Jack and + title = "Do Transformers Need Deep Long-Range Memory?", + author = "Rae, Jack and Razavi, Ali", booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", - month = jul, - year = "2020", + month = jul, + year = "2020", address = "Online", publisher = "Association for Computational Linguistics", - url = "https://www.aclweb.org/anthology/2020.acl-main.672" + url = "https://www.aclweb.org/anthology/2020.acl-main.672" } ``` @@ -152,3 +151,14 @@ sample = model.generate(prime, 4096) url = {https://arxiv.org/abs/1909.11942} } ``` + +```bibtex +@misc{ding2021erniedoc, + title = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer}, + author = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang}, + year = {2021}, + eprint = {2012.15688}, + archivePrefix = {arXiv}, + primaryClass = {cs.CL} +} +``` diff --git a/compressive_transformer_pytorch/compressive_transformer_pytorch.py b/compressive_transformer_pytorch/compressive_transformer_pytorch.py index 33c4caf..a6aafd6 100644 --- a/compressive_transformer_pytorch/compressive_transformer_pytorch.py +++ b/compressive_transformer_pytorch/compressive_transformer_pytorch.py @@ -167,7 +167,7 @@ def forward(self, x, **kwargs): # attention. class SelfAttention(nn.Module): - def __init__(self, dim, seq_len, mem_len, cmem_len, cmem_ratio = 4, heads = 8, attn_dropout = 0., dropout = 0., reconstruction_attn_dropout = 0., one_kv_head = False): + def __init__(self, dim, seq_len, mem_len, cmem_len, cmem_ratio = 4, heads = 8, attn_dropout = 0., dropout = 0., reconstruction_attn_dropout = 0.): super().__init__() assert (dim % heads) == 0, 'dimension must be divisible by the number of heads' @@ -182,9 +182,7 @@ def __init__(self, dim, seq_len, mem_len, cmem_len, cmem_ratio = 4, heads = 8, a self.compress_mem_fn = ConvCompress(dim, cmem_ratio) self.to_q = nn.Linear(dim, dim, bias = False) - - kv_dim = self.dim_head if one_kv_head else dim - self.to_kv = nn.Linear(dim, kv_dim * 2, bias = False) + self.to_kv = nn.Linear(dim, dim * 2, bias = False) self.to_out = nn.Linear(dim, dim) self.attn_dropout = nn.Dropout(attn_dropout) @@ -291,7 +289,28 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me # transformer class CompressiveTransformer(nn.Module): - def __init__(self, num_tokens, dim, seq_len, depth, emb_dim = None, memory_layers = None, mem_len = None, cmem_len = None, cmem_ratio = 4, heads = 8, gru_gated_residual = True, mogrify_gru = False, attn_dropout = 0., ff_glu = False, ff_dropout = 0., attn_layer_dropout = 0., reconstruction_attn_dropout = 0., reconstruction_loss_weight = 1., one_kv_head = False): + def __init__( + self, + num_tokens, + dim, + seq_len, + depth, + emb_dim = None, + memory_layers = None, + enhanced_recurrence = True, + mem_len = None, + cmem_len = None, + cmem_ratio = 4, + heads = 8, + gru_gated_residual = True, + mogrify_gru = False, + attn_dropout = 0., + ff_glu = False, + ff_dropout = 0., + attn_layer_dropout = 0., + reconstruction_attn_dropout = 0., + reconstruction_loss_weight = 1. + ): super().__init__() emb_dim = default(emb_dim, dim) mem_len = default(mem_len, seq_len) @@ -306,6 +325,7 @@ def __init__(self, num_tokens, dim, seq_len, depth, emb_dim = None, memory_layer self.depth = depth self.memory_layers = list(memory_layers) + self.enhanced_recurrence = enhanced_recurrence self.token_emb = nn.Embedding(num_tokens, emb_dim) self.to_model_dim = nn.Identity() if emb_dim == dim else nn.Linear(emb_dim, dim) @@ -320,7 +340,7 @@ def __init__(self, num_tokens, dim, seq_len, depth, emb_dim = None, memory_layer wrapper = partial(GRUGating, dim, mogrify = mogrify_gru) if gru_gated_residual else Residual - self.attn_layers = nn.ModuleList([wrapper(PreNorm(dim, SelfAttention(dim, seq_len, mem_len, cmem_len, cmem_ratio, heads, dropout = attn_layer_dropout, attn_dropout = attn_dropout, reconstruction_attn_dropout = reconstruction_attn_dropout, one_kv_head = one_kv_head))) for _ in range(depth)]) + self.attn_layers = nn.ModuleList([wrapper(PreNorm(dim, SelfAttention(dim, seq_len, mem_len, cmem_len, cmem_ratio, heads, dropout = attn_layer_dropout, attn_dropout = attn_dropout, reconstruction_attn_dropout = reconstruction_attn_dropout))) for _ in range(depth)]) self.ff_layers = nn.ModuleList([wrapper(PreNorm(dim, FeedForward(dim, dropout = ff_dropout, glu = ff_glu))) for _ in range(depth)]) self.reconstruction_loss_weight = reconstruction_loss_weight @@ -347,6 +367,10 @@ def forward(self, x, memories = None, mask = None): next_cmem = [] aux_loss = torch.tensor(0., requires_grad = True, **to(x)) + if self.enhanced_recurrence: + mem = torch.roll(mem, -1, 0) + cmem = torch.roll(cmem, -1, 0) + mem_iter, cmem_iter = map(iterate_tensor, (mem, cmem)) for ind, (attn, ff) in enumerate(zip(self.attn_layers, self.ff_layers)): diff --git a/setup.py b/setup.py index e0f1d9d..f8f0417 100644 --- a/setup.py +++ b/setup.py @@ -3,13 +3,18 @@ setup( name = 'compressive-transformer-pytorch', packages = find_packages(exclude=['examples']), - version = '0.3.21', + version = '0.4.0', license='MIT', description = 'Implementation of Compressive Transformer in Pytorch', author = 'Phil Wang', author_email = 'lucidrains@gmail.com', url = 'https://github.com/lucidrains/compressive-transformer-pytorch', - keywords = ['attention', 'artificial intelligence', 'transformer', 'deep learning'], + keywords = [ + 'attention', + 'artificial intelligence', + 'transformer', + 'deep learning' + ], install_requires=[ 'torch', 'mogrifier'