Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

External test trigger #683

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/container/Dockerfile.maxtext.amd64
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ ARG SRC_PATH_MAXTEXT
RUN <<"EOF" bash -ex
cat ${MANIFEST_FILE}
get-source.sh -l maxtext -m ${MANIFEST_FILE}
EOF

###############################################################################
## Apply patch
###############################################################################

ADD maxtext-mha.patch /opt
RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff

RUN <<"EOF" bash -ex
cat "${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in
EOF

Expand Down
130 changes: 130 additions & 0 deletions .github/container/maxtext-mha.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py
index 05145c5..9ad5b2c 100644
--- a/MaxText/layers/attentions.py
+++ b/MaxText/layers/attentions.py
@@ -182,7 +182,7 @@ class AttentionOp(nn.Module):
if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError("""Decode not supported with flash attention.
Use `dot_product` instead.""")
- return self.cudnn_flash_attention(query, key, value), None, None
+ return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None
else:
raise ValueError(f'Unexpected attention kernel {self.attention_kernel=}.')

@@ -255,58 +255,39 @@ class AttentionOp(nn.Module):
x = jnp.transpose(x, axes=(0, 2, 1, 3))
return x

- def cudnn_flash_attention(self,
- query: Array,
- key: Array,
- value: Array) -> Array:
+ def cudnn_flash_attention(
+ self,
+ query: Array,
+ key: Array,
+ value: Array,
+ decoder_segment_ids: Array | None,
+ model_mode: str = common_types.MODEL_MODE_TRAIN,
+ ) -> Array:
"""CUDNN Flash Attention with Transformer Engine.
-
- It is an unstable API. In future release, the API can get changed
- A stable flash attention API will be included soon. Currently,
- 1. It does not support GQA, num_query_heads == num_kv_heads
- 2. It supports head_dim till 128
- GQA support with head_dim=256 will be added soon
+ 1. Stable API, supports GQA
+ 2. Supports head_dim till 128; head_dim=256 support will be added soon
"""
-
# These imports are only meant to work in a GPU build.
- import transformer_engine.jax.fused_attn as fused_attn # pytype: disable=import-error
- from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout # pytype: disable=import-error
- from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available # pytype: disable=import-error
-
- batch, s_q, n_heads, head_dim = query.shape # pylint: disable=unused-variable
- _, s_kv, _, _ = key.shape
-
- qkv_layout = QKVLayout.BS3HD
- attn_mask_type = AttnMaskType.CAUSAL_MASK
- attn_bias_type = AttnBiasType.NO_BIAS
-
- has_fused_attn_kernel = is_fused_attn_kernel_available(
- self.dtype, self.dtype, qkv_layout,
- attn_bias_type,
- attn_mask_type,
- self.dropout_rate, self.num_query_heads,
- self.num_kv_heads, s_q,
- s_kv, head_dim)
-
- if not has_fused_attn_kernel:
- raise ValueError("Flash attention is not supported for current config i.e. head_dim, seq_len, n_heads etc."
- "Please see transformer_engine/common/fused_attn/fused_attn.cpp:NVTE_Fused_Attn_Backend for details")
-
- q = jnp.reshape(query, (*query.shape[:2], 1, *query.shape[-2:]))
- k = jnp.reshape(key, (*query.shape[:2], 1, *query.shape[-2:]))
- v = jnp.reshape(value, (*query.shape[:2], 1, *query.shape[-2:]))
- qkv = jnp.concatenate((q, k, v), axis=2) # to make it (b, s, 3, h, d)
-
- return fused_attn.self_fused_attn(
- qkv=qkv,
- bias=None,
- mask=jnp.zeros((batch, 1, s_q, s_kv)), # no padding
- seed=None,
- attn_bias_type=attn_bias_type,
- attn_mask_type=attn_mask_type,
- scaling_factor=1.0/math.sqrt(head_dim),
- dropout_probability=self.dropout_rate,
- is_training=True)
+ from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
+
+ _, _, _, head_dim = query.shape # pylint: disable=unused-variable
+
+ #generate attn_mask
+ attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
+
+ dpa_layer = DotProductAttention(head_dim=head_dim,
+ num_attention_heads=self.num_query_heads,
+ num_gqa_groups=self.num_kv_heads,
+ attn_mask_type='causal', # 'causal' or 'padding'
+ attn_bias_type='NO_BIAS', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
+ attention_dropout=self.dropout_rate,
+ dropout_rng_name='aqt',
+ dtype=self.dtype,
+ float32_logits=self.float32_logits,
+ qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
+ scale_factor=1.0/math.sqrt(head_dim),
+ transpose_batch_sequence=False)
+ return dpa_layer(query, key, value, mask=attn_mask)

def compute_local_attention(self,
attn_weights: Array,
diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py
index d15f764..cfb87a1 100644
--- a/MaxText/layers/models.py
+++ b/MaxText/layers/models.py
@@ -90,7 +90,7 @@ class DecoderLayer(nn.Module):
dropout_rate=cfg.dropout_rate,
name='self_attention',
quant=self.quant,
- quantize_kvcache=self.quantize_kvcache)
+ quantize_kvcache=cfg.quantize_kvcache)


attention_lnx = attention_layer(
diff --git a/requirements.txt b/requirements.txt
index cae6c73..4b7a214 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -17,8 +17,8 @@ pylint
pytest
pytype
sentencepiece==0.1.97
-tensorflow-text>=2.13.0
-tensorflow>=2.13.0
+tensorflow-text==2.13.0
+tensorflow==2.13.0
tensorflow-datasets
tensorboardx
tensorboard-plugin-profile
2 changes: 1 addition & 1 deletion .github/workflows/_test_maxtext.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ jobs:
#SBATCH --time=00:10:00
#SBATCH --output=${{ steps.meta.outputs.LOG_FILE }}
#SBATCH --export="VOCAB_PATH=gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model,ENROOT_PASSWORD=${{ secrets.GITHUB_TOKEN }},NVCR_TOKEN=${{ secrets.NVCR_TOKEN }}"

# preload enroot container using one task per node
time srun \
--ntasks-per-node=1 \
Expand Down
17 changes: 11 additions & 6 deletions .github/workflows/ngc-release-testing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,22 @@ permissions:
packages: read # to upload container

jobs:

test-jax:
if: inputs.JAX_IMAGE != ''
uses: ./.github/workflows/_test_unit.yaml
with:
IMAGE: ${{ inputs.JAX_IMAGE }}
TEST_NAME: jax
EXECUTE: |
docker run --shm-size=1g --gpus all ${{ inputs.JAX_IMAGE }} test-jax.sh -b backend-independent | tee test-backend-independent.log
docker run --shm-size=1g --gpus all ${{ inputs.JAX_IMAGE }} test-jax.sh -b gpu | tee test-gpu.log
docker run -i --shm-size=1g --gpus all \
${{ inputs.JAX_IMAGE }} \
bash <<"EOF" |& tee test-backend-independent.log
test-jax.sh -b backend-independent
EOF
docker run -i --shm-size=1g --gpus all \
${{ inputs.JAX_IMAGE }} \
bash <<"EOF" |& tee tee test-gpu.log
test-jax.sh -b gpu
EOF
STATISTICS_SCRIPT: |
errors=$(cat test-*.log | grep -c 'ERROR:' || true)
failed_tests=$(cat test-*.log | grep -c 'FAILED in' || true)
Expand Down Expand Up @@ -76,10 +82,9 @@ jobs:
if: inputs.LEVANTER_IMAGE != ''
uses: ./.github/workflows/_test_unit.yaml
with:
IMAGE: ${{ inputs.LEVANTER_IMAGE }}
TEST_NAME: levanter
EXECUTE: |
docker run --gpus all --shm-size=1g ${{ needs.build-levanter.outputs.DOCKER_TAG_FINAL }} \
docker run --gpus all --shm-size=1g ${{ inputs.LEVANTER_IMAGE }} \
bash -ec \
"pip install pytest && PYTHONPATH=/opt/levanter/tests:$PYTHONPATH pytest /opt/levanter/tests" | tee test-levanter.log
STATISTICS_SCRIPT: |
Expand Down
Loading
Loading