diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml deleted file mode 100644 index 27385e7c5..000000000 --- a/.github/workflows/integration_tests.yml +++ /dev/null @@ -1,154 +0,0 @@ -name: Integration tests - -on: - workflow_dispatch: - inputs: - test: - description: the integration test to run - default: fairscale_benchmarks - required: true - type: choice - options: - - fairscale_benchmarks - cluster: - description: the beaker cluster to run the test on - default: ai2/tango-integration-tests - required: true - type: choice - options: - - ai2/tango-integration-tests - - ai2/allennlp-cirrascale - # Uncomment this trigger to test changes on a pull request. - # You also have to uncomment the lines below that mention 'for pull request checks' - # pull_request: - # branches: - # - '*' - -jobs: - run_test: - name: ${{ github.event.inputs.test }} - # name: fairscale_benchmarks # for pull request checks - runs-on: [ubuntu-latest] - timeout-minutes: 60 - env: - TEST_NAME: ${{ github.event.inputs.test }} - # TEST_NAME: fairscale_benchmarks # for pull request checks - BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} - BEAKER_WORKSPACE: ai2/tango-integration-tests - BEAKER_CLUSTER: ${{ github.event.inputs.cluster }} - # BEAKER_CLUSTER: ai2/allennlp-cirrascale # for pull request checks - IMAGE_NAME: petew/tango-testing - steps: - - uses: actions/checkout@v3 - - - name: Validate inputs - run: | - # The 'test' input should be a directory in `integration_tests/` - test -d "integration_tests/${TEST_NAME}" - - - name: Determine current commit SHA (pull request) - if: github.event_name == 'pull_request' - run: | - echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV - - - name: Determine current commit SHA (push) - if: github.event_name != 'pull_request' - run: | - echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV - - - name: Install beaker client - shell: bash - run: | - mkdir -p "$HOME/bin" - - # Download and install from latest GitHub release. - curl -s https://api.github.com/repos/allenai/beaker/releases/latest \ - | grep 'browser_download_url.*linux' \ - | cut -d '"' -f 4 \ - | wget -qi - \ - && tar -xvzf beaker_linux.tar.gz -C "$HOME/bin" - - # Add to path. - echo "$HOME/bin" >> "$GITHUB_PATH" - - - name: Verify beaker install - run: | - beaker account whoami - - - name: Create beaker experiment config - run: | - cat >beaker_config.yml << EOL - version: v2-alpha - description: ${{ env.TEST_NAME }} - tasks: - - name: test - image: - beaker: ${{ env.IMAGE_NAME }} - command: ["/entrypoint.sh", "integration_tests/${{ env.TEST_NAME }}/run.sh"] - envVars: - - name: COMMIT_SHA - value: $COMMIT_SHA - - name: WANDB_API_KEY - secret: WANDB_API_KEY - - name: FILE_FRIENDLY_LOGGING - value: "true" - - name: TOKENIZERS_PARALLELISM # set this to avoid warnings - value: "true" - - name: PYTHONUNBUFFERED - value: "true" - result: - path: '/results' - resources: - gpuCount: 4 - context: - cluster: ${{ env.BEAKER_CLUSTER }} - priority: normal - EOL - cat beaker_config.yml - - - name: Submit beaker job - run: | - TIMESTAMP=$(date +%H%M%S) - EXPERIMENT=$(beaker experiment create beaker_config.yml --workspace $BEAKER_WORKSPACE --name "${TEST_NAME}-${{ github.run_number }}-${TIMESTAMP}" | awk '{print $2}') - if [ -z "$EXPERIMENT" ]; then - exit 1 - else - echo "EXPERIMENT=$EXPERIMENT" >> $GITHUB_ENV - echo "Experiment $EXPERIMENT submitted. See progress at https://beaker.org/ex/$EXPERIMENT" - fi - - - name: Wait for job to finish - run: | - beaker experiment await $EXPERIMENT test finalized --timeout 60m - # Check the job's exit code. - test $(beaker experiment get $EXPERIMENT --format=json | jq '.[0].jobs[0].status.exitCode') -eq 0 - - - name: Get logs - if: always() - run: | - # EXPERIMENT could be empty if the submission step failed. - # We'll exit right away if that's the case. - if [ -z "$EXPERIMENT" ]; then - echo "No logs to show" - exit 0 - fi - - # Download logs from beaker. - beaker experiment results $EXPERIMENT --prefix out.log --output results - - # If the experiment failed during startup, there might not be any logs. - if [ -f results/test/out.log ]; then - echo "" - echo ">>> Logs:" - echo "" - cat results/test/out.log - else - echo "No logs to show" - fi - - - name: Stop job - if: cancelled() - run: | - if [ ! -z "$EXPERIMENT" ]; then - beaker experiment stop $EXPERIMENT - fi diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9174a0021..fe75389c9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,7 +20,7 @@ env: WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} BEAKER_WORKSPACE: ai2/tango-testing - BEAKER_DEFAULT_CLUSTER: ai2/tango-gpu-tests + BEAKER_DEFAULT_CLUSTER: ai2/canary BEAKER_IMAGE: petew/tango-testing GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -93,12 +93,6 @@ jobs: run: | pytest -v --color=yes --doctest-modules tango/integrations/transformers tests/integrations/transformers - - name: FairScale integration - extras: dev,fairscale - requires_torch: true - run: | - pytest -v --color=yes --doctest-modules tango/integrations/fairscale tests/integrations/fairscale - - name: W&B integration extras: dev,torch,flax,wandb requires_torch: true @@ -298,7 +292,7 @@ jobs: path: /unused token: ${{ secrets.BEAKER_TOKEN }} workspace: ${{ env.BEAKER_WORKSPACE }} - clusters: ai2/general-cirrascale,ai2/allennlp-cirrascale,ai2/aristo-cirrascale,ai2/mosaic-cirrascale,ai2/s2-cirrascale + clusters: ai2/general-cirrascale,ai2/allennlp-cirrascale,ai2/aristo-cirrascale,ai2/mosaic-cirrascale,ai2/s2-cirrascale,ai2/mosaic-cirrascale-a100,ai2/prior-cirrascale,ai2/general-cirrascale-a100-80g-ib release: name: Release diff --git a/CHANGELOG.md b/CHANGELOG.md index 917f69fb9..58e6f0ff8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixes a bug where `FromParams` would fail to parse when an object takes a `Step` argument directly. - Changed a name so we don't override the built-in name `set`. - Fixed a bug that would cause O(n^2) memory consumption in dense step graphs. +- Fixed how we find learning rate schedulers in Torch 2. ## [v1.2.0](https://github.com/allenai/tango/releases/tag/v1.2.0) - 2023-02-10 diff --git a/README.md b/README.md index 84ca7a569..80eb4395f 100644 --- a/README.md +++ b/README.md @@ -230,7 +230,7 @@ The motivation behind this library is that we can make research easier by compos You can run the `tango` command through [pdb](https://docs.python.org/3/library/pdb.html). For example: ```bash -python -m pdb -m tango run config.jsonnet +python -m pdb -m tango run fsdp_config.jsonnet ``` ### How is Tango different from [Metaflow](https://metaflow.org), [Airflow](https://airflow.apache.org), or [redun](https://github.com/insitro/redun)? diff --git a/docs/source/api/integrations/fairscale.rst b/docs/source/api/integrations/fairscale.rst deleted file mode 100644 index d7890f909..000000000 --- a/docs/source/api/integrations/fairscale.rst +++ /dev/null @@ -1,14 +0,0 @@ -🔥 FairScale -============ - -.. automodule:: tango.integrations.fairscale - -Reference ---------- - -.. autoclass:: tango.integrations.fairscale.FairScaleTrainingEngine - -.. autoclass:: tango.integrations.fairscale.FSDPConfig - :members: - -.. autofunction:: tango.integrations.fairscale.with_wrapped_modules diff --git a/docs/source/api/integrations/index.rst b/docs/source/api/integrations/index.rst index 91ddf4fd9..ab5f40be3 100644 --- a/docs/source/api/integrations/index.rst +++ b/docs/source/api/integrations/index.rst @@ -8,7 +8,6 @@ Integrations :caption: Integrations torch - fairscale datasets transformers wandb diff --git a/docs/source/api/integrations/torch.rst b/docs/source/api/integrations/torch.rst index 5996bc793..2c64f5d86 100644 --- a/docs/source/api/integrations/torch.rst +++ b/docs/source/api/integrations/torch.rst @@ -32,6 +32,8 @@ Model .. autoclass:: tango.integrations.torch.Model :members: +.. autofunction:: tango.integrations.torch.with_wrapped_modules + TrainingEngine ~~~~~~~~~~~~~~ @@ -40,6 +42,11 @@ TrainingEngine .. autoclass:: tango.integrations.torch.TorchTrainingEngine +.. autoclass:: tango.integrations.torch.FSDPTrainingEngine + +.. autoclass:: tango.integrations.torch.FSDPConfig + :members: + Optim ~~~~~ diff --git a/docs/source/conf.py b/docs/source/conf.py index 7002ebf81..a5c814b86 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -62,7 +62,6 @@ "rich": ("https://rich.readthedocs.io/en/latest", None), "torch": ("https://pytorch.org/docs/stable", None), "flax": ("https://flax.readthedocs.io/en/latest", None), - "fairscale": ("https://fairscale.readthedocs.io/en/latest/", None), "datasets": ("https://huggingface.co/docs/datasets/master/en", None), "transformers": ("https://huggingface.co/docs/transformers/master/en", None), "beaker": ("https://beaker-py.readthedocs.io/en/latest/", None), diff --git a/docs/source/examples/eval_p3.md b/docs/source/examples/eval_p3.md index 8bfb87929..cd2ef7e93 100644 --- a/docs/source/examples/eval_p3.md +++ b/docs/source/examples/eval_p3.md @@ -22,5 +22,5 @@ to create the same configuration for all 10 prompts: You can run the experiment with: ```bash -tango run config.jsonnet -i eval -d /tmp/workspace +tango run fsdp_config.jsonnet -i eval -d /tmp/workspace ``` diff --git a/docs/source/examples/train_lm.md b/docs/source/examples/train_lm.md index 319950233..808c70e79 100644 --- a/docs/source/examples/train_lm.md +++ b/docs/source/examples/train_lm.md @@ -33,5 +33,5 @@ Next you'll need to create a configuration file that defines the experiment. Jus Now we can run the experiment with: ```bash -tango run config.jsonnet -i tokenize_step.py -d /tmp/results +tango run fsdp_config.jsonnet -i tokenize_step.py -d /tmp/results ``` diff --git a/docs/source/first_steps.md b/docs/source/first_steps.md index 16989205c..ca8421333 100644 --- a/docs/source/first_steps.md +++ b/docs/source/first_steps.md @@ -237,7 +237,7 @@ Tango will warn you when you try to cache a non-deterministic step. This time when we run the experiment we'll designate a specific directory for Tango to use: ```bash -$ tango run config.jsonnet -i components -d workspace/ +$ tango run fsdp_config.jsonnet -i components -d workspace/ ``` ``` Starting new run live-tarpon @@ -262,7 +262,7 @@ $ cat workspace/runs/live-tarpon/add_numbers/data.json Now look what happens when we run this step again: ```bash -$ tango run config.jsonnet -i components -d workspace/ +$ tango run fsdp_config.jsonnet -i components -d workspace/ ``` ``` Starting new run modest-shrimp @@ -290,7 +290,7 @@ If we changed the inputs to the step in `config.jsonnet`: And ran it again: ```bash -$ tango run config.jsonnet -i components -d workspace/ +$ tango run fsdp_config.jsonnet -i components -d workspace/ ``` ``` Starting new run true-parrot diff --git a/examples/finetune/config.jsonnet b/examples/finetune/config.jsonnet index 485739742..71c385d9b 100644 --- a/examples/finetune/config.jsonnet +++ b/examples/finetune/config.jsonnet @@ -23,7 +23,7 @@ local batch_size = 2; local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2) local amp = false; # use PyTorch's native automatic mixed precision -local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2) +local fsdp = false; # Use Torch's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2) local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow. ###################### @@ -38,14 +38,13 @@ assert fsdp == true || cpu_offloading == false : "cpu_offloading only available # FullyShardedDataParallel config: local fsdp_config = if fsdp then { - reshard_after_forward: true, move_params_to_cpu: cpu_offloading, move_grads_to_cpu: cpu_offloading, mixed_precision: amp, } else null; local training_engine = { - type: if fsdp then "fairscale" else "torch", + type: if fsdp then "torch::fsdp" else "torch", optimizer: { type: "torch::AdamW", lr: learning_rate, @@ -95,13 +94,13 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device trained_model: { type: "transformers::finetune", model: { - type: "fairscale::with_wrapped_modules", + type: "torch::with_wrapped_modules", model: { type: "transformers::finetune::from_pretrained", pretrained_model_name_or_path: pretrained_model, low_cpu_mem_usage: load_with_low_cpu_mem_usage, }, - modules_to_wrap: modules_to_wrap, # tell FairScale to wrap the transformer's blocks individually + modules_to_wrap: modules_to_wrap, # tell torch to wrap the transformer's blocks individually fsdp_config: fsdp_config, activation_checkpointing: activation_checkpointing, }, diff --git a/examples/train_lm/README.md b/examples/train_lm/README.md index 59c347794..e35f4d002 100644 --- a/examples/train_lm/README.md +++ b/examples/train_lm/README.md @@ -6,7 +6,7 @@ This Tango example showcases how you could train or fine-tune a causal language or [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) from [transformers](https://github.com/huggingface/transformers) on WikiText2 or a similar dataset. It's best that you run this experiment on a machine with a GPU and PyTorch [properly installed](https://pytorch.org/get-started/locally/#start-locally), otherwise Tango will fall back to CPU-only and it will be extremely slow. -This example also depends on [FairScale](https://fairscale.readthedocs.io/en/latest/), which allows you to leverage [`FullyShardedDataParallel`](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html) (FSDP) and [activation checkpointing](https://fairscale.readthedocs.io/en/latest/api/nn/checkpoint/checkpoint_activations.html) to fine-tune [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B) or a similar-sized model. Just set the constants `fsdp` and `activation_checkpointing` in the config to `true`. +This example also uses [`FullyShardedDataParallel`](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) (FSDP) and [activation checkpointing](https://pytorch.org/docs/stable/checkpoint.html) to fine-tune [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B) or a similar-sized model. Just set the constants `fsdp` and `activation_checkpointing` in the config to `true`. Without using CPU offloading you'll need at least 4 x 40GiB A100 GPUs, or a different configuration with a comparable amount of total GPU memory. diff --git a/examples/train_lm/config.jsonnet b/examples/train_lm/config.jsonnet index b99c8e098..e1e35bfeb 100644 --- a/examples/train_lm/config.jsonnet +++ b/examples/train_lm/config.jsonnet @@ -44,14 +44,13 @@ assert fsdp == true || cpu_offloading == false : "cpu_offloading only available # FullyShardedDataParallel config: local fsdp_config = if fsdp then { - reshard_after_forward: true, move_params_to_cpu: cpu_offloading, move_grads_to_cpu: cpu_offloading, mixed_precision: amp, } else null; local training_engine = { - type: if fsdp then "fairscale" else "torch", + type: if fsdp then "torch::fsdp" else "torch", optimizer: { type: "torch::AdamW", lr: learning_rate, @@ -100,13 +99,13 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device trained_model: { type: "torch::train", model: { - type: "fairscale::with_wrapped_modules", + type: "torch::with_wrapped_modules", model: { type: "transformers::AutoModelForCausalLM::from_pretrained", pretrained_model_name_or_path: pretrained_model, low_cpu_mem_usage: load_with_low_cpu_mem_usage, }, - modules_to_wrap: ["transformer\\.h\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually + modules_to_wrap: ["transformer\\.h\\.[0-9]+"], # tell torch to wrap the transformer's blocks individually fsdp_config: fsdp_config, activation_checkpointing: activation_checkpointing, }, diff --git a/integration_tests/README.md b/integration_tests/README.md deleted file mode 100644 index 333170109..000000000 --- a/integration_tests/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Integration tests - -These are a collection of longer running end-to-end tests of various parts of the Tango library. - -The easiest way to run any of these integration tests is by triggering the [**Integration tests**](https://github.com/allenai/tango/actions/workflows/integration_tests.yml) -workflow on GitHub Actions. Just select the "Run workflow" dropdown, then pick the test to run and the Beaker cluster to run it on, -and finally hit the "Run workflow" button. - -Each test should have a `run.sh` file in its folder that will run the relevant tango command. -This is what the **Integration tests** workflow will call, and you can also use it to run the test manually. diff --git a/integration_tests/fairscale_benchmarks/README.md b/integration_tests/fairscale_benchmarks/README.md deleted file mode 100644 index d59d7caa8..000000000 --- a/integration_tests/fairscale_benchmarks/README.md +++ /dev/null @@ -1,18 +0,0 @@ -# FairScale Benchmarks - -This integration test is for checking the performance of the `FairScaleTrainingEngine` with various configurations. - -**When to run it:** It should be ran every time there is a major PyTorch or FairScale upgrade. - -**Where to run it:** A server with 4 A100 GPUs. Make sure you set your `WANDB_API_KEY` environment variable. - -**How to run it:** From the root directory of this repository, run: -``` -integration_tests/fairscale_benchmarks/run.sh -``` - -By default, not all configurations are run. If you want to run change which configurations are run, open `config.jsonnet` -are search for "enabled". Then toggle this `enabled` field to `true` or `false` for each configuration. - -**What to look for:** The training jobs shouldn't fail, for one. After `tango run` completes, check the corresponding Weights & Biases -dashboard and inspect the results. Compare the various "fsdp" training runs with the baseline to ensure you see memory savings. diff --git a/integration_tests/fairscale_benchmarks/config.jsonnet b/integration_tests/fairscale_benchmarks/config.jsonnet deleted file mode 100644 index 109010fe8..000000000 --- a/integration_tests/fairscale_benchmarks/config.jsonnet +++ /dev/null @@ -1,296 +0,0 @@ -################## -# Model settings # -################## - -local pretrained_model = "gpt2"; -# local pretrained_model = "EleutherAI/gpt-j-6B"; -# This doesn't seem to work with gpt2, but works fine with gpt-j-6B. -local load_with_low_cpu_mem_usage = pretrained_model == "EleutherAI/gpt-j-6B"; - -#################### -# Trainer settings # -#################### - -# Trainer settings, adjust to your use-case. -local training_steps = 100; # total number of optimization steps to train for -local validate_every = 20; # how often to validate and save checkpoints - -local devices = 4; -local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size) -# This is the batch size per GPU, ignoring gradient accumulation: -local batch_size = 8; -# So the effective batch size is `batch_size * grad_accum * devices` - -###################### -# Optimizer settings # -###################### - -local warmup_steps = 20; -local learning_rate = if pretrained_model == "EleutherAI/gpt-j-6B" then 0.00001 else 0.0001; - - -# <----- you probably don't need to edit below this line ----> # - - -local distributed_dataloader = { - batch_size: batch_size, - collate_fn: { type: "transformers::DefaultDataCollator" }, - sampler: { - type: "torch::DistributedSampler", - shuffle: true, - drop_last: true, - }, -}; - -local single_device_dataloader = { - shuffle: true, - batch_size: batch_size, - collate_fn: { type: "transformers::DefaultDataCollator" }, -}; - -local TrainStep(options) = - local training_engine = { - type: if options.fsdp_config != null then "fairscale" else "torch", - optimizer: { - type: "torch::AdamW", - lr: learning_rate, - betas: [0.9, 0.95], - eps: 1e-6, - }, - lr_scheduler: { - type: "transformers::linear", - num_warmup_steps: warmup_steps, - num_training_steps: training_steps, - }, - amp: options.amp, - [if options.fsdp_config != null then "fsdp_config" else null]: options.fsdp_config, - }; - - { - type: "torch::train", - model: { - type: "fairscale::with_wrapped_modules", - model: { - type: "transformers::AutoModelForCausalLM::from_pretrained", - pretrained_model_name_or_path: pretrained_model, - low_cpu_mem_usage: load_with_low_cpu_mem_usage, - }, - modules_to_wrap: ["transformer\\.h\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually - fsdp_config: options.fsdp_config, - activation_checkpointing: options.activation_checkpointing, - }, - dataset_dict: { type: "ref", ref: "tokenized_data" }, - train_dataloader: distributed_dataloader, - validation_split: "validation", - grad_accum: grad_accum, - train_steps: training_steps, - validate_every: validate_every, - checkpoint_every: validate_every, - log_every: 1, - device_count: devices, - training_engine: training_engine, - callbacks: [ - { - type: "wandb::log", - entity: "allennlp", - project: "tango-fairscale-benchmarks", - wandb_config: options + { - effective_batch_size: batch_size * devices * grad_accum, - model: pretrained_model, - }, - }, - ], - }; - -{ - steps: { - raw_data: { - type: "datasets::load", - path: "wikitext", - name: "wikitext-2-raw-v1", - }, - tokenized_data: { - type: "tokenize_data", - dataset: { type: "ref", ref: "raw_data" }, - tokenizer: { pretrained_model_name_or_path: pretrained_model } - }, - } + { - ["trained_model_" + options.name]: TrainStep(options) - for options in [ - # NOTE: With 6B model, baseline and many others will fail with CUDA OOM. - # FSDP and activation checkpointing will be required for a 6B model. - { - name: "baseline", - enabled: false, - amp: false, - fsdp_config: null, - activation_checkpointing: false, - }, - { - name: "amp", - enabled: false, - amp: true, - fsdp_config: null, - activation_checkpointing: false, - }, - { - name: "checkpointing", - enabled: false, - amp: false, - fsdp_config: null, - activation_checkpointing: true, - }, - { - name: "amp_and_checkpointing", - enabled: false, - amp: true, - fsdp_config: null, - activation_checkpointing: true, - }, - { - name: "fsdp", - enabled: false, - amp: false, - activation_checkpointing: false, - fsdp_config: { - reshard_after_forward: true, - move_params_to_cpu: false, - move_grads_to_cpu: false, - mixed_precision: false, - }, - }, - { - name: "fsdp_no_reshard", - enabled: false, - amp: false, - activation_checkpointing: false, - fsdp_config: { - reshard_after_forward: false, - move_params_to_cpu: false, - move_grads_to_cpu: false, - mixed_precision: false, - }, - }, - { - name: "amp_and_fsdp", - enabled: false, - amp: true, - activation_checkpointing: false, - fsdp_config: { - reshard_after_forward: true, - move_params_to_cpu: false, - move_grads_to_cpu: false, - mixed_precision: false, - }, - }, - { - name: "amp_and_fsdp_no_reshard", - enabled: false, - amp: true, - activation_checkpointing: false, - fsdp_config: { - reshard_after_forward: false, - move_params_to_cpu: false, - move_grads_to_cpu: false, - mixed_precision: false, - }, - }, - { - name: "amp_and_fsdp_mp", - enabled: false, - amp: true, - activation_checkpointing: false, - fsdp_config: { - reshard_after_forward: true, - move_params_to_cpu: false, - move_grads_to_cpu: false, - mixed_precision: true, - }, - }, - { - name: "amp_and_fsdp_mp_no_reshard", - enabled: false, - amp: true, - activation_checkpointing: false, - fsdp_config: { - reshard_after_forward: false, - move_params_to_cpu: false, - move_grads_to_cpu: false, - mixed_precision: true, - }, - }, - { - name: "checkpointing_and_fsdp", - enabled: false, - amp: false, - activation_checkpointing: true, - fsdp_config: { - reshard_after_forward: true, - move_params_to_cpu: false, - move_grads_to_cpu: false, - mixed_precision: false, - }, - }, - { - name: "amp_and_checkpointing_and_fsdp", - enabled: false, - amp: true, - activation_checkpointing: true, - fsdp_config: { - reshard_after_forward: true, - move_params_to_cpu: false, - move_grads_to_cpu: false, - mixed_precision: false, - }, - }, - { - name: "amp_and_checkpointing_and_fsdp_mp", - enabled: true, - amp: true, - activation_checkpointing: true, - fsdp_config: { - reshard_after_forward: true, - move_params_to_cpu: false, - move_grads_to_cpu: false, - mixed_precision: true, - }, - }, - { - name: "checkpointing_and_fsdp_mp", - enabled: false, - amp: false, - activation_checkpointing: true, - fsdp_config: { - reshard_after_forward: true, - move_params_to_cpu: false, - move_grads_to_cpu: false, - mixed_precision: true, - }, - }, - { # This configuration currently does not work. Tracking https://github.com/facebookresearch/fairscale/issues/918 - name: "amp_and_checkpointing_and_fsdp_mp_with_partial_offloading", - enabled: false, - amp: true, - activation_checkpointing: true, - fsdp_config: { - reshard_after_forward: true, - move_params_to_cpu: true, - move_grads_to_cpu: false, - mixed_precision: true, - }, - }, - { - name: "amp_and_checkpointing_and_fsdp_mp_with_full_offloading", - enabled: false, - amp: true, - activation_checkpointing: true, - fsdp_config: { - reshard_after_forward: true, - move_params_to_cpu: true, - move_grads_to_cpu: true, - mixed_precision: true, - }, - }, - ] if options.enabled - } -} diff --git a/integration_tests/fairscale_benchmarks/run.sh b/integration_tests/fairscale_benchmarks/run.sh deleted file mode 100755 index 0b9ae1906..000000000 --- a/integration_tests/fairscale_benchmarks/run.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/sh - -tango run integration_tests/fairscale_benchmarks/config.jsonnet -i examples/train_lm/tokenize_step.py diff --git a/pyproject.toml b/pyproject.toml index 748795065..e545e6089 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,17 +76,12 @@ transformers = [ "numpy", "datasets>=1.12,<3.0", "transformers>=4.12.3", - "sentencepiece==0.1.98", + "sentencepiece==0.1.99", "sacremoses" ] datasets = [ "datasets>=1.12,<3.0" ] -fairscale = [ - "torch>=1.9,<2.1", - "numpy", - "fairscale>=0.4.6,<0.5" -] flax = [ "datasets>=1.12,<3.0", "jax>=0.3.13", @@ -106,7 +101,7 @@ gs = [ "google-cloud-datastore>=2.12.0" ] all = [ - "ai2-tango[examples,torch,transformers,datasets,fairscale,flax,wandb,beaker,gs]" + "ai2-tango[examples,torch,transformers,datasets,flax,wandb,beaker,gs]" ] [project.scripts] diff --git a/tango/integrations/fairscale/__init__.py b/tango/integrations/fairscale/__init__.py deleted file mode 100644 index dc14b2d40..000000000 --- a/tango/integrations/fairscale/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -.. important:: - To use this integration you should install ``tango`` with the "fairscale" extra - (e.g. ``pip install tango[fairscale]``) or just install FairScale after the fact. - - This integration also depends on `PyTorch `_, so make sure you - install the correct version of torch *first* given your operating system and supported - CUDA version. Check `pytorch.org/get-started/locally/ `_ - for more details. - -Components for Tango integration with `FairScale `_. - -Overview --------- - -FairScale is a PyTorch library for large scale training. Among other things, it implements -the main memory-savings techniques for distributed data-parallel training (DDP) that came from the paper -`ZeRO: Memory Optimization Towards Training A Trillion Parameter Models -`_. - -The main part of this Tango integration is the :class:`FairScaleTrainingEngine`. -This is a :class:`~tango.integrations.torch.TrainingEngine` implementation that utilizes -FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` (FSDP) for substantial memory savings -during distributed training. - -For the best performance you should also use :func:`with_wrapped_modules()` to wrap the inner modules -of your :class:`~tango.integrations.torch.Model`. When used with FSDP this will dramatically reduce -the memory required to load your model. - -""" - -from tango.common.exceptions import IntegrationMissingError - -try: - import fairscale -except ModuleNotFoundError: - raise IntegrationMissingError("fairscale") - -__all__ = [ - "FairScaleTrainingEngine", - "FSDPConfig", - "with_wrapped_modules", -] - -from .fsdp_config import FSDPConfig -from .module_wrapper import with_wrapped_modules -from .training_engine import FairScaleTrainingEngine diff --git a/tango/integrations/fairscale/fsdp_config.py b/tango/integrations/fairscale/fsdp_config.py deleted file mode 100644 index 50d47a3e4..000000000 --- a/tango/integrations/fairscale/fsdp_config.py +++ /dev/null @@ -1,80 +0,0 @@ -from dataclasses import asdict, dataclass -from typing import Any, Dict, Optional - -import torch -from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP - -from tango.common import FromParams - - -@dataclass -class FSDPConfig(FromParams): - """ - Defines all of the configurable options for FairScale's :class:`~fairscale.nn.FullyShardedDataParallel`. - - .. seealso:: - `Best practices for FullyShardedDataParallel `_ - from the FairScale docs. - - """ # noqa: E501 - - reshard_after_forward: bool = True - """ - See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. - """ - - move_params_to_cpu: bool = False - """ - See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. - """ - - move_grads_to_cpu: Optional[bool] = None - """ - See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. - - .. seealso:: - :data:`move_params_to_cpu` - - .. warning:: - At the moment we recommend that you don't mess with this parameter, or only explicitly - set it to the same value as :data:`move_params_to_cpu`. If you leave it as ``None`` - (the default), it will automatically be set to match :data:`move_params_to_cpu` by FairScale. - - Currently training seems to crash if you set this ``False`` while :data:`move_params_to_cpu` is ``True``. - We're tracking `fairscale#918 `_, - which may be related. - """ - - mixed_precision: bool = False - """ - See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. - - .. important:: - We recommend setting this to the same value as the ``amp`` parameter in - :class:`FairScaleTrainingEngine`. - - Based on our experiments, if you're training with AMP enabled (``amp=True``) - you might see a small additional speedup in training time along with a small - additional decrease in GPU memory utilization without any performance penalty - (with respect to convergence) by setting this to ``True``. - But if you're *not* training with AMP, setting this ``True`` could impact the - model's ability to converge. - - """ - - def as_kwargs(self) -> Dict[str, Any]: - """ - Convert to the appropriate ``kwargs`` for :class:`~fairscale.nn.FullyShardedDataParallel`. - """ - return asdict(self) - - def wrap(self, module: torch.nn.Module): - """ - A convenience method for wrapping a module in :class:`~fairscale.nn.FullyShardedDataParallel` - with all of the options defined in this class. - - .. seealso:: - Internally this is what :func:`with_wrapped_modules()` calls. - - """ - return FSDP(module, **self.as_kwargs()) diff --git a/tango/integrations/fairscale/training_engine.py b/tango/integrations/fairscale/training_engine.py deleted file mode 100644 index f88d16ac8..000000000 --- a/tango/integrations/fairscale/training_engine.py +++ /dev/null @@ -1,156 +0,0 @@ -import logging -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import torch -from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP -from fairscale.optim.grad_scaler import ShardedGradScaler - -from tango.common import Lazy -from tango.common.exceptions import ConfigurationError -from tango.integrations.torch import ( - LRScheduler, - Model, - Optimizer, - TorchTrainingEngine, - TrainConfig, - TrainingEngine, -) - -from .fsdp_config import FSDPConfig - - -@TrainingEngine.register("fairscale") -class FairScaleTrainingEngine(TorchTrainingEngine): - """ - A :class:`~tango.integrations.torch.TrainingEngine` that leverages FairScale's - :class:`~fairscale.nn.FullyShardedDataParallel` for use within - :class:`~tango.integrations.torch.TorchTrainStep`. - - .. tip:: - Registered as an :class:`~tango.integrations.torch.TrainingEngine` under the name - "fairscale". - - .. tip:: - To get the best performance out of :class:`FairScaleTrainingEngine` you should - wrap individual layers of your model with :class:`~fairscale.nn.FullyShardedDataParallel` - and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper` - while instantiating them. You can use :class:`with_wrapped_modules()` to accomplish this. - - .. important:: - Only the parameters listed below should be defined in a configuration - file. The other parameters will be automatically passed to the constructor - within :class:`~tango.integrations.torch.TorchTrainStep`. - - .. warning:: - :class:`~FairScaleTrainingEngine` can only be used in distributed training, i.e. - when ``device_count > 1`` in the :class:`~tango.integrations.torch.TorchTrainStep`. - - For maximum memory savings, we recommend training with AMP enabled and the following - :class:`FSDPConfig`: - - .. testcode:: - - from tango.integrations.fairscale import FSDPConfig - - fsdp_config = FSDPConfig( - reshard_after_forward=True, - move_params_to_cpu=True, - move_grads_to_cpu=True, - mixed_precision=True, - ) - - For maximum training *speed*, we recommend training with AMP enabled and the following - :class:`FSDPConfig`: - - .. testcode:: - - from tango.integrations.fairscale import FSDPConfig - - fsdp_config = FSDPConfig( - reshard_after_forward=False, - move_params_to_cpu=False, - move_grads_to_cpu=False, - mixed_precision=True, - ) - - :param amp: - Use automatic mixed precision (AMP). Default is ``False``. - :param max_grad_norm: - If set, gradients will be clipped to have this max norm. Default is ``None``. - :param amp_use_bfloat16: - Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training. - Only applicable when ``amp=True``. If not specified, the default behavior will be - to use ``bfloat16`` when training with AMP on CPU, otherwise not. - :param fsdp_config: - The options for :class:`~fairscale.nn.FullyShardedDataParallel`. - If not specified, the default options will be used. - - """ - - def __init__( - self, - train_config: TrainConfig, - model: Lazy[Model], - optimizer: Lazy[Optimizer], - *, - lr_scheduler: Optional[Lazy[LRScheduler]] = None, - amp: bool = False, - max_grad_norm: Optional[float] = None, - amp_use_bfloat16: Optional[bool] = None, - fsdp_config: Optional[FSDPConfig] = None, - ) -> None: - if not train_config.is_distributed: - raise ConfigurationError( - f"{self.__class__.__name__} can only be used with distributed training" - ) - - self.fsdp_config = fsdp_config or FSDPConfig() - self.logger = logging.getLogger(self.__class__.__name__) - - super().__init__( - train_config, - model, - optimizer, - lr_scheduler=lr_scheduler, - amp=amp, - max_grad_norm=max_grad_norm, - amp_use_bfloat16=amp_use_bfloat16, - ) - if amp: - self.grad_scaler = ShardedGradScaler() - - def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model: - if isinstance(model, Lazy): - model = model.construct() - if not self.fsdp_config.move_params_to_cpu: - model.to(self.train_config.worker_local_default_device) - return FSDP(model, **self.fsdp_config.as_kwargs()) - - def clip_grad_norm(self) -> None: - if self.max_grad_norm is not None: - self.model.clip_grad_norm_(self.max_grad_norm) # type: ignore - - def get_model_state(self) -> Dict[str, torch.Tensor]: - return { - "weights": self.model.local_state_dict(), # type: ignore - "metadata": self.model.local_metadata_dict(), # type: ignore - } - - def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None: - self.model.load_local_state_dict(state_dict["weights"]) # type: ignore - - def save_complete_weights_from_checkpoint( - self, checkpoint_dir: Path, weights_path: Path - ) -> None: - self.logger.info("Consolidating sharded checkpoint weights...") - sharded_weights: List[Dict[str, torch.Tensor]] = [] - sharded_metadata: List[Dict[str, Any]] = [] - for path in checkpoint_dir.resolve().glob("worker*_model.pt"): - sharded_state = torch.load(path, map_location="cpu") - sharded_weights.append(sharded_state["weights"]) - sharded_metadata.append(sharded_state["metadata"]) - full_state = FSDP.consolidate_shard_weights(sharded_weights, sharded_metadata) - del sharded_weights - del sharded_metadata - torch.save(full_state, weights_path) diff --git a/tango/integrations/torch/__init__.py b/tango/integrations/torch/__init__.py index ae5ecd21c..65044e68d 100644 --- a/tango/integrations/torch/__init__.py +++ b/tango/integrations/torch/__init__.py @@ -166,6 +166,9 @@ def _return_zero(self): from .eval_callback import EvalCallback from .exceptions import StopEarly from .format import TorchFormat +from .fsdp_config import FSDPConfig +from .fsdp_module_wrapper import with_wrapped_modules +from .fsdp_training_engine import FSDPTrainingEngine from .model import Model from .optim import LRScheduler, Optimizer from .train import TorchTrainStep diff --git a/tango/integrations/torch/fsdp_config.py b/tango/integrations/torch/fsdp_config.py new file mode 100644 index 000000000..405295c91 --- /dev/null +++ b/tango/integrations/torch/fsdp_config.py @@ -0,0 +1,55 @@ +from dataclasses import asdict, dataclass +from typing import Any, Dict, Optional + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from tango.common import FromParams + + +@dataclass +class FSDPConfig(FromParams): + """ + Defines all of the configurable options for Torch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel`. + """ # noqa: E501 + + sharding_strategy: Optional[str] = None + + cpu_offload: Optional[str] = None + + backward_prefetch: Optional[str] = None + + sync_module_states: bool = False + + forward_prefetch: bool = False + + limit_all_gathers: bool = False + + use_orig_params: bool = False + + mixed_precision: Optional[str] = None + """ + See the docstring for :class:`~torch.distributed.fsdp.FullyShardedDataParallel`. + + .. important:: + We recommend setting this to the same value as the ``amp`` parameter in + :class:`FSDPTrainingEngine`. + + """ + + def as_kwargs(self) -> Dict[str, Any]: + """ + Convert to the appropriate ``kwargs`` for :class:`~torch.distributed.fsdp.FullyShardedDataParallel`. + """ + return asdict(self) + + def wrap(self, module: torch.nn.Module): + """ + A convenience method for wrapping a module in :class:`~torch.distributed.fsdp.FullyShardedDataParallel` + with all the options defined in this class. + + .. seealso:: + Internally this is what :func:`with_wrapped_modules()` calls. + + """ + return FSDP(module, **self.as_kwargs()) diff --git a/tango/integrations/fairscale/module_wrapper.py b/tango/integrations/torch/fsdp_module_wrapper.py similarity index 83% rename from tango/integrations/fairscale/module_wrapper.py rename to tango/integrations/torch/fsdp_module_wrapper.py index 6c580cbc5..164bcdcf3 100644 --- a/tango/integrations/fairscale/module_wrapper.py +++ b/tango/integrations/torch/fsdp_module_wrapper.py @@ -3,14 +3,15 @@ import torch import torch.nn as nn -from fairscale.nn.checkpoint import checkpoint_wrapper - -from tango.integrations.torch import Model +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, +) from .fsdp_config import FSDPConfig +from .model import Model -@Model.register("fairscale::with_wrapped_modules") # type: ignore[arg-type] +@Model.register("torch::with_wrapped_modules") # type: ignore[arg-type] def with_wrapped_modules( model: Model, modules_to_wrap: Set[str], @@ -19,15 +20,15 @@ def with_wrapped_modules( ) -> Model: """ A :class:`~tango.integrations.torch.Model` wrapper that can be used to easily wrap - inner modules of a model with FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` wrapper - and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`. + inner modules of a model with Torch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel` wrapper + and/or :class:`~torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper`. .. tip:: Registered as a :class:`~tango.integrations.torch.Model` constructor under the name - "fairscale::with_wrapped_modules". + "torch::with_wrapped_modules". .. important:: - This is meant to be used with the :class:`FairScaleTrainingEngine`. + This is meant to be used with the :class:`FSDPTrainingEngine`. :param model: The model to wrap. @@ -37,8 +38,8 @@ def with_wrapped_modules( The ``FullyShardedDataParallel`` configuration to use when wrapping the modules. If not specified, the modules will NOT be wrapped with FSDP. :param activation_checkpointing: - Whether to wrap the modules with FairScale's - :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`. + Whether to wrap the modules with Torch's + :class:`~torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper`. Examples -------- @@ -77,7 +78,7 @@ def forward(self, x, y): model = Model.from_params({ - "type": "fairscale::with_wrapped_modules", + "type": "torch::with_wrapped_modules", "model": { "type": "simple_regression_model", }, diff --git a/tango/integrations/torch/fsdp_training_engine.py b/tango/integrations/torch/fsdp_training_engine.py new file mode 100644 index 000000000..9f8509acc --- /dev/null +++ b/tango/integrations/torch/fsdp_training_engine.py @@ -0,0 +1,213 @@ +import logging +import os +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType +from torch.distributed.fsdp.api import FullOptimStateDictConfig +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + +from tango.common import Lazy, Tqdm +from tango.common.exceptions import ConfigurationError + +from .fsdp_config import FSDPConfig +from .model import Model +from .optim import LRScheduler, Optimizer +from .train_config import TrainConfig +from .training_engine import TorchTrainingEngine, TrainingEngine + + +@TrainingEngine.register("torch::fsdp") +class FSDPTrainingEngine(TorchTrainingEngine): + """ + A :class:`~tango.integrations.torch.TrainingEngine` that leverages Torch's + :class:`~torch.distributed.fsdp.FullyShardedDataParallel` for use within + :class:`~tango.integrations.torch.TorchTrainStep`. + + .. tip:: + Registered as an :class:`~tango.integrations.torch.TrainingEngine` under the name + "torch::fsdp". + + .. tip:: + To get the best performance out of :class:`FSDPTrainingEngine` you should + wrap individual layers of your model with :class:`~torch.distributed.fsdp.FullyShardedDataParallel` + and/or :class:`~torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper` + while instantiating them. You can use :class:`with_wrapped_modules()` to accomplish this. + + .. important:: + Only the parameters listed below should be defined in a configuration + file. The other parameters will be automatically passed to the constructor + within :class:`~tango.integrations.torch.TorchTrainStep`. + + .. warning:: + :class:`~FSDPTrainingEngine` can only be used in distributed training, i.e. + when ``device_count > 1`` in the :class:`~tango.integrations.torch.TorchTrainStep`. + + For maximum memory savings, we recommend training with AMP enabled and the following + :class:`FSDPConfig`: + + .. testcode:: + + from tango.integrations.torch import FSDPConfig + + fsdp_config = FSDPConfig( + move_params_to_cpu=True, + move_grads_to_cpu=True, + mixed_precision=True, + ) + + For maximum training *speed*, we recommend training with AMP enabled and the following + :class:`FSDPConfig`: + + .. testcode:: + + from tango.integrations.torch import FSDPConfig + + fsdp_config = FSDPConfig( + move_params_to_cpu=False, + move_grads_to_cpu=False, + mixed_precision=True, + ) + + :param amp: + Use automatic mixed precision (AMP). Default is ``False``. + :param max_grad_norm: + If set, gradients will be clipped to have this max norm. Default is ``None``. + :param amp_use_bfloat16: + Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training. + Only applicable when ``amp=True``. If not specified, the default behavior will be + to use ``bfloat16`` when training with AMP on CPU, otherwise not. + :param fsdp_config: + The options for :class:`~torch.distributed.fsdp.FullyShardedDataParallel`. + If not specified, the default options will be used. + + """ + + def __init__( + self, + train_config: TrainConfig, + model: Lazy[Model], + optimizer: Lazy[Optimizer], + *, + lr_scheduler: Optional[Lazy[LRScheduler]] = None, + amp: bool = False, + max_grad_norm: Optional[float] = None, + amp_use_bfloat16: Optional[bool] = None, + fsdp_config: Optional[FSDPConfig] = None, + ) -> None: + if not train_config.is_distributed: + raise ConfigurationError( + f"{self.__class__.__name__} can only be used with distributed training" + ) + + self.fsdp_config = fsdp_config or FSDPConfig() + self.logger = logging.getLogger(self.__class__.__name__) + + super().__init__( + train_config, + model, + optimizer, + lr_scheduler=lr_scheduler, + amp=amp, + max_grad_norm=max_grad_norm, + amp_use_bfloat16=amp_use_bfloat16, + ) + if amp: + self.grad_scaler = ShardedGradScaler() + + def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model: + if isinstance(model, Lazy): + model = model.construct() + if not self.fsdp_config.move_params_to_cpu: + model.to(self.train_config.worker_local_default_device) + return FSDP(model, **self.fsdp_config.as_kwargs()) # type: ignore + + def clip_grad_norm(self) -> None: + if self.max_grad_norm is not None: + self.model.clip_grad_norm_(self.max_grad_norm) # type: ignore + + def get_model_state(self) -> Dict[str, torch.Tensor]: + return self.model.state_dict() + + def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None: + self.model.load_state_dict(state_dict) # type: ignore + + def save_checkpoint(self, checkpoint_dir: Path, client_state: Dict[str, Any]) -> None: + checkpoint_dir.mkdir(exist_ok=True) + + def save_state(state: Dict[str, Any], name: str): + # only rank 0 writes any files + if self.train_config.worker_id != 0: + return + + temp_state_file = tempfile.NamedTemporaryFile( + "w+b", dir=checkpoint_dir, delete=False, suffix=".pt" + ) + try: + with Tqdm.wrapattr( + temp_state_file, + "write", + desc=f"Saving {name} state", + leave=False, + disable=not self.train_config.is_local_main_process, + ) as f: + torch.save(state, f) + temp_state_file.close() + os.replace( + temp_state_file.name, + checkpoint_dir / f"worker0_{name}.pt", + ) + finally: + if os.path.exists(temp_state_file.name): + os.remove(temp_state_file.name) + + with FSDP.state_dict_type( + self.model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(rank0_only=True, offload_to_cpu=True), + FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True), + ): + save_state(self.get_model_state(), "model") + save_state(FSDP.optim_state_dict(self.model, self.optimizer), "optimizer") + if self.lr_scheduler is not None: + save_state(self.lr_scheduler.state_dict(), "lr_scheduler") + if self.grad_scaler is not None: + save_state(self.grad_scaler.state_dict(), "grad_scaler") + save_state(client_state, "trainer") + + def load_checkpoint(self, checkpoint_dir: Path) -> Dict[str, Any]: + with FSDP.state_dict_type( + self.model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(rank0_only=True, offload_to_cpu=True), + FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True), + ): + if self.train_config.worker_id == 0: + model_state_dict = torch.load(checkpoint_dir / "worker0_model.pt") + optimizer_state_dict = torch.load(checkpoint_dir / "worker0_optimizer.pt") + else: + model_state_dict = {} + optimizer_state_dict = {} + + self.load_model_state(model_state_dict) + optimizer_state_dict = FSDP.optim_state_dict_to_load( + optimizer_state_dict, self.model, self.optimizer + ) + self.optimizer.load_state_dict(optimizer_state_dict) + + # The states for LR scheduler, grad scaler, and trainer are identical on all workers, so we have + # all of them load the same files. + if self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict( + torch.load(checkpoint_dir / "worker0_lr_scheduler.pt") + ) + if self.grad_scaler is not None: + self.grad_scaler.load_state_dict( + torch.load(checkpoint_dir / "worker0_grad_scaler.pt") + ) + + return torch.load(checkpoint_dir / "worker0_trainer.pt") diff --git a/tango/integrations/torch/optim.py b/tango/integrations/torch/optim.py index 54cc2c6b0..56ac7c133 100644 --- a/tango/integrations/torch/optim.py +++ b/tango/integrations/torch/optim.py @@ -73,11 +73,29 @@ class LRScheduler(torch.optim.lr_scheduler._LRScheduler, Registrable): ): Optimizer.register("torch::" + name)(cls) +if torch.__version__.startswith("1."): + + def isLrSchedulerClass(cls): + return ( + isinstance(cls, type) + and issubclass(cls, torch.optim.lr_scheduler._LRScheduler) + and not cls == torch.optim.lr_scheduler._LRScheduler + ) + +elif torch.__version__.startswith("2."): + + def isLrSchedulerClass(cls): + return ( + isinstance(cls, type) + and issubclass(cls, torch.optim.lr_scheduler.LRScheduler) + and not cls == torch.optim.lr_scheduler.LRScheduler + and not cls == torch.optim.lr_scheduler._LRScheduler + ) + +else: + raise Exception("Unknown version of PyTorch") + # Register all learning rate schedulers. for name, cls in torch.optim.lr_scheduler.__dict__.items(): - if ( - isinstance(cls, type) - and issubclass(cls, torch.optim.lr_scheduler._LRScheduler) - and not cls == torch.optim.lr_scheduler._LRScheduler - ): + if isLrSchedulerClass(cls): LRScheduler.register("torch::" + name)(cls) diff --git a/tango/integrations/transformers/__init__.py b/tango/integrations/transformers/__init__.py index 6d9526717..449a25ce7 100644 --- a/tango/integrations/transformers/__init__.py +++ b/tango/integrations/transformers/__init__.py @@ -72,7 +72,8 @@ transformers::AutoModelForImageSegmentation::from_pretrained transformers::AutoModelForInstanceSegmentation::from_config transformers::AutoModelForInstanceSegmentation::from_pretrained - transformers::AutoModelForMaskedImageModeling::from_config + transformers::AutoModelForMaskGeneration::from_config + transformers::AutoModelForMaskGeneration::from_pretrained transformers::AutoModelForMaskedImageModeling::from_pretrained transformers::AutoModelForMaskedLM::from_config transformers::AutoModelForMaskedLM::from_pretrained diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py index 2afe4e667..90bd2d75c 100644 --- a/tango/integrations/transformers/finetune.py +++ b/tango/integrations/transformers/finetune.py @@ -424,7 +424,7 @@ def run( # type: ignore[override] # Hacky way to deal with resizing the model embeddings. model_params_dict = model._params.as_dict() - if "fairscale" in model_params_dict["type"]: + if "fsdp" in model_params_dict["type"]: model_params_dict["model"]["num_tokens"] = len(tokenizer) # type: ignore else: model_params_dict["num_tokens"] = len(tokenizer) # type: ignore diff --git a/test_fixtures/integrations/fairscale/__init__.py b/test_fixtures/integrations/fairscale/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test_fixtures/integrations/fairscale/components.py b/test_fixtures/integrations/torch/components.py similarity index 100% rename from test_fixtures/integrations/fairscale/components.py rename to test_fixtures/integrations/torch/components.py diff --git a/test_fixtures/integrations/fairscale/config.jsonnet b/test_fixtures/integrations/torch/fsdp_config.jsonnet similarity index 94% rename from test_fixtures/integrations/fairscale/config.jsonnet rename to test_fixtures/integrations/torch/fsdp_config.jsonnet index 468716b09..0e51e1544 100644 --- a/test_fixtures/integrations/fairscale/config.jsonnet +++ b/test_fixtures/integrations/torch/fsdp_config.jsonnet @@ -25,14 +25,13 @@ local learning_rate = 0.005; local fsdp_config = { - reshard_after_forward: true, move_params_to_cpu: cpu_offloading, move_grads_to_cpu: cpu_offloading, mixed_precision: amp, }; local training_engine = { - type: "fairscale", + type: "forch::fsdp", optimizer: { type: "torch::AdamW", lr: learning_rate, @@ -60,7 +59,7 @@ local dataloader = { trained_model: { type: "torch::train", model: { - type: "fairscale::with_wrapped_modules", + type: "torch::with_wrapped_modules", model: { type: "simple_regression_model", }, diff --git a/tests/common/params_test.py b/tests/common/params_test.py index 7bf6a0ec3..2378c52de 100644 --- a/tests/common/params_test.py +++ b/tests/common/params_test.py @@ -88,7 +88,7 @@ def test_as_flat_dict(self): assert params == {"a": 10, "b.c": 20, "b.d": "stuff"} def test_jsonnet_features(self): - config_file = self.TEST_DIR / "config.jsonnet" + config_file = self.TEST_DIR / "fsdp_config.jsonnet" with open(config_file, "w") as f: f.write( """{ @@ -187,7 +187,7 @@ def test_to_file(self): params_dict = {"keyA": "valA", "keyB": "valB"} expected_ordered_params_dict = OrderedDict({"keyB": "valB", "keyA": "valA"}) params = Params(params_dict) - file_path = self.TEST_DIR / "config.jsonnet" + file_path = self.TEST_DIR / "fsdp_config.jsonnet" # check with preference orders params.to_file(file_path, [["keyB", "keyA"]]) with open(file_path, "r") as handle: diff --git a/tests/integrations/fairscale/__init__.py b/tests/integrations/fairscale/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/integrations/fairscale/train_test.py b/tests/integrations/torch/fsdp_train_test.py similarity index 64% rename from tests/integrations/fairscale/train_test.py rename to tests/integrations/torch/fsdp_train_test.py index 522bbcec2..ca36347f9 100644 --- a/tests/integrations/fairscale/train_test.py +++ b/tests/integrations/torch/fsdp_train_test.py @@ -7,7 +7,7 @@ from tango.common.testing import TangoTestCase -class TestFairScaleTrain(TangoTestCase): +class TestFSDPTrain(TangoTestCase): def setup_method(self): super().setup_method() initialize_logging(log_level="info") @@ -16,11 +16,12 @@ def teardown_method(self): teardown_logging() @pytest.mark.parametrize( - "fsdp", - ( + "fsdp,activation_checkpoint", + [ pytest.param( True, - id="fsdp=True", + False, + id="fsdp=True-checkpointing=False", marks=[ pytest.mark.gpu, pytest.mark.skipif( @@ -28,15 +29,22 @@ def teardown_method(self): ), ], ), - pytest.param(False, id="fsdp=False"), - ), - ) - @pytest.mark.parametrize( - "activation_checkpoint", - ( - pytest.param(True, id="checkpointing=True"), - pytest.param(False, id="checkpointing=False"), - ), + pytest.param( + True, + True, + id="fsdp=True-checkpointing=True", + marks=[ + pytest.mark.gpu, + pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Requires CUDA devices" + ), + ], + ), + pytest.param(False, False, id="fsdp=False-checkpointing=False"), + # This last configuration will try to use DDP with checkpointing, which is not supported by torch. + # TODO: remove DDP and recommend just using FSDP for everything + # pytest.param(False, True, id="fsdp=False-checkpointing=True"), + ], ) @pytest.mark.parametrize( "amp", @@ -68,8 +76,8 @@ def test_train_tiny_gpt2(self, fsdp: bool, activation_checkpoint: bool, amp: boo }, } if fsdp: - training_engine["type"] = "fairscale" - fsdp_config = {"reshard_after_forward": True, "mixed_precision": amp} + training_engine["type"] = "torch::fsdp" + fsdp_config = {"mixed_precision": amp} training_engine["fsdp_config"] = fsdp_config overrides["steps.trained_model.model.fsdp_config"] = fsdp_config else: @@ -77,8 +85,8 @@ def test_train_tiny_gpt2(self, fsdp: bool, activation_checkpoint: bool, amp: boo overrides["steps.trained_model.model.fsdp_config"] = None overrides["steps.trained_model.training_engine"] = training_engine run_dir = self.run( - self.FIXTURES_ROOT / "integrations" / "fairscale" / "config.jsonnet", - include_package=["test_fixtures.integrations.fairscale.components"], + self.FIXTURES_ROOT / "integrations" / "torch" / "fsdp_config.jsonnet", + include_package=["test_fixtures.integrations.torch.components"], overrides=overrides, ) assert (run_dir / "trained_model").is_dir()