Skip to content

Commit

Permalink
add ability to read the state at a layer earlier than the layer in wh…
Browse files Browse the repository at this point in the history
…ich it is written to next
  • Loading branch information
lucidrains committed Apr 18, 2023
1 parent 528a30a commit 8f0c925
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from random import random
from functools import wraps, partial
from collections import namedtuple
from collections import namedtuple, defaultdict
from packaging import version

from einops import rearrange
Expand All @@ -24,6 +24,9 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def all_unique(arr):
return len(arr) == len(set(arr))

def eval_decorator(fn):
def inner(self, *args, **kwargs):
was_training = self.training
Expand Down Expand Up @@ -221,18 +224,23 @@ def __init__(
# since each read should be followed by a write, just store cache in the container

self.cache = None
self.next_read_state = None

def read(
def set_next_read_state(
self,
x,
*,
states = None,
states
):
# use initial state if no states were passed in

if not exists(states):
states = self.init_state

self.next_read_state = (states,)

def read(self, x):
assert exists(self.next_read_state), 'states to be read must be set with .set_next_read_state'

states, = self.next_read_state
self.next_read_state = None

# pre norm state for attention

normed_states = self.state_norm(states)
Expand Down Expand Up @@ -523,8 +531,10 @@ def __init__(
heads = 8,
qk_rmsnorm = False,
qk_rmsnorm_scale = 8,
use_flash_attn = False,
num_state_vectors = 0,
use_flash_attn = False
num_external_state_reads = 0,
state_read_before_write = True # this will be defaulted to on as in the paper, but will be turned off in the case the researcher wants to test out reading the state at a lower layer
):
super().__init__()
inner_dim = dim_head * heads
Expand All @@ -541,11 +551,17 @@ def __init__(
self.block_width = block_width
self.is_recurrent_layer = num_state_vectors > 0

self.to_out = nn.Linear(inner_dim * (2 if self.is_recurrent_layer else 1), dim, bias = False)
# decide how many states this attention layer is going to read from

num_state_reads = int(self.is_recurrent_layer and state_read_before_write) + num_external_state_reads

self.to_out = nn.Linear(inner_dim * (1 + num_state_reads), dim, bias = False)

if not self.is_recurrent_layer:
return

self.state_read_before_write = state_read_before_write

self.state_container = StateContainer(
dim,
dim_head = dim_head,
Expand All @@ -567,7 +583,7 @@ def forward(
xpos_scale = None,
attn_mask = None,
xl_memories: Optional[torch.Tensor] = None,
states: Optional[torch.Tensor] = None
read_from_state_containers: List[StateContainer] = []
):
batch, seq_len, _, width, device = *x.shape, self.block_width, self.device

Expand Down Expand Up @@ -615,20 +631,29 @@ def forward(

# early return if not a recurrent layer

if not self.is_recurrent_layer:
if not self.is_recurrent_layer and len(read_from_state_containers) == 0:
return self.to_out(out), memories, None

# read from the states ...
# whether to read from own state container, default to on, but may pass in more

if self.is_recurrent_layer and self.state_read_before_write:
read_from_state_containers = [self.state_container, *read_from_state_containers]

for read_state_container in read_from_state_containers:
# read from the states ...

to_state_out = self.state_container.read(x, states = states)
to_state_out = read_state_container.read(x)

# and concat it to the output of self-attention
# and concat it to the output of self-attention

out = torch.cat((out, to_state_out), dim = -1)
out = torch.cat((out, to_state_out), dim = -1)

# then write to the states as well
new_states = None

new_states = self.state_container.write(memories = memories)
if self.is_recurrent_layer:
# then write to the states as well if need be

new_states = self.state_container.write(memories = memories)

return self.to_out(out), memories, new_states

Expand All @@ -649,49 +674,84 @@ def __init__(
max_seq_len = 1024,
block_width = 512,
recurrent_layers: Optional[Tuple[int, ...]] = None,
read_recurrent_layers: Optional[Tuple[int, ...]] = None,
num_state_vectors = None,
ignore_index = -100,
use_flash_attn = False
):
super().__init__()
num_state_vectors = default(num_state_vectors, block_width)

# set recurrent layers

recurrent_layers = default(recurrent_layers, (depth // 2,)) # default to one recurent layer at middle of the network
self.recurrent_layers = set(recurrent_layers)

assert all([0 < layer <= depth for layer in recurrent_layers])
assert all([0 < layer <= depth for layer in recurrent_layers]), f'recurrent layers must range from 1 to the depth {depth}'
assert all_unique(recurrent_layers), 'recurrent layers must be all unique. no duplicate layers'

self.recurrent_layers = recurrent_layers

# set read recurrent layers

read_recurrent_layers = default(read_recurrent_layers, recurrent_layers)

assert all([read_layer <= write_layer for read_layer, write_layer in zip(read_recurrent_layers, recurrent_layers)]), 'the recurrent read layer must be always less than or equal to the write layer'
assert all([0 < layer <= depth for layer in read_recurrent_layers])
assert len(read_recurrent_layers) == len(recurrent_layers)

self.read_recurrent_layers = read_recurrent_layers

# token embedding

self.token_emb = nn.Embedding(num_tokens, dim)

self.rotary_pos_emb = RotaryEmbedding(dim = dim_head)

self.layers = nn.ModuleList([])

self.write_to_read_map = {write_layer: read_layer for write_layer, read_layer in zip(recurrent_layers, read_recurrent_layers)}

self.read_state_router = defaultdict(list)

for layer in range(1, depth + 1):
is_recurrent_layer = layer in self.recurrent_layers

layer_num_state_vectors = num_state_vectors if is_recurrent_layer else 0

num_external_state_reads = sum([int(layer == read_layer) for read_layer in read_recurrent_layers])

# only layers with xl memories
# or has recurrence in horizontal direction
# use qk rmsnorm (in paper, they use cosine sim attention, but i think qk rmsnorm is more proven given Vit 22B paper)
# one can also override to use all qk rmsnorm by setting all_layers_qk_rmsnorm = True

qk_rmsnorm = all_layers_qk_rmsnorm or is_recurrent_layer

attn_block = AttentionBlock(
dim,
block_width = block_width,
dim_head = dim_head,
heads = heads,
qk_rmsnorm = qk_rmsnorm,
num_state_vectors = layer_num_state_vectors,
use_flash_attn = use_flash_attn,
num_external_state_reads = num_external_state_reads,
state_read_before_write = False,
)

ff_block = FeedForward(dim, mult = ff_mult)

if is_recurrent_layer:
read_layer = self.write_to_read_map[layer]
self.read_state_router[read_layer].append(attn_block.state_container)

self.layers.append(nn.ModuleList([
AttentionBlock(
dim,
block_width = block_width,
dim_head = dim_head,
heads = heads,
qk_rmsnorm = qk_rmsnorm,
num_state_vectors = layer_num_state_vectors,
use_flash_attn = use_flash_attn
),
FeedForward(dim, mult = ff_mult)
attn_block,
ff_block
]))

# to logits

self.to_logits = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, num_tokens, bias = False)
Expand Down Expand Up @@ -835,27 +895,32 @@ def forward(
next_xl_memories = []
next_states = []

# set the states on the appropriate state containers

for attn, _ in self.layers:
if not attn.is_recurrent_layer:
continue

attn.state_container.set_next_read_state(next(states, None))

# go through layers

for ind, (attn, ff) in enumerate(self.layers):

# determine if the layer requires transformer xl memories

layer = ind + 1
is_state_layer = attn.is_recurrent_layer

# whether to pass in xl memories

attn_kwargs = dict(
rotary_pos_emb = rotary_pos_emb,
xpos_scale = xpos_scale,
attn_mask = attn_mask,
xl_memories = next(xl_memories, None)
xl_memories = next(xl_memories, None),
read_from_state_containers = self.read_state_router[layer]
)

if is_state_layer:
attn_kwargs.update(states = next(states, None))

# attention layer

residual = input_block
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'block-recurrent-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.2',
version = '0.3.3',
license='MIT',
description = 'Block Recurrent Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 8f0c925

Please sign in to comment.