diff --git a/.github/container/Dockerfile.maxtext.amd64 b/.github/container/Dockerfile.maxtext.amd64
index 1d5e92677..8ec4403f1 100644
--- a/.github/container/Dockerfile.maxtext.amd64
+++ b/.github/container/Dockerfile.maxtext.amd64
@@ -14,6 +14,16 @@ ARG SRC_PATH_MAXTEXT
 RUN <<"EOF" bash -ex
 cat ${MANIFEST_FILE}
 get-source.sh -l maxtext -m ${MANIFEST_FILE} 
+EOF
+
+###############################################################################
+## Apply patch
+###############################################################################
+
+ADD maxtext-mha.patch /opt
+RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff
+
+RUN <<"EOF" bash -ex
 cat "${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in
 EOF
 
diff --git a/.github/container/maxtext-mha.patch b/.github/container/maxtext-mha.patch
new file mode 100644
index 000000000..311a81fa3
--- /dev/null
+++ b/.github/container/maxtext-mha.patch
@@ -0,0 +1,130 @@
+diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py
+index 05145c5..9ad5b2c 100644
+--- a/MaxText/layers/attentions.py
++++ b/MaxText/layers/attentions.py
+@@ -182,7 +182,7 @@ class AttentionOp(nn.Module):
+       if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
+         raise ValueError("""Decode not supported with flash attention.
+                            Use `dot_product` instead.""")
+-      return self.cudnn_flash_attention(query, key, value), None, None
++      return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None
+     else:
+       raise ValueError(f'Unexpected attention kernel {self.attention_kernel=}.')
+ 
+@@ -255,58 +255,39 @@ class AttentionOp(nn.Module):
+     x = jnp.transpose(x, axes=(0, 2, 1, 3))
+     return x
+ 
+-  def cudnn_flash_attention(self,
+-                            query: Array,
+-                            key: Array,
+-                            value: Array) -> Array:
++  def cudnn_flash_attention(
++    self,
++    query: Array,
++    key: Array,
++    value: Array,
++    decoder_segment_ids: Array | None,
++    model_mode: str = common_types.MODEL_MODE_TRAIN,
++    ) -> Array:
+     """CUDNN Flash Attention with Transformer Engine.
+-
+-    It is an unstable API. In future release, the API can get changed
+-    A stable flash attention API will be included soon. Currently,
+-    1. It does not support GQA, num_query_heads == num_kv_heads
+-    2. It supports head_dim till 128
+-    GQA support with head_dim=256 will be added soon 
++      1. Stable API, supports GQA
++      2. Supports head_dim till 128; head_dim=256 support will be added soon 
+     """
+-
+     # These imports are only meant to work in a GPU build.
+-    import transformer_engine.jax.fused_attn as fused_attn  # pytype: disable=import-error
+-    from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout  # pytype: disable=import-error
+-    from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available  # pytype: disable=import-error
+-
+-    batch, s_q, n_heads, head_dim = query.shape  # pylint: disable=unused-variable
+-    _, s_kv, _, _ = key.shape
+-
+-    qkv_layout = QKVLayout.BS3HD
+-    attn_mask_type = AttnMaskType.CAUSAL_MASK
+-    attn_bias_type = AttnBiasType.NO_BIAS
+-
+-    has_fused_attn_kernel = is_fused_attn_kernel_available(
+-       self.dtype, self.dtype, qkv_layout,
+-       attn_bias_type,
+-       attn_mask_type,
+-       self.dropout_rate, self.num_query_heads,
+-       self.num_kv_heads, s_q,
+-       s_kv, head_dim)
+-
+-    if not has_fused_attn_kernel:
+-      raise ValueError("Flash attention is not supported for current config i.e. head_dim, seq_len, n_heads etc."
+-          "Please see transformer_engine/common/fused_attn/fused_attn.cpp:NVTE_Fused_Attn_Backend for details")
+-
+-    q = jnp.reshape(query, (*query.shape[:2], 1, *query.shape[-2:]))
+-    k = jnp.reshape(key, (*query.shape[:2], 1, *query.shape[-2:]))
+-    v = jnp.reshape(value, (*query.shape[:2], 1, *query.shape[-2:]))
+-    qkv = jnp.concatenate((q, k, v), axis=2)  # to make it (b, s, 3, h, d)
+-
+-    return fused_attn.self_fused_attn(
+-        qkv=qkv,
+-        bias=None,
+-        mask=jnp.zeros((batch, 1, s_q, s_kv)),  # no padding
+-        seed=None,
+-        attn_bias_type=attn_bias_type,
+-        attn_mask_type=attn_mask_type,
+-        scaling_factor=1.0/math.sqrt(head_dim),
+-        dropout_probability=self.dropout_rate,
+-        is_training=True)
++    from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
++
++    _, _, _, head_dim = query.shape  # pylint: disable=unused-variable
++
++    #generate attn_mask
++    attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
++
++    dpa_layer = DotProductAttention(head_dim=head_dim,
++                  num_attention_heads=self.num_query_heads,
++                  num_gqa_groups=self.num_kv_heads,
++                  attn_mask_type='causal', # 'causal' or 'padding'
++                  attn_bias_type='NO_BIAS', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
++                  attention_dropout=self.dropout_rate,
++                  dropout_rng_name='aqt',
++                  dtype=self.dtype,
++                  float32_logits=self.float32_logits,
++                  qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
++                  scale_factor=1.0/math.sqrt(head_dim),
++                  transpose_batch_sequence=False)
++    return dpa_layer(query, key, value, mask=attn_mask)
+ 
+   def compute_local_attention(self,
+                               attn_weights: Array, 
+diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py
+index d15f764..cfb87a1 100644
+--- a/MaxText/layers/models.py
++++ b/MaxText/layers/models.py
+@@ -90,7 +90,7 @@ class DecoderLayer(nn.Module):
+       dropout_rate=cfg.dropout_rate,
+       name='self_attention',
+       quant=self.quant,
+-      quantize_kvcache=self.quantize_kvcache)
++      quantize_kvcache=cfg.quantize_kvcache)
+ 
+ 
+     attention_lnx = attention_layer(
+diff --git a/requirements.txt b/requirements.txt
+index cae6c73..4b7a214 100644
+--- a/requirements.txt
++++ b/requirements.txt
+@@ -17,8 +17,8 @@ pylint
+ pytest
+ pytype
+ sentencepiece==0.1.97
+-tensorflow-text>=2.13.0
+-tensorflow>=2.13.0
++tensorflow-text==2.13.0
++tensorflow==2.13.0
+ tensorflow-datasets
+ tensorboardx
+ tensorboard-plugin-profile
diff --git a/.github/workflows/_test_maxtext.yaml b/.github/workflows/_test_maxtext.yaml
index c71e334ff..4d2b3fc05 100644
--- a/.github/workflows/_test_maxtext.yaml
+++ b/.github/workflows/_test_maxtext.yaml
@@ -100,7 +100,7 @@ jobs:
           #SBATCH --time=00:10:00
           #SBATCH --output=${{ steps.meta.outputs.LOG_FILE }}
           #SBATCH --export="VOCAB_PATH=gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model,ENROOT_PASSWORD=${{ secrets.GITHUB_TOKEN }},NVCR_TOKEN=${{ secrets.NVCR_TOKEN }}"
-          
+
           # preload enroot container using one task per node
           time srun \
             --ntasks-per-node=1 \
diff --git a/.github/workflows/ngc-release-testing.yaml b/.github/workflows/ngc-release-testing.yaml
index 2cedab6c4..5f6d783c5 100644
--- a/.github/workflows/ngc-release-testing.yaml
+++ b/.github/workflows/ngc-release-testing.yaml
@@ -34,16 +34,22 @@ permissions:
   packages: read # to upload container
 
 jobs:
-
   test-jax:
     if: inputs.JAX_IMAGE != ''
     uses: ./.github/workflows/_test_unit.yaml
     with:
-      IMAGE: ${{ inputs.JAX_IMAGE }}
       TEST_NAME: jax
       EXECUTE: |
-        docker run --shm-size=1g --gpus all ${{ inputs.JAX_IMAGE }} test-jax.sh -b backend-independent | tee test-backend-independent.log
-        docker run --shm-size=1g --gpus all ${{ inputs.JAX_IMAGE }} test-jax.sh -b gpu | tee test-gpu.log
+        docker run -i --shm-size=1g --gpus all \
+        ${{ inputs.JAX_IMAGE }} \
+        bash <<"EOF" |& tee test-backend-independent.log
+          test-jax.sh -b backend-independent 
+        EOF
+        docker run -i --shm-size=1g --gpus all \
+        ${{ inputs.JAX_IMAGE }} \
+        bash <<"EOF" |& tee tee test-gpu.log
+          test-jax.sh -b gpu
+        EOF
       STATISTICS_SCRIPT: |
         errors=$(cat test-*.log | grep -c 'ERROR:' || true)
         failed_tests=$(cat test-*.log | grep -c 'FAILED in' || true)
@@ -76,10 +82,9 @@ jobs:
     if: inputs.LEVANTER_IMAGE  != ''
     uses: ./.github/workflows/_test_unit.yaml
     with:
-      IMAGE: ${{ inputs.LEVANTER_IMAGE }}
       TEST_NAME: levanter
       EXECUTE: |
-        docker run --gpus all --shm-size=1g ${{ needs.build-levanter.outputs.DOCKER_TAG_FINAL }} \
+        docker run --gpus all --shm-size=1g ${{ inputs.LEVANTER_IMAGE }} \
         bash -ec \
         "pip install pytest && PYTHONPATH=/opt/levanter/tests:$PYTHONPATH pytest /opt/levanter/tests" | tee test-levanter.log
       STATISTICS_SCRIPT: |
diff --git a/README.md b/README.md
index facf8ca62..e8d598626 100644
--- a/README.md
+++ b/README.md
@@ -22,7 +22,9 @@
   <tbody>
     <tr>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=base%3D%7BCUDA%2CcuDNN%2CNCCL%2COFED%2CEFA%7D">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=base%3D%7BCUDA%2CcuDNN%2CNCCL%2COFED%2CEFA%7D">
+        </picture>
       </td>
       <td>
         <code>ghcr.io/nvidia/jax:base</code>
@@ -35,7 +37,9 @@
     </tr>
     <tr>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=core%3D%7Bbase%2CJAX%2CFlax%2CTE%7D">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=core%3D%7Bbase%2CJAX%2CFlax%2CTE%7D">
+        </picture>
       </td>
       <td>
         <code>ghcr.io/nvidia/jax:jax</code>
@@ -45,20 +49,36 @@
         <a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-jax-md"><img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-jax-build-arm64.json&logo=docker&label=arm64"></a>
       </td>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-jax-unit-test-V100.json&logo=nvidia&label=V100">
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-jax-unit-test-A100.json&logo=nvidia&label=A100">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-jax-unit-test-V100.json&logo=nvidia&label=V100">
+        </picture>
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-jax-unit-test-A100.json&logo=nvidia&label=A100">
+        </picture>
         <br>
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-te-unit-test-V100.json&logo=nvidia&label=TE%20V100">
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-te-unit-test-A100.json&logo=nvidia&label=TE%20A100">
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-te-multigpu-test.json&logo=nvidia&label=TE%20Multi%20GPU">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-te-unit-test-V100.json&logo=nvidia&label=TE%20V100">
+        </picture>
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-te-unit-test-A100.json&logo=nvidia&label=TE%20A100">
+        </picture>
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-te-multigpu-test.json&logo=nvidia&label=TE%20Multi%20GPU">
+        </picture>
         <br>
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-pallas-unit-test-V100.json&logo=nvidia&label=Pallas V100">
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-pallas-unit-test-A100.json&logo=nvidia&label=Pallas A100">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-pallas-unit-test-V100.json&logo=nvidia&label=Pallas V100">
+        </picture>
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-pallas-unit-test-A100.json&logo=nvidia&label=Pallas A100">
+        </picture>
       </td>
     </tr>
     <tr>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Levanter%3D%7Bcore%2CLevanter%7D">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Levanter%3D%7Bcore%2CLevanter%7D">
+        </picture>
       </td>
       <td>
         <code>ghcr.io/nvidia/jax:levanter</code>
@@ -68,13 +88,19 @@
         <a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-levanter-md"><img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-levanter-build-arm64.json&logo=docker&label=arm64"></a>
       </td>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-levanter-unit-test-V100.json&logo=nvidia&label=V100">
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-levanter-unit-test-A100.json&logo=nvidia&label=A100">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-levanter-unit-test-V100.json&logo=nvidia&label=V100">
+        </picture>
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-levanter-unit-test-A100.json&logo=nvidia&label=A100">
+        </picture>
       </td>
     </tr>
     <tr>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Equinox%3D%7Bcore%2CEquinox%7D">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Equinox%3D%7Bcore%2CEquinox%7D">
+        </picture>
       </td>
       <td>
         <code>ghcr.io/nvidia/jax:equinox</code>
@@ -90,7 +116,9 @@
     </tr>
     <tr>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Triton%3D%7Bcore%2CJAX-Triton%2CTriton%7D">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Triton%3D%7Bcore%2CJAX-Triton%2CTriton%7D">
+        </picture>
       </td>
       <td>
         <code>ghcr.io/nvidia/jax:triton</code>
@@ -100,13 +128,19 @@
         <!-- <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-triton-build-arm64.json&logo=docker&label=arm64"> -->
       </td>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-triton-unit-test-V100.json&logo=nvidia&label=JAX-Triton V100">
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-triton-unit-test-A100.json&logo=nvidia&label=JAX-Triton A100">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-triton-unit-test-V100.json&logo=nvidia&label=JAX-Triton V100">
+        </picture>
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-triton-unit-test-A100.json&logo=nvidia&label=JAX-Triton A100">
+        </picture>
       </td>
     </tr>
     <tr>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Upstream%20T5X%3D%7Bcore%2CT5X%7D">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Upstream%20T5X%3D%7Bcore%2CT5X%7D">
+        </picture>
       </td>
       <td>
         <code>ghcr.io/nvidia/jax:upstream-t5x</code>
@@ -116,12 +150,16 @@
         <a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-upstream-t5x-md"><img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-t5x-build-arm64.json&logo=docker&label=arm64"></a>
       </td>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-upstream-t5x-mgmn-test.json&logo=nvidia&label=A100%20distributed">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-upstream-t5x-mgmn-test.json&logo=nvidia&label=A100%20distributed">
+        </picture>
       </td>
     </tr>
     <tr>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Rosetta%20T5X%3D%7Bcore%2CT5X%7D">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Rosetta%20T5X%3D%7Bcore%2CT5X%7D">
+        </picture>
       </td>
       <td>
         <code>ghcr.io/nvidia/jax:t5x</code>
@@ -131,12 +169,16 @@
         <a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-t5x-md"><img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-t5x-arm64.json&logo=docker&label=arm64"></a>
       </td>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-t5x-mgmn-test.json&logo=nvidia&label=A100%20distributed">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-t5x-mgmn-test.json&logo=nvidia&label=A100%20distributed">
+        </picture>
       </td>
     </tr>
     <tr>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Upstream%20PAX%3D%7Bcore%2Cpaxml%2Cpraxis%7D">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Upstream%20PAX%3D%7Bcore%2Cpaxml%2Cpraxis%7D">
+        </picture>
       </td>
       <td>
         <code>ghcr.io/nvidia/jax:upstream-pax</code>
@@ -146,12 +188,16 @@
         <a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-upstream-pax-md"><img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-pax-build-arm64.json&logo=docker&label=arm64"></a>
       </td>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-upstream-pax-mgmn-test.json&logo=nvidia&label=A100%20distributed">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-upstream-pax-mgmn-test.json&logo=nvidia&label=A100%20distributed">
+        </picture>
       </td>
     </tr>
     <tr>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Rosetta%20PAX%3D%7Bcore%2Cpaxml%2Cpraxis%7D">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Rosetta%20PAX%3D%7Bcore%2Cpaxml%2Cpraxis%7D">
+        </picture>
       </td>
       <td>
         <code>ghcr.io/nvidia/jax:pax</code>
@@ -161,12 +207,16 @@
         <a href="https://gist.github.com/nvjax/913c2af68649fe568e9711c2dabb23ae/#file-final-pax-md"><img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-build-pax-arm64.json&logo=docker&label=arm64"></a>
       </td>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-pax-mgmn-test.json&logo=nvidia&label=A100%20distributed">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-rosetta-pax-mgmn-test.json&logo=nvidia&label=A100%20distributed">
+        </picture>
       </td>
     </tr>
     <tr>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=MaxText%3D%7Bcore%2CMaxText%7D">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=MaxText%3D%7Bcore%2CMaxText%7D">
+        </picture>
       </td>
       <td>
         <code>ghcr.io/nvidia/jax:maxtext</code>
@@ -176,12 +226,16 @@
         <!-- <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-maxtext-build-arm64.json&logo=docker&label=arm64"> -->
       </td>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-maxtext-test.json&logo=nvidia&label=A100%20distributed">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/endpoint?url=https%3A%2F%2Fgist.githubusercontent.com%2Fnvjax%2F913c2af68649fe568e9711c2dabb23ae%2Fraw%2Fbadge-maxtext-test.json&logo=nvidia&label=A100%20distributed">
+        </picture>
       </td>
     </tr>
     <tr>
       <td>
-        <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Grok%3D%7Bcore%2CGrok-1%7D">
+        <picture>
+          <img style="height:1em;" src="https://img.shields.io/static/v1?label=&color=gray&logo=docker&message=Grok%3D%7Bcore%2CGrok-1%7D">
+        </picture>
       </td>
       <td>
         <code>ghcr.io/nvidia/jax:grok</code>
diff --git a/rosetta/docs/GPU_performance.md b/rosetta/docs/GPU_performance.md
new file mode 100644
index 000000000..62b198595
--- /dev/null
+++ b/rosetta/docs/GPU_performance.md
@@ -0,0 +1,135 @@
+# Tips for High-Performance LLMs with JAX and XLA 
+
+This page documents the various flags in XLA and JAX to improve performance for LLMs on GPUs. The XLA flags are defined with their default values in [xla/debug_options_flags.cc](https://github.com/openxla/xla/blob/main/xla/debug_options_flags.cc)
+
+The flags can be set via the environment variable `XLA_FLAGS="--xla-flag1=true --xla-flag2=false"` on command line or your script.
+
+
+## Flags to manage memory used in JAX/XLA
+
+- XLA_PYTHON_CLIENT_MEM_FRACTION is a XLA environment variable that allocates a fraction of GPU memory for JAX/XLA.
+--  Ideally, should be 1, but in practice less because some memory is used by NVIDIA Libraries, and the JAX framework.
+--  We typically set it to 0.9 or 0.8. At 0.9, XLA gets 90% of GPU memory.
+
+- The `xla_gpu_memory_limit_slop_factor` flag controls the memory used by XLA for determining its default heuristics for scheduling, and rematerialization. Default is recommended.
+
+
+## General CUDA/NCCL flags 
+
+### CUDA configuration
+
+The following environment variable restricts CUDA queues to 1 and is useful when a strict ordering of operations is required to achieve best performance. This is recommended to achieve good performance with latency hiding optimizations with asynchronous collectives.
+- CUDA_DEVICE_MAX_CONNECTIONS=1
+  
+### NCCL configuration 
+
+See [NCCL Environment Variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html) for more details.
+- NCCL_PROTO: SIMPLE,LL,LL128
+
+The following variable accelerates all-reduce collective on NVLink4/H100. It requires additional GPU memory and may need one to reduce `XLA_PYTHON_CLIENT_MEM_FRACTION` to avoid OOMs if enabled.
+- NCCL_NVLS_ENABLE:1 
+
+
+## XLA flags to enable Latency Hiding Scheduler, and asynchronous collective communication
+
+To achieve communication computation overlap for models in JAX/XLA, we must enable Latency Hiding Scheduler and enable asynchronous communications. 
+
+To enable latency hiding optimizations with XLA, turn on the following flag: 
+
+- --xla_gpu_enable_latency_hiding_scheduler=true 
+
+To enable asynchronous communication for all collectives, the following is recommended, and is set by default in XLA :
+
+- --xla_gpu_enable_async_collectives=true
+- --xla_gpu_enable_highest_priority_async_stream=true
+
+For more fine-grained control over which collectives should be asynchronous or not, please use: 
+
+- --xla_gpu_enable_async_all_reduce=<>
+- --xla_gpu_enable_async_all_gather=<>
+- --xla_gpu_enable_async_reduce_scatter=<> 
+- --xla_gpu_enable_async_collective_permute=<>
+
+
+### Flags to enable optimizations for FSDP communication 
+
+With FSDP in JAX/XLA, there are additional optimizations of 
+
+- scan loop unrolling and loop double buffering 
+    - --xla_gpu_enable_while_loop_double_buffering=true
+      
+- optimized pipelining of all-gather and reduce-scatter for latency hiding in FSDP
+    - --xla_gpu_enable_pipelined_all_gather=true
+    - --xla_gpu_enable_pipelined_reduce_scatter=true
+    - --xla_gpu_enable_pipelined_all_reduce=true 
+    - --xla_gpu_enable_pipelined_collectives=false // if true overrides the above
+      
+- combining tensors that are sharded along different dimensions. Within a transformer layer, tensors can be sharded row-wise or column-wise and by default XLA will generate multiple collective calls for tensors sharded along different dimensions. The following optimization flags combine all tensors shardings, and map them to a group NCCL call that has a large commulative size and achieves high communication efficiency. 
+    - --xla_gpu_enable_all_gather_combine_by_dim=false
+    - --xla_gpu_enable_reduce_scatter_combine_by_dim=false
+      
+- Combine threshold values in XLA that determine when an all-gather (AG) or reduce-scatter (RS) is triggered. We want to set these values to be at least as large as the size of weights (AG) or gradients (RS) in a single transformer layer since large communication buffers achieve higher link bandwidth utilization. For example, LLAMA2-7B with BF16 weights and gradients, we have 32 transformer layers => each layer has ~218M weights => one would want to set these thresholds to at least 436MB.
+    - --xla_gpu_all_gather_combine_threshold_bytes=8589934592
+    - --xla_gpu_reduce_scatter_combine_threshold_bytes=8589934592
+      
+- Combine threshold values in XLA that determine when an all-reduce (AR) is triggered. Typically, used to overlap AR of gradients with back-prop of compute. We want to set this to be at least as large as possible to achieve high efficiency, but as small as possible to achieve maximum overlap. Depending on the interconnect of your system, one might want to try several threshold values in steps of 2 from say 16MB to total gradient size.
+    - --xla_gpu_all_reduce_combine_threshold_bytes=8589934592
+
+
+### Flags to enable async collective permute 
+
+The following flags enable overlap of pipeline parallel communication of send/recv with computation. 
+- --xla_gpu_enable_pipelined_p2p=true  (false by default)
+- --xla_gpu_collective_permute_decomposer_threshold=1024
+- --xla_gpu_lhs_enable_gpu_async_tracker=true
+
+### Flags to enable collective matmul
+
+The following flags enable overlap of tensor parallel communication with GEMMs/matmul by splicing GEMMs into smaller chunks and triggering each chunks' collective right after the chunk's GEMM is done. The threshold determines the size of output buffer of GEMM when this optimization becomes active (0 enables collective matmul for all GEMM-collective patterns)
+- --xla_gpu_multi_streamed_windowed_einsum=true
+- --xla_gpu_threshold_for_windowed_einsum_mib=0
+
+### Profile Guided Latency Estimator (PGLE)
+
+The following flag enables use of PGLE with JAX/XLA. Please see [PGLE notes](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/docs/PGLE.md) for more details.
+- --xla_gpu_pgle_profile_file_or_directory_path=filename
+
+## Other XLA Flags 
+
+### CUDA graphs
+
+The below enables CUDA Graph suppport for JAX/XLA workloads, and is enabled by default.
+- --xla_gpu_enable_command_buffer (Set to "" to disable)
+
+
+
+### Dynamic-Update Slice Fusion
+
+The following flag removes extra copies introduced by DUS (dynamic update slice) when used in conjunction with custom NVIDIA kernels (like cuBLAS for GEMMs)
+- --xla_gpu_enable_custom_fusions=true
+
+### NCCL Optimizations
+
+Enable user-buffers in NCCL for zero-copy collectives and send/recv. Needs NCCL_NVLS_ENABLE=1 for AG, AR, RS.
+- --xla_gpu_enable_nccl_user_buffers=true
+
+Flags to reduce memory consumed by NCCL.
+- --xla_gpu_enable_nccl_comm_splitting=false  
+- --xla_gpu_enable_nccl_per_stream_comms=false [https://github.com/openxla/xla/pull/9845](https://github.com/openxla/xla/pull/9845)
+
+Fine-grain control to improve performance by initializing a NCCL communicator to use only max_nchannels (SMs). Default value of 0 gets the default values from NCCL for SMs used per collective.
+- --xla_gpu_nccl_collective_max_nchannels
+- --xla_gpu_nccl_p2p_max_nchannels
+
+### Debug flags 
+- --xla_dump_to=some/path
+- --xla_dump_latency_hiding_schedule=true
+
+### Miscellaneous flags 
+- --xla_gpu_cudnn_gemm_fusion=true (enables GEMM/bias fusion via cuDNN)
+- --xla_gpu_enable_cudnn_fmha=false (enables XLA pattern matcher to detect multi-headed attention pattern in JAX)
+- --xla_disable_hlo_passes=<> (turns off specific HLO passes; can be used for debugging)
+
+
+
+