Skip to content

Commit

Permalink
adding single-node+multi-gpu transformer training pipeline configurat…
Browse files Browse the repository at this point in the history
…ion example and k8s test case
  • Loading branch information
SebastianScherer88 committed Nov 30, 2024
1 parent b3bdd1e commit 9db62a7
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 4 deletions.
1 change: 1 addition & 0 deletions sdk/bettmensch_ai/pipelines/pipeline/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .annotated_transformer import ( # noqa: F401
train_transformer_pipeline_1n_1p,
train_transformer_pipeline_1n_2p,
train_transformer_pipeline_2n_1p,
train_transformer_pipeline_2n_2p,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .annotated_transformer import ( # noqa: F401
train_transformer_pipeline_1n_1p,
train_transformer_pipeline_1n_2p,
train_transformer_pipeline_2n_1p,
train_transformer_pipeline_2n_2p,
)
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def train_transformer_pipeline(


train_transformer_pipeline_1n_1p = get_train_transformer_pipeline()
train_transformer_pipeline_1n_2p = get_train_transformer_pipeline(
name="test-train-pipeline-1n-2p-", n_nodes=1, n_proc_per_node=2
)
train_transformer_pipeline_2n_1p = get_train_transformer_pipeline(
name="test-train-pipeline-2n-1p-", n_nodes=2, n_proc_per_node=1
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from bettmensch_ai.pipelines.pipeline.examples import (
train_transformer_pipeline_1n_1p,
train_transformer_pipeline_1n_2p,
train_transformer_pipeline_2n_1p,
train_transformer_pipeline_2n_2p,
)
Expand Down Expand Up @@ -49,8 +50,50 @@ def test_train_transformer_pipeline_1n_1p_decorator_and_register_and_run(


@pytest.mark.train_transformer
@pytest.mark.multi_node_single_gpu
@pytest.mark.single_node_multi_gpu
@pytest.mark.order(12)
def test_train_transformer_pipeline_1n_2p_decorator_and_register_and_run(
test_output_dir, test_namespace
):
train_transformer_pipeline_1n_2p.export(test_output_dir)

assert not train_transformer_pipeline_1n_2p.registered
assert train_transformer_pipeline_1n_2p.registered_id is None
assert train_transformer_pipeline_1n_2p.registered_name is None
assert train_transformer_pipeline_1n_2p.registered_namespace is None

train_transformer_pipeline_1n_2p.register()

assert train_transformer_pipeline_1n_2p.registered
assert train_transformer_pipeline_1n_2p.registered_id is not None
assert train_transformer_pipeline_1n_2p.registered_name.startswith(
f"pipeline-{train_transformer_pipeline_1n_2p.name}-"
)
assert (
train_transformer_pipeline_1n_2p.registered_namespace == test_namespace
) # noqa: E501

train_transformer_flow = train_transformer_pipeline_1n_2p.run(
inputs={
"dataset": "multi30k",
"source_language": "de",
"target_language": "en",
"batch_size": 32,
"num_epochs": 1,
"accum_iter": 10,
"base_lr": 1.0,
"max_padding": 72,
"warmup": 3000,
},
wait=True,
)

assert train_transformer_flow.status.phase == "Succeeded"


@pytest.mark.train_transformer
@pytest.mark.multi_node_single_gpu
@pytest.mark.order(13)
def test_train_transformer_pipeline_2n_1p_decorator_and_register_and_run(
test_output_dir, test_namespace
):
Expand Down Expand Up @@ -92,7 +135,7 @@ def test_train_transformer_pipeline_2n_1p_decorator_and_register_and_run(

@pytest.mark.train_transformer
@pytest.mark.multi_node_multi_gpu
@pytest.mark.order(13)
@pytest.mark.order(14)
def test_train_transformer_pipeline_2n_2p_decorator_and_register_and_run(
test_output_dir, test_namespace
):
Expand Down
2 changes: 1 addition & 1 deletion sdk/test/k8s/pipelines/pipeline/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_get_standard_flow(test_namespace):
@pytest.mark.ddp
@pytest.mark.train_transformer
@pytest.mark.delete_flows
@pytest.mark.order(14)
@pytest.mark.order(15)
def test_delete(test_namespace):
"""Test the delete_flow function"""

Expand Down
2 changes: 1 addition & 1 deletion sdk/test/k8s/pipelines/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def test_run_dpp_registered_pipelines_from_registry(
@pytest.mark.ddp
@pytest.mark.train_transformer
@pytest.mark.delete_pipelines
@pytest.mark.order(15)
@pytest.mark.order(16)
def test_delete_registered_pipeline(test_namespace):
"""Test the delete_registered_pipeline function"""

Expand Down

0 comments on commit 9db62a7

Please sign in to comment.