Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssm_enhancement #689

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion axlearn/common/quantized_dot_general/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import jax
from absl import logging
from aqt.jax.v2 import aqt_dot_general
from aqt.jax.v2 import utils as aqt_utils
from jax import numpy as jnp
from jax.lax import DotDimensionNumbers, Precision
from jax.typing import DTypeLike
Expand Down Expand Up @@ -79,7 +80,7 @@ def __call__(
dimension_numbers: DotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: Optional[DTypeLike] = None,
context: aqt_dot_general.Context = aqt_dot_general.Context(key=None, train_step=None),
context: aqt_utils.Context = aqt_utils.Context(key=None, train_step=None),
) -> Tensor:
...

Expand Down
71 changes: 47 additions & 24 deletions axlearn/common/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,30 +984,6 @@ def __init__(self, cfg: Config, *, parent: Module):
self._add_child("b_norm", cfg.b_norm.set(input_dim=cfg.state_dim))
self._add_child("c_norm", cfg.c_norm.set(input_dim=cfg.state_dim))

def _ssm_parameters(self, inputs: Tensor) -> MambaMixerLayer.SSMParameters:
"""Computes layer-normed versions of the input-dependent SSM parameters.

Args:
inputs: [batch_size, seq_len, inner_dim]

Returns:
An instance of MambaMixerLayer.SSMParameters.
"""
cfg = self.config
x_dbl = self.x_proj(inputs) # [batch_size, seq_len, dt_rank, state_dim*2]
dt, b, c = jnp.split(
x_dbl,
(
self.dt_rank,
self.dt_rank + cfg.state_dim,
),
axis=-1,
)
dt, b, c = self.dt_norm(dt), self.b_norm(b), self.c_norm(c)
delta = jax.nn.softplus(self.dt_proj(dt)) # [batch_size, seq_len, inner_dim]
a = -jnp.exp(_at_least_float32(self.parameters["log_a"])).astype(inputs.dtype)
return MambaMixerLayer.SSMParameters(a=a, b=b, c=c, delta=delta, d=self.parameters["d"])


class BaseSSMLayer(BaseLayer):
"""An abstract class representing SSM layers.
Expand Down Expand Up @@ -1445,3 +1421,50 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
for i in range(cfg.num_layers)
]
super().__init__(cfg.set(layer=layers), parent=parent)


class HybridMambaRecurrence(BaseMambaRecurrence):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for these new classes. Do people use either the hybrid recurrences or alternative recurrences defined below? Is there evidence that they are useful empirically? If not, I think it would be simpler to leave these classes out for now, and if necessary let people define them in downstream experiment files which import axlearn.common.ssm.

Copy link
Author

@vishesh9131 vishesh9131 Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @swiseman , Thank you for your valuable input. I've reviewed the hybrid recurrences and alternative recurrences, and it seems that they haven't been used extensively in practice. Based on your benchmarking results, it appears that the AssociativeScanMambaRecurrence is more efficient than the HybridMambaRecurrence.

Given the lack of empirical evidence and the performance advantage of the AssociativeScanMambaRecurrence, I agree that it's reasonable to remove the HybridMambaRecurrence and other less-used recurrences from the core axlearn.common.ssm module for now.

This will simplify the codebase and make it easier for users to understand and use. If there's a strong need for these recurrences in the future, they can be defined in downstream experiment files as you suggested.

-Vishesh

"""A layer that combines different recurrence methods to leverage their strengths."""

@config_class
class Config(BaseMambaRecurrence.Config):
"""Configures a HybridMambaRecurrence."""

primary_recurrence: BaseMambaRecurrence = LinearScanMambaRecurrence.default_config()
secondary_recurrence: BaseMambaRecurrence = AssociativeScanMambaRecurrence.default_config()

def __init__(self, cfg: Config, *, parent: Module):
super().__init__(cfg, parent=parent)
self._add_child("primary_recurrence", cfg.primary_recurrence)
self._add_child("secondary_recurrence", cfg.secondary_recurrence)

def forward(
self, x: Tensor, *, a: Tensor, b: Tensor, c: Tensor, delta: Tensor, d: Tensor
) -> BaseMambaRecurrence.Output:
primary_output = self.primary_recurrence(x, a=a, b=b, c=c, delta=delta, d=d)
secondary_output = self.secondary_recurrence(x, a=a, b=b, c=c, delta=delta, d=d)
combined_data = (primary_output.data + secondary_output.data) / 2
combined_states = (
(primary_output.states + secondary_output.states) / 2
if primary_output.states is not None and secondary_output.states is not None
else None
)
return BaseMambaRecurrence.Output(data=combined_data, states=combined_states)


class AlternativeMambaRecurrence(BaseMambaRecurrence):
"""A layer that implements an alternative recurrence method."""

def forward(
self, x: Tensor, *, a: Tensor, b: Tensor, c: Tensor, delta: Tensor, d: Tensor
) -> BaseMambaRecurrence.Output:
# Implement an alternative recurrence method here.
# For demonstration, let's use a simple weighted sum of inputs and parameters.
weighted_sum = jnp.einsum("btd,sd->btsd", x, a) + jnp.einsum("bts,sd->btsd", b, c)
y = jnp.sum(weighted_sum, axis=-2) + d * x
states = (
weighted_sum
if self.config.output_mode == MambaRecurrenceOutputMode.OUTPUTS_AND_STATES
else None
)
return BaseMambaRecurrence.Output(data=y, states=states)
166 changes: 166 additions & 0 deletions axlearn/common/ssm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
from axlearn.common.config import InstantiableConfig
from axlearn.common.module import functional as F
from axlearn.common.ssm import (
AlternativeMambaRecurrence,
AssociativeScanMambaRecurrence,
BlockResidualMode,
HybridMambaRecurrence,
JambaMambaBlock,
LinearScanMambaRecurrence,
MambaBlock,
Expand Down Expand Up @@ -509,6 +511,64 @@ def test_prefill_states(self, dtype: jnp.dtype):

assert_allclose(decoder_output, forward_outputs.data, atol=1e-6)

@parameterized.parameters(jnp.float32, jnp.bfloat16)
vishesh9131 marked this conversation as resolved.
Show resolved Hide resolved
def test_hybrid_recurrence(self, dtype: jnp.dtype):
model_dim = 4
state_dim = 16
cfg = MambaMixerLayer.default_config().set(
input_dim=model_dim,
state_dim=state_dim,
cache_dtype=dtype,
dtype=dtype,
)
layer: MambaMixerLayer = cfg.set(name="test").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))
layer_params = cast_floats(layer_params, to_dtype=dtype)
batch_size, tgt_len = 2, 6
query = jax.random.normal(
jax.random.PRNGKey(1),
[batch_size, tgt_len, model_dim],
dtype=dtype,
)
inputs = dict(query=query)
forward_outputs, _ = F(
layer,
state=layer_params,
is_training=False,
prng_key=jax.random.PRNGKey(2),
inputs=inputs,
)
assert forward_outputs.data.shape == (batch_size, tgt_len, model_dim)

@parameterized.parameters(jnp.float32, jnp.bfloat16)
def test_alternative_recurrence(self, dtype: jnp.dtype):
model_dim = 4
state_dim = 16
cfg = MambaMixerLayer.default_config().set(
input_dim=model_dim,
state_dim=state_dim,
cache_dtype=dtype,
dtype=dtype,
)
layer: MambaMixerLayer = cfg.set(name="test").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))
layer_params = cast_floats(layer_params, to_dtype=dtype)
batch_size, tgt_len = 2, 6
query = jax.random.normal(
jax.random.PRNGKey(1),
[batch_size, tgt_len, model_dim],
dtype=dtype,
)
inputs = dict(query=query)
forward_outputs, _ = F(
layer,
state=layer_params,
is_training=False,
prng_key=jax.random.PRNGKey(2),
inputs=inputs,
)
assert forward_outputs.data.shape == (batch_size, tgt_len, model_dim)


def _test_extend_step(layer_cfg: InstantiableConfig, *, model_dim: int, dtype: jnp.dtype):
"""Tests extend for composite layers."""
Expand Down Expand Up @@ -834,6 +894,58 @@ def test_prefill(self, block_klass: MambaBlock, dtype: jnp.dtype):

_test_prefill_states(cfg, model_dim=model_dim, dtype=dtype)

@parameterized.product(
block_klass=(MambaBlock, JambaMambaBlock),
dtype=(jnp.float32, jnp.bfloat16),
)
def test_hybrid_recurrence_in_block(self, block_klass: MambaBlock, dtype: jnp.dtype):
model_dim = 16
state_dim = 16
hidden_dim = 32
num_layers = 3

cfg = StackedSSMLayer.default_config().set(
input_dim=model_dim,
num_layers=num_layers,
layer=block_klass.default_config().set(
state_dim=state_dim,
mamba_layer=MambaMixerLayer.default_config().set(
recurrence=HybridMambaRecurrence.default_config()
),
),
)
cfg.layer.mamba_layer.set(dtype=dtype, cache_dtype=None)
if hasattr(cfg.layer, "feed_forward"):
cfg.layer.feed_forward.hidden_dim = hidden_dim

_test_extend_step(cfg, model_dim=model_dim, dtype=dtype)

@parameterized.product(
block_klass=(MambaBlock, JambaMambaBlock),
dtype=(jnp.float32, jnp.bfloat16),
)
def test_alternative_recurrence_in_block(self, block_klass: MambaBlock, dtype: jnp.dtype):
model_dim = 16
state_dim = 16
hidden_dim = 32
num_layers = 3

cfg = StackedSSMLayer.default_config().set(
input_dim=model_dim,
num_layers=num_layers,
layer=block_klass.default_config().set(
state_dim=state_dim,
mamba_layer=MambaMixerLayer.default_config().set(
recurrence=AlternativeMambaRecurrence.default_config()
),
),
)
cfg.layer.mamba_layer.set(dtype=dtype, cache_dtype=None)
if hasattr(cfg.layer, "feed_forward"):
cfg.layer.feed_forward.hidden_dim = hidden_dim

_test_extend_step(cfg, model_dim=model_dim, dtype=dtype)


class StackedMixedSSMTransformerTest(TestCase):
"""Tests that mixing SSM layers and transformer layers behaves as expected."""
Expand Down Expand Up @@ -927,3 +1039,57 @@ def test_prefill(self, dtype: jnp.dtype):
cfg.layer.self_attention.attention.num_heads = num_heads
cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None)
_test_prefill_states(cfg, model_dim=model_dim, dtype=dtype)

@parameterized.parameters(jnp.float32, jnp.bfloat16)
def test_hybrid_recurrence_in_mixed_layer(self, dtype: jnp.dtype):
model_dim = 16
state_dim = 16
num_heads = 4
hidden_dim = 32
num_layers = 4
cfg = StackedMixedSSMTransformerLayer.default_config().set(
input_dim=model_dim,
num_layers=num_layers,
transformer_layer_period=3,
transformer_layer_offset=1,
ssm_layer=JambaMambaBlock.default_config().set(
state_dim=state_dim,
mamba_layer=MambaMixerLayer.default_config().set(
recurrence=HybridMambaRecurrence.default_config()
),
),
dtype=dtype,
)
cfg.ssm_layer.feed_forward.hidden_dim = hidden_dim
cfg.ssm_layer.mamba_layer.set(dtype=dtype, cache_dtype=None)
cfg.layer.feed_forward.hidden_dim = hidden_dim
cfg.layer.self_attention.attention.num_heads = num_heads
cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None)
_test_extend_step(cfg, model_dim=model_dim, dtype=dtype)

@parameterized.parameters(jnp.float32, jnp.bfloat16)
def test_alternative_recurrence_in_mixed_layer(self, dtype: jnp.dtype):
model_dim = 16
state_dim = 16
num_heads = 4
hidden_dim = 32
num_layers = 4
cfg = StackedMixedSSMTransformerLayer.default_config().set(
input_dim=model_dim,
num_layers=num_layers,
transformer_layer_period=3,
transformer_layer_offset=1,
ssm_layer=JambaMambaBlock.default_config().set(
state_dim=state_dim,
mamba_layer=MambaMixerLayer.default_config().set(
recurrence=AlternativeMambaRecurrence.default_config()
),
),
dtype=dtype,
)
cfg.ssm_layer.feed_forward.hidden_dim = hidden_dim
cfg.ssm_layer.mamba_layer.set(dtype=dtype, cache_dtype=None)
cfg.layer.feed_forward.hidden_dim = hidden_dim
cfg.layer.self_attention.attention.num_heads = num_heads
cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None)
_test_extend_step(cfg, model_dim=model_dim, dtype=dtype)