diff --git a/tests/pytorch/nightly/llama2-model.libsonnet b/tests/pytorch/nightly/llama2-model.libsonnet index 1086cb80c..18d135d9b 100644 --- a/tests/pytorch/nightly/llama2-model.libsonnet +++ b/tests/pytorch/nightly/llama2-model.libsonnet @@ -45,23 +45,15 @@ local utils = import 'templates/utils.libsonnet'; }, command: self.paramsOverride.trainCommand, }, + local pjrt = self.pjrt, + pjrt:: common.PyTorchTpuVmMixin { + modelName: 'llama2-pjrt', + }, local infer = self.infer, - infer:: common.PyTorchTpuVmMixin { + infer:: common.PyTorchTpuVmMixin + pjrt { modelName+: '-infer', tpuSettings+: { tpuVmExtraSetup: ||| - pip3 uninstall torch torch_xla torchvision libtpu-nightly -y - sudo apt-get update -y - sudo apt-get install libomp5 -y - pip3 install mkl mkl-include - pip3 install tf-nightly tb-nightly tbp-nightly - pip3 install numpy - sudo apt-get install numactl -y - sudo apt-get install libopenblas-dev -y - pip3 install --user --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl - pip3 install torch_xla[tpuvm] - # install tokenizer model wget https://storage.googleapis.com/tpu-pytorch/lsiyuan-experiment/llama/spiece.model @@ -93,7 +85,7 @@ local utils = import 'templates/utils.libsonnet'; }, }, local spmd = self.spmd, - spmd:: common.PyTorchTpuVmMixin { + spmd:: common.PyTorchTpuVmMixin + pjrt { modelName+: '-train-spmd', tpuSettings+: { tpuVmExports+: ||| @@ -110,19 +102,6 @@ local utils = import 'templates/utils.libsonnet'; export TPU_MEGACORE=megacore_dense |||, tpuVmExtraSetup: ||| - pip3 uninstall torch torch_xla torchvision libtpu-nightly -y - sudo apt update -y - sudo apt-get update -y - pip install accelerate -U - sudo apt-get install libomp5 -y - pip3 install mkl mkl-include - pip3 install numpy - sudo apt-get install numactl -y - sudo apt-get install libopenblas-dev -y - pip3 install --user --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu - pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl - pip3 install torch_xla[tpuvm] - # install tokenizer model wget https://storage.googleapis.com/tpu-pytorch/lsiyuan-experiment/llama/spiece.model