Skip to content

Commit 28951de

Browse files
[python] add aot config for nxdi with vllm (deepjavalibrary#2691)
1 parent c209230 commit 28951de

File tree

4 files changed

+45
-0
lines changed

4 files changed

+45
-0
lines changed

engines/python/setup/djl_python/transformers_neuronx.py

+9
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
OPTIMUM_CAUSALLM_MODEL_TYPES = {"gpt2", "opt", "bloom", "llama", "mistral"}
3737
OPTIMUM_CAUSALLM_CONTINUOUS_BATCHING_MODELS = {"llama", "mistral"}
3838
VLLM_CONTINUOUS_BATCHING_MODELS = {"llama"}
39+
NXDI_COMPILED_MODEL_FILE_NAME = "model.pt"
3940

4041

4142
class TransformersNeuronXService(object):
@@ -141,6 +142,14 @@ def set_model_loader_class(self) -> None:
141142
if self.config.model_loader == "nxdi":
142143
os.environ[
143144
'VLLM_NEURON_FRAMEWORK'] = "neuronx-distributed-inference"
145+
if self.config.save_mp_checkpoint_path:
146+
os.environ[
147+
"NEURON_COMPILED_ARTIFACTS"] = self.config.save_mp_checkpoint_path
148+
nxdi_compiled_model_path = os.path.join(
149+
self.config.model_id_or_path, NXDI_COMPILED_MODEL_FILE_NAME)
150+
if os.path.isfile(nxdi_compiled_model_path):
151+
os.environ[
152+
"NEURON_COMPILED_ARTIFACTS"] = self.config.model_id_or_path
144153
return
145154

146155
if self.config.model_loader == "vllm":

tests/integration/llm/client.py

+4
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ def get_model_name():
172172
"llama-3-1-8b-instruct-vllm-nxdi": {
173173
"batch_size": [1, 2],
174174
"seq_length": [256],
175+
},
176+
"llama-3-2-1b-instruct-vllm-nxdi-aot": {
177+
"batch_size": [1],
178+
"seq_length": [128],
175179
}
176180
}
177181

tests/integration/llm/prepare.py

+16
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,22 @@
271271
"deterministic": False
272272
}
273273
}
274+
},
275+
"llama-3-2-1b-instruct-vllm-nxdi-aot": {
276+
"option.model_id": "s3://djl-llm/llama-3-2-1b-instruct/",
277+
"option.tensor_parallel_degree": 2,
278+
"option.rolling_batch": "vllm",
279+
"option.model_loading_timeout": 1200,
280+
"option.model_loader": "nxdi",
281+
"option.override_neuron_config": {
282+
"on_device_sampling_config": {
283+
"global_topk": 64,
284+
"dynamic": True,
285+
"deterministic": False
286+
}
287+
},
288+
"option.n_positions": 128,
289+
"option.max_rolling_batch_size": 1,
274290
}
275291
}
276292

tests/integration/tests.py

+16
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,22 @@ def test_llama_vllm_nxdi(self):
900900
"transformers_neuronx_rolling_batch llama-3-1-8b-instruct-vllm-nxdi"
901901
)
902902

903+
def test_llama_vllm_nxdi_aot(self):
904+
with Runner('pytorch-inf2',
905+
'llama-3-2-1b-instruct-vllm-nxdi-aot') as r:
906+
prepare.build_transformers_neuronx_handler_model(
907+
"llama-3-2-1b-instruct-vllm-nxdi-aot")
908+
r.launch(
909+
container="pytorch-inf2-1",
910+
cmd=
911+
"partition --model-dir /opt/ml/input/data/training --save-mp-checkpoint-path /opt/ml/input/data/training/aot --skip-copy"
912+
)
913+
r.launch(container="pytorch-inf2-1",
914+
cmd="serve -m test=file:/opt/ml/model/test/aot")
915+
client.run(
916+
"transformers_neuronx_rolling_batch llama-3-2-1b-instruct-vllm-nxdi-aot"
917+
)
918+
903919

904920
@pytest.mark.correctness
905921
@pytest.mark.trtllm

0 commit comments

Comments
 (0)