diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml
deleted file mode 100644
index 27385e7c5..000000000
--- a/.github/workflows/integration_tests.yml
+++ /dev/null
@@ -1,154 +0,0 @@
-name: Integration tests
-
-on:
- workflow_dispatch:
- inputs:
- test:
- description: the integration test to run
- default: fairscale_benchmarks
- required: true
- type: choice
- options:
- - fairscale_benchmarks
- cluster:
- description: the beaker cluster to run the test on
- default: ai2/tango-integration-tests
- required: true
- type: choice
- options:
- - ai2/tango-integration-tests
- - ai2/allennlp-cirrascale
- # Uncomment this trigger to test changes on a pull request.
- # You also have to uncomment the lines below that mention 'for pull request checks'
- # pull_request:
- # branches:
- # - '*'
-
-jobs:
- run_test:
- name: ${{ github.event.inputs.test }}
- # name: fairscale_benchmarks # for pull request checks
- runs-on: [ubuntu-latest]
- timeout-minutes: 60
- env:
- TEST_NAME: ${{ github.event.inputs.test }}
- # TEST_NAME: fairscale_benchmarks # for pull request checks
- BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
- BEAKER_WORKSPACE: ai2/tango-integration-tests
- BEAKER_CLUSTER: ${{ github.event.inputs.cluster }}
- # BEAKER_CLUSTER: ai2/allennlp-cirrascale # for pull request checks
- IMAGE_NAME: petew/tango-testing
- steps:
- - uses: actions/checkout@v3
-
- - name: Validate inputs
- run: |
- # The 'test' input should be a directory in `integration_tests/`
- test -d "integration_tests/${TEST_NAME}"
-
- - name: Determine current commit SHA (pull request)
- if: github.event_name == 'pull_request'
- run: |
- echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV
-
- - name: Determine current commit SHA (push)
- if: github.event_name != 'pull_request'
- run: |
- echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV
-
- - name: Install beaker client
- shell: bash
- run: |
- mkdir -p "$HOME/bin"
-
- # Download and install from latest GitHub release.
- curl -s https://api.github.com/repos/allenai/beaker/releases/latest \
- | grep 'browser_download_url.*linux' \
- | cut -d '"' -f 4 \
- | wget -qi - \
- && tar -xvzf beaker_linux.tar.gz -C "$HOME/bin"
-
- # Add to path.
- echo "$HOME/bin" >> "$GITHUB_PATH"
-
- - name: Verify beaker install
- run: |
- beaker account whoami
-
- - name: Create beaker experiment config
- run: |
- cat >beaker_config.yml << EOL
- version: v2-alpha
- description: ${{ env.TEST_NAME }}
- tasks:
- - name: test
- image:
- beaker: ${{ env.IMAGE_NAME }}
- command: ["/entrypoint.sh", "integration_tests/${{ env.TEST_NAME }}/run.sh"]
- envVars:
- - name: COMMIT_SHA
- value: $COMMIT_SHA
- - name: WANDB_API_KEY
- secret: WANDB_API_KEY
- - name: FILE_FRIENDLY_LOGGING
- value: "true"
- - name: TOKENIZERS_PARALLELISM # set this to avoid warnings
- value: "true"
- - name: PYTHONUNBUFFERED
- value: "true"
- result:
- path: '/results'
- resources:
- gpuCount: 4
- context:
- cluster: ${{ env.BEAKER_CLUSTER }}
- priority: normal
- EOL
- cat beaker_config.yml
-
- - name: Submit beaker job
- run: |
- TIMESTAMP=$(date +%H%M%S)
- EXPERIMENT=$(beaker experiment create beaker_config.yml --workspace $BEAKER_WORKSPACE --name "${TEST_NAME}-${{ github.run_number }}-${TIMESTAMP}" | awk '{print $2}')
- if [ -z "$EXPERIMENT" ]; then
- exit 1
- else
- echo "EXPERIMENT=$EXPERIMENT" >> $GITHUB_ENV
- echo "Experiment $EXPERIMENT submitted. See progress at https://beaker.org/ex/$EXPERIMENT"
- fi
-
- - name: Wait for job to finish
- run: |
- beaker experiment await $EXPERIMENT test finalized --timeout 60m
- # Check the job's exit code.
- test $(beaker experiment get $EXPERIMENT --format=json | jq '.[0].jobs[0].status.exitCode') -eq 0
-
- - name: Get logs
- if: always()
- run: |
- # EXPERIMENT could be empty if the submission step failed.
- # We'll exit right away if that's the case.
- if [ -z "$EXPERIMENT" ]; then
- echo "No logs to show"
- exit 0
- fi
-
- # Download logs from beaker.
- beaker experiment results $EXPERIMENT --prefix out.log --output results
-
- # If the experiment failed during startup, there might not be any logs.
- if [ -f results/test/out.log ]; then
- echo ""
- echo ">>> Logs:"
- echo ""
- cat results/test/out.log
- else
- echo "No logs to show"
- fi
-
- - name: Stop job
- if: cancelled()
- run: |
- if [ ! -z "$EXPERIMENT" ]; then
- beaker experiment stop $EXPERIMENT
- fi
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 9174a0021..fe75389c9 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -20,7 +20,7 @@ env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
BEAKER_WORKSPACE: ai2/tango-testing
- BEAKER_DEFAULT_CLUSTER: ai2/tango-gpu-tests
+ BEAKER_DEFAULT_CLUSTER: ai2/canary
BEAKER_IMAGE: petew/tango-testing
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@@ -93,12 +93,6 @@ jobs:
run: |
pytest -v --color=yes --doctest-modules tango/integrations/transformers tests/integrations/transformers
- - name: FairScale integration
- extras: dev,fairscale
- requires_torch: true
- run: |
- pytest -v --color=yes --doctest-modules tango/integrations/fairscale tests/integrations/fairscale
-
- name: W&B integration
extras: dev,torch,flax,wandb
requires_torch: true
@@ -298,7 +292,7 @@ jobs:
path: /unused
token: ${{ secrets.BEAKER_TOKEN }}
workspace: ${{ env.BEAKER_WORKSPACE }}
- clusters: ai2/general-cirrascale,ai2/allennlp-cirrascale,ai2/aristo-cirrascale,ai2/mosaic-cirrascale,ai2/s2-cirrascale
+ clusters: ai2/general-cirrascale,ai2/allennlp-cirrascale,ai2/aristo-cirrascale,ai2/mosaic-cirrascale,ai2/s2-cirrascale,ai2/mosaic-cirrascale-a100,ai2/prior-cirrascale,ai2/general-cirrascale-a100-80g-ib
release:
name: Release
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 917f69fb9..58e6f0ff8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixes a bug where `FromParams` would fail to parse when an object takes a `Step` argument directly.
- Changed a name so we don't override the built-in name `set`.
- Fixed a bug that would cause O(n^2) memory consumption in dense step graphs.
+- Fixed how we find learning rate schedulers in Torch 2.
## [v1.2.0](https://github.com/allenai/tango/releases/tag/v1.2.0) - 2023-02-10
diff --git a/README.md b/README.md
index 84ca7a569..80eb4395f 100644
--- a/README.md
+++ b/README.md
@@ -230,7 +230,7 @@ The motivation behind this library is that we can make research easier by compos
You can run the `tango` command through [pdb](https://docs.python.org/3/library/pdb.html). For example:
```bash
-python -m pdb -m tango run config.jsonnet
+python -m pdb -m tango run fsdp_config.jsonnet
```
### How is Tango different from [Metaflow](https://metaflow.org), [Airflow](https://airflow.apache.org), or [redun](https://github.com/insitro/redun)?
diff --git a/docs/source/api/integrations/fairscale.rst b/docs/source/api/integrations/fairscale.rst
deleted file mode 100644
index d7890f909..000000000
--- a/docs/source/api/integrations/fairscale.rst
+++ /dev/null
@@ -1,14 +0,0 @@
-🔥 FairScale
-============
-
-.. automodule:: tango.integrations.fairscale
-
-Reference
----------
-
-.. autoclass:: tango.integrations.fairscale.FairScaleTrainingEngine
-
-.. autoclass:: tango.integrations.fairscale.FSDPConfig
- :members:
-
-.. autofunction:: tango.integrations.fairscale.with_wrapped_modules
diff --git a/docs/source/api/integrations/index.rst b/docs/source/api/integrations/index.rst
index 91ddf4fd9..ab5f40be3 100644
--- a/docs/source/api/integrations/index.rst
+++ b/docs/source/api/integrations/index.rst
@@ -8,7 +8,6 @@ Integrations
:caption: Integrations
torch
- fairscale
datasets
transformers
wandb
diff --git a/docs/source/api/integrations/torch.rst b/docs/source/api/integrations/torch.rst
index 5996bc793..2c64f5d86 100644
--- a/docs/source/api/integrations/torch.rst
+++ b/docs/source/api/integrations/torch.rst
@@ -32,6 +32,8 @@ Model
.. autoclass:: tango.integrations.torch.Model
:members:
+.. autofunction:: tango.integrations.torch.with_wrapped_modules
+
TrainingEngine
~~~~~~~~~~~~~~
@@ -40,6 +42,11 @@ TrainingEngine
.. autoclass:: tango.integrations.torch.TorchTrainingEngine
+.. autoclass:: tango.integrations.torch.FSDPTrainingEngine
+
+.. autoclass:: tango.integrations.torch.FSDPConfig
+ :members:
+
Optim
~~~~~
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 7002ebf81..a5c814b86 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -62,7 +62,6 @@
"rich": ("https://rich.readthedocs.io/en/latest", None),
"torch": ("https://pytorch.org/docs/stable", None),
"flax": ("https://flax.readthedocs.io/en/latest", None),
- "fairscale": ("https://fairscale.readthedocs.io/en/latest/", None),
"datasets": ("https://huggingface.co/docs/datasets/master/en", None),
"transformers": ("https://huggingface.co/docs/transformers/master/en", None),
"beaker": ("https://beaker-py.readthedocs.io/en/latest/", None),
diff --git a/docs/source/examples/eval_p3.md b/docs/source/examples/eval_p3.md
index 8bfb87929..cd2ef7e93 100644
--- a/docs/source/examples/eval_p3.md
+++ b/docs/source/examples/eval_p3.md
@@ -22,5 +22,5 @@ to create the same configuration for all 10 prompts:
You can run the experiment with:
```bash
-tango run config.jsonnet -i eval -d /tmp/workspace
+tango run fsdp_config.jsonnet -i eval -d /tmp/workspace
```
diff --git a/docs/source/examples/train_lm.md b/docs/source/examples/train_lm.md
index 319950233..808c70e79 100644
--- a/docs/source/examples/train_lm.md
+++ b/docs/source/examples/train_lm.md
@@ -33,5 +33,5 @@ Next you'll need to create a configuration file that defines the experiment. Jus
Now we can run the experiment with:
```bash
-tango run config.jsonnet -i tokenize_step.py -d /tmp/results
+tango run fsdp_config.jsonnet -i tokenize_step.py -d /tmp/results
```
diff --git a/docs/source/first_steps.md b/docs/source/first_steps.md
index 16989205c..ca8421333 100644
--- a/docs/source/first_steps.md
+++ b/docs/source/first_steps.md
@@ -237,7 +237,7 @@ Tango will warn you when you try to cache a non-deterministic step.
This time when we run the experiment we'll designate a specific directory for Tango to use:
```bash
-$ tango run config.jsonnet -i components -d workspace/
+$ tango run fsdp_config.jsonnet -i components -d workspace/
```
```
Starting new run live-tarpon
@@ -262,7 +262,7 @@ $ cat workspace/runs/live-tarpon/add_numbers/data.json
Now look what happens when we run this step again:
```bash
-$ tango run config.jsonnet -i components -d workspace/
+$ tango run fsdp_config.jsonnet -i components -d workspace/
```
```
Starting new run modest-shrimp
@@ -290,7 +290,7 @@ If we changed the inputs to the step in `config.jsonnet`:
And ran it again:
```bash
-$ tango run config.jsonnet -i components -d workspace/
+$ tango run fsdp_config.jsonnet -i components -d workspace/
```
```
Starting new run true-parrot
diff --git a/examples/finetune/config.jsonnet b/examples/finetune/config.jsonnet
index 485739742..71c385d9b 100644
--- a/examples/finetune/config.jsonnet
+++ b/examples/finetune/config.jsonnet
@@ -23,7 +23,7 @@ local batch_size = 2;
local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2)
local amp = false; # use PyTorch's native automatic mixed precision
-local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2)
+local fsdp = false; # Use Torch's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2)
local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow.
######################
@@ -38,14 +38,13 @@ assert fsdp == true || cpu_offloading == false : "cpu_offloading only available
# FullyShardedDataParallel config:
local fsdp_config = if fsdp then {
- reshard_after_forward: true,
move_params_to_cpu: cpu_offloading,
move_grads_to_cpu: cpu_offloading,
mixed_precision: amp,
} else null;
local training_engine = {
- type: if fsdp then "fairscale" else "torch",
+ type: if fsdp then "torch::fsdp" else "torch",
optimizer: {
type: "torch::AdamW",
lr: learning_rate,
@@ -95,13 +94,13 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device
trained_model: {
type: "transformers::finetune",
model: {
- type: "fairscale::with_wrapped_modules",
+ type: "torch::with_wrapped_modules",
model: {
type: "transformers::finetune::from_pretrained",
pretrained_model_name_or_path: pretrained_model,
low_cpu_mem_usage: load_with_low_cpu_mem_usage,
},
- modules_to_wrap: modules_to_wrap, # tell FairScale to wrap the transformer's blocks individually
+ modules_to_wrap: modules_to_wrap, # tell torch to wrap the transformer's blocks individually
fsdp_config: fsdp_config,
activation_checkpointing: activation_checkpointing,
},
diff --git a/examples/train_lm/README.md b/examples/train_lm/README.md
index 59c347794..e35f4d002 100644
--- a/examples/train_lm/README.md
+++ b/examples/train_lm/README.md
@@ -6,7 +6,7 @@ This Tango example showcases how you could train or fine-tune a causal language
or [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) from [transformers](https://github.com/huggingface/transformers) on WikiText2 or a similar dataset.
It's best that you run this experiment on a machine with a GPU and PyTorch [properly installed](https://pytorch.org/get-started/locally/#start-locally), otherwise Tango will fall back to CPU-only and it will be extremely slow.
-This example also depends on [FairScale](https://fairscale.readthedocs.io/en/latest/), which allows you to leverage [`FullyShardedDataParallel`](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html) (FSDP) and [activation checkpointing](https://fairscale.readthedocs.io/en/latest/api/nn/checkpoint/checkpoint_activations.html) to fine-tune [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B) or a similar-sized model. Just set the constants `fsdp` and `activation_checkpointing` in the config to `true`.
+This example also uses [`FullyShardedDataParallel`](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) (FSDP) and [activation checkpointing](https://pytorch.org/docs/stable/checkpoint.html) to fine-tune [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B) or a similar-sized model. Just set the constants `fsdp` and `activation_checkpointing` in the config to `true`.
Without using CPU offloading you'll need at least 4 x 40GiB A100 GPUs, or a different configuration with a comparable amount of total GPU memory.
diff --git a/examples/train_lm/config.jsonnet b/examples/train_lm/config.jsonnet
index b99c8e098..e1e35bfeb 100644
--- a/examples/train_lm/config.jsonnet
+++ b/examples/train_lm/config.jsonnet
@@ -44,14 +44,13 @@ assert fsdp == true || cpu_offloading == false : "cpu_offloading only available
# FullyShardedDataParallel config:
local fsdp_config = if fsdp then {
- reshard_after_forward: true,
move_params_to_cpu: cpu_offloading,
move_grads_to_cpu: cpu_offloading,
mixed_precision: amp,
} else null;
local training_engine = {
- type: if fsdp then "fairscale" else "torch",
+ type: if fsdp then "torch::fsdp" else "torch",
optimizer: {
type: "torch::AdamW",
lr: learning_rate,
@@ -100,13 +99,13 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device
trained_model: {
type: "torch::train",
model: {
- type: "fairscale::with_wrapped_modules",
+ type: "torch::with_wrapped_modules",
model: {
type: "transformers::AutoModelForCausalLM::from_pretrained",
pretrained_model_name_or_path: pretrained_model,
low_cpu_mem_usage: load_with_low_cpu_mem_usage,
},
- modules_to_wrap: ["transformer\\.h\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually
+ modules_to_wrap: ["transformer\\.h\\.[0-9]+"], # tell torch to wrap the transformer's blocks individually
fsdp_config: fsdp_config,
activation_checkpointing: activation_checkpointing,
},
diff --git a/integration_tests/README.md b/integration_tests/README.md
deleted file mode 100644
index 333170109..000000000
--- a/integration_tests/README.md
+++ /dev/null
@@ -1,10 +0,0 @@
-# Integration tests
-
-These are a collection of longer running end-to-end tests of various parts of the Tango library.
-
-The easiest way to run any of these integration tests is by triggering the [**Integration tests**](https://github.com/allenai/tango/actions/workflows/integration_tests.yml)
-workflow on GitHub Actions. Just select the "Run workflow" dropdown, then pick the test to run and the Beaker cluster to run it on,
-and finally hit the "Run workflow" button.
-
-Each test should have a `run.sh` file in its folder that will run the relevant tango command.
-This is what the **Integration tests** workflow will call, and you can also use it to run the test manually.
diff --git a/integration_tests/fairscale_benchmarks/README.md b/integration_tests/fairscale_benchmarks/README.md
deleted file mode 100644
index d59d7caa8..000000000
--- a/integration_tests/fairscale_benchmarks/README.md
+++ /dev/null
@@ -1,18 +0,0 @@
-# FairScale Benchmarks
-
-This integration test is for checking the performance of the `FairScaleTrainingEngine` with various configurations.
-
-**When to run it:** It should be ran every time there is a major PyTorch or FairScale upgrade.
-
-**Where to run it:** A server with 4 A100 GPUs. Make sure you set your `WANDB_API_KEY` environment variable.
-
-**How to run it:** From the root directory of this repository, run:
-```
-integration_tests/fairscale_benchmarks/run.sh
-```
-
-By default, not all configurations are run. If you want to run change which configurations are run, open `config.jsonnet`
-are search for "enabled". Then toggle this `enabled` field to `true` or `false` for each configuration.
-
-**What to look for:** The training jobs shouldn't fail, for one. After `tango run` completes, check the corresponding Weights & Biases
-dashboard and inspect the results. Compare the various "fsdp" training runs with the baseline to ensure you see memory savings.
diff --git a/integration_tests/fairscale_benchmarks/config.jsonnet b/integration_tests/fairscale_benchmarks/config.jsonnet
deleted file mode 100644
index 109010fe8..000000000
--- a/integration_tests/fairscale_benchmarks/config.jsonnet
+++ /dev/null
@@ -1,296 +0,0 @@
-##################
-# Model settings #
-##################
-
-local pretrained_model = "gpt2";
-# local pretrained_model = "EleutherAI/gpt-j-6B";
-# This doesn't seem to work with gpt2, but works fine with gpt-j-6B.
-local load_with_low_cpu_mem_usage = pretrained_model == "EleutherAI/gpt-j-6B";
-
-####################
-# Trainer settings #
-####################
-
-# Trainer settings, adjust to your use-case.
-local training_steps = 100; # total number of optimization steps to train for
-local validate_every = 20; # how often to validate and save checkpoints
-
-local devices = 4;
-local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size)
-# This is the batch size per GPU, ignoring gradient accumulation:
-local batch_size = 8;
-# So the effective batch size is `batch_size * grad_accum * devices`
-
-######################
-# Optimizer settings #
-######################
-
-local warmup_steps = 20;
-local learning_rate = if pretrained_model == "EleutherAI/gpt-j-6B" then 0.00001 else 0.0001;
-
-
-# <----- you probably don't need to edit below this line ----> #
-
-
-local distributed_dataloader = {
- batch_size: batch_size,
- collate_fn: { type: "transformers::DefaultDataCollator" },
- sampler: {
- type: "torch::DistributedSampler",
- shuffle: true,
- drop_last: true,
- },
-};
-
-local single_device_dataloader = {
- shuffle: true,
- batch_size: batch_size,
- collate_fn: { type: "transformers::DefaultDataCollator" },
-};
-
-local TrainStep(options) =
- local training_engine = {
- type: if options.fsdp_config != null then "fairscale" else "torch",
- optimizer: {
- type: "torch::AdamW",
- lr: learning_rate,
- betas: [0.9, 0.95],
- eps: 1e-6,
- },
- lr_scheduler: {
- type: "transformers::linear",
- num_warmup_steps: warmup_steps,
- num_training_steps: training_steps,
- },
- amp: options.amp,
- [if options.fsdp_config != null then "fsdp_config" else null]: options.fsdp_config,
- };
-
- {
- type: "torch::train",
- model: {
- type: "fairscale::with_wrapped_modules",
- model: {
- type: "transformers::AutoModelForCausalLM::from_pretrained",
- pretrained_model_name_or_path: pretrained_model,
- low_cpu_mem_usage: load_with_low_cpu_mem_usage,
- },
- modules_to_wrap: ["transformer\\.h\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually
- fsdp_config: options.fsdp_config,
- activation_checkpointing: options.activation_checkpointing,
- },
- dataset_dict: { type: "ref", ref: "tokenized_data" },
- train_dataloader: distributed_dataloader,
- validation_split: "validation",
- grad_accum: grad_accum,
- train_steps: training_steps,
- validate_every: validate_every,
- checkpoint_every: validate_every,
- log_every: 1,
- device_count: devices,
- training_engine: training_engine,
- callbacks: [
- {
- type: "wandb::log",
- entity: "allennlp",
- project: "tango-fairscale-benchmarks",
- wandb_config: options + {
- effective_batch_size: batch_size * devices * grad_accum,
- model: pretrained_model,
- },
- },
- ],
- };
-
-{
- steps: {
- raw_data: {
- type: "datasets::load",
- path: "wikitext",
- name: "wikitext-2-raw-v1",
- },
- tokenized_data: {
- type: "tokenize_data",
- dataset: { type: "ref", ref: "raw_data" },
- tokenizer: { pretrained_model_name_or_path: pretrained_model }
- },
- } + {
- ["trained_model_" + options.name]: TrainStep(options)
- for options in [
- # NOTE: With 6B model, baseline and many others will fail with CUDA OOM.
- # FSDP and activation checkpointing will be required for a 6B model.
- {
- name: "baseline",
- enabled: false,
- amp: false,
- fsdp_config: null,
- activation_checkpointing: false,
- },
- {
- name: "amp",
- enabled: false,
- amp: true,
- fsdp_config: null,
- activation_checkpointing: false,
- },
- {
- name: "checkpointing",
- enabled: false,
- amp: false,
- fsdp_config: null,
- activation_checkpointing: true,
- },
- {
- name: "amp_and_checkpointing",
- enabled: false,
- amp: true,
- fsdp_config: null,
- activation_checkpointing: true,
- },
- {
- name: "fsdp",
- enabled: false,
- amp: false,
- activation_checkpointing: false,
- fsdp_config: {
- reshard_after_forward: true,
- move_params_to_cpu: false,
- move_grads_to_cpu: false,
- mixed_precision: false,
- },
- },
- {
- name: "fsdp_no_reshard",
- enabled: false,
- amp: false,
- activation_checkpointing: false,
- fsdp_config: {
- reshard_after_forward: false,
- move_params_to_cpu: false,
- move_grads_to_cpu: false,
- mixed_precision: false,
- },
- },
- {
- name: "amp_and_fsdp",
- enabled: false,
- amp: true,
- activation_checkpointing: false,
- fsdp_config: {
- reshard_after_forward: true,
- move_params_to_cpu: false,
- move_grads_to_cpu: false,
- mixed_precision: false,
- },
- },
- {
- name: "amp_and_fsdp_no_reshard",
- enabled: false,
- amp: true,
- activation_checkpointing: false,
- fsdp_config: {
- reshard_after_forward: false,
- move_params_to_cpu: false,
- move_grads_to_cpu: false,
- mixed_precision: false,
- },
- },
- {
- name: "amp_and_fsdp_mp",
- enabled: false,
- amp: true,
- activation_checkpointing: false,
- fsdp_config: {
- reshard_after_forward: true,
- move_params_to_cpu: false,
- move_grads_to_cpu: false,
- mixed_precision: true,
- },
- },
- {
- name: "amp_and_fsdp_mp_no_reshard",
- enabled: false,
- amp: true,
- activation_checkpointing: false,
- fsdp_config: {
- reshard_after_forward: false,
- move_params_to_cpu: false,
- move_grads_to_cpu: false,
- mixed_precision: true,
- },
- },
- {
- name: "checkpointing_and_fsdp",
- enabled: false,
- amp: false,
- activation_checkpointing: true,
- fsdp_config: {
- reshard_after_forward: true,
- move_params_to_cpu: false,
- move_grads_to_cpu: false,
- mixed_precision: false,
- },
- },
- {
- name: "amp_and_checkpointing_and_fsdp",
- enabled: false,
- amp: true,
- activation_checkpointing: true,
- fsdp_config: {
- reshard_after_forward: true,
- move_params_to_cpu: false,
- move_grads_to_cpu: false,
- mixed_precision: false,
- },
- },
- {
- name: "amp_and_checkpointing_and_fsdp_mp",
- enabled: true,
- amp: true,
- activation_checkpointing: true,
- fsdp_config: {
- reshard_after_forward: true,
- move_params_to_cpu: false,
- move_grads_to_cpu: false,
- mixed_precision: true,
- },
- },
- {
- name: "checkpointing_and_fsdp_mp",
- enabled: false,
- amp: false,
- activation_checkpointing: true,
- fsdp_config: {
- reshard_after_forward: true,
- move_params_to_cpu: false,
- move_grads_to_cpu: false,
- mixed_precision: true,
- },
- },
- { # This configuration currently does not work. Tracking https://github.com/facebookresearch/fairscale/issues/918
- name: "amp_and_checkpointing_and_fsdp_mp_with_partial_offloading",
- enabled: false,
- amp: true,
- activation_checkpointing: true,
- fsdp_config: {
- reshard_after_forward: true,
- move_params_to_cpu: true,
- move_grads_to_cpu: false,
- mixed_precision: true,
- },
- },
- {
- name: "amp_and_checkpointing_and_fsdp_mp_with_full_offloading",
- enabled: false,
- amp: true,
- activation_checkpointing: true,
- fsdp_config: {
- reshard_after_forward: true,
- move_params_to_cpu: true,
- move_grads_to_cpu: true,
- mixed_precision: true,
- },
- },
- ] if options.enabled
- }
-}
diff --git a/integration_tests/fairscale_benchmarks/run.sh b/integration_tests/fairscale_benchmarks/run.sh
deleted file mode 100755
index 0b9ae1906..000000000
--- a/integration_tests/fairscale_benchmarks/run.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-#!/bin/sh
-
-tango run integration_tests/fairscale_benchmarks/config.jsonnet -i examples/train_lm/tokenize_step.py
diff --git a/pyproject.toml b/pyproject.toml
index 748795065..e545e6089 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -76,17 +76,12 @@ transformers = [
"numpy",
"datasets>=1.12,<3.0",
"transformers>=4.12.3",
- "sentencepiece==0.1.98",
+ "sentencepiece==0.1.99",
"sacremoses"
]
datasets = [
"datasets>=1.12,<3.0"
]
-fairscale = [
- "torch>=1.9,<2.1",
- "numpy",
- "fairscale>=0.4.6,<0.5"
-]
flax = [
"datasets>=1.12,<3.0",
"jax>=0.3.13",
@@ -106,7 +101,7 @@ gs = [
"google-cloud-datastore>=2.12.0"
]
all = [
- "ai2-tango[examples,torch,transformers,datasets,fairscale,flax,wandb,beaker,gs]"
+ "ai2-tango[examples,torch,transformers,datasets,flax,wandb,beaker,gs]"
]
[project.scripts]
diff --git a/tango/integrations/fairscale/__init__.py b/tango/integrations/fairscale/__init__.py
deleted file mode 100644
index dc14b2d40..000000000
--- a/tango/integrations/fairscale/__init__.py
+++ /dev/null
@@ -1,47 +0,0 @@
-"""
-.. important::
- To use this integration you should install ``tango`` with the "fairscale" extra
- (e.g. ``pip install tango[fairscale]``) or just install FairScale after the fact.
-
- This integration also depends on `PyTorch `_, so make sure you
- install the correct version of torch *first* given your operating system and supported
- CUDA version. Check `pytorch.org/get-started/locally/ `_
- for more details.
-
-Components for Tango integration with `FairScale `_.
-
-Overview
---------
-
-FairScale is a PyTorch library for large scale training. Among other things, it implements
-the main memory-savings techniques for distributed data-parallel training (DDP) that came from the paper
-`ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
-`_.
-
-The main part of this Tango integration is the :class:`FairScaleTrainingEngine`.
-This is a :class:`~tango.integrations.torch.TrainingEngine` implementation that utilizes
-FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` (FSDP) for substantial memory savings
-during distributed training.
-
-For the best performance you should also use :func:`with_wrapped_modules()` to wrap the inner modules
-of your :class:`~tango.integrations.torch.Model`. When used with FSDP this will dramatically reduce
-the memory required to load your model.
-
-"""
-
-from tango.common.exceptions import IntegrationMissingError
-
-try:
- import fairscale
-except ModuleNotFoundError:
- raise IntegrationMissingError("fairscale")
-
-__all__ = [
- "FairScaleTrainingEngine",
- "FSDPConfig",
- "with_wrapped_modules",
-]
-
-from .fsdp_config import FSDPConfig
-from .module_wrapper import with_wrapped_modules
-from .training_engine import FairScaleTrainingEngine
diff --git a/tango/integrations/fairscale/fsdp_config.py b/tango/integrations/fairscale/fsdp_config.py
deleted file mode 100644
index 50d47a3e4..000000000
--- a/tango/integrations/fairscale/fsdp_config.py
+++ /dev/null
@@ -1,80 +0,0 @@
-from dataclasses import asdict, dataclass
-from typing import Any, Dict, Optional
-
-import torch
-from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
-
-from tango.common import FromParams
-
-
-@dataclass
-class FSDPConfig(FromParams):
- """
- Defines all of the configurable options for FairScale's :class:`~fairscale.nn.FullyShardedDataParallel`.
-
- .. seealso::
- `Best practices for FullyShardedDataParallel `_
- from the FairScale docs.
-
- """ # noqa: E501
-
- reshard_after_forward: bool = True
- """
- See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.
- """
-
- move_params_to_cpu: bool = False
- """
- See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.
- """
-
- move_grads_to_cpu: Optional[bool] = None
- """
- See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.
-
- .. seealso::
- :data:`move_params_to_cpu`
-
- .. warning::
- At the moment we recommend that you don't mess with this parameter, or only explicitly
- set it to the same value as :data:`move_params_to_cpu`. If you leave it as ``None``
- (the default), it will automatically be set to match :data:`move_params_to_cpu` by FairScale.
-
- Currently training seems to crash if you set this ``False`` while :data:`move_params_to_cpu` is ``True``.
- We're tracking `fairscale#918 `_,
- which may be related.
- """
-
- mixed_precision: bool = False
- """
- See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.
-
- .. important::
- We recommend setting this to the same value as the ``amp`` parameter in
- :class:`FairScaleTrainingEngine`.
-
- Based on our experiments, if you're training with AMP enabled (``amp=True``)
- you might see a small additional speedup in training time along with a small
- additional decrease in GPU memory utilization without any performance penalty
- (with respect to convergence) by setting this to ``True``.
- But if you're *not* training with AMP, setting this ``True`` could impact the
- model's ability to converge.
-
- """
-
- def as_kwargs(self) -> Dict[str, Any]:
- """
- Convert to the appropriate ``kwargs`` for :class:`~fairscale.nn.FullyShardedDataParallel`.
- """
- return asdict(self)
-
- def wrap(self, module: torch.nn.Module):
- """
- A convenience method for wrapping a module in :class:`~fairscale.nn.FullyShardedDataParallel`
- with all of the options defined in this class.
-
- .. seealso::
- Internally this is what :func:`with_wrapped_modules()` calls.
-
- """
- return FSDP(module, **self.as_kwargs())
diff --git a/tango/integrations/fairscale/training_engine.py b/tango/integrations/fairscale/training_engine.py
deleted file mode 100644
index f88d16ac8..000000000
--- a/tango/integrations/fairscale/training_engine.py
+++ /dev/null
@@ -1,156 +0,0 @@
-import logging
-from pathlib import Path
-from typing import Any, Dict, List, Optional, Union
-
-import torch
-from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
-from fairscale.optim.grad_scaler import ShardedGradScaler
-
-from tango.common import Lazy
-from tango.common.exceptions import ConfigurationError
-from tango.integrations.torch import (
- LRScheduler,
- Model,
- Optimizer,
- TorchTrainingEngine,
- TrainConfig,
- TrainingEngine,
-)
-
-from .fsdp_config import FSDPConfig
-
-
-@TrainingEngine.register("fairscale")
-class FairScaleTrainingEngine(TorchTrainingEngine):
- """
- A :class:`~tango.integrations.torch.TrainingEngine` that leverages FairScale's
- :class:`~fairscale.nn.FullyShardedDataParallel` for use within
- :class:`~tango.integrations.torch.TorchTrainStep`.
-
- .. tip::
- Registered as an :class:`~tango.integrations.torch.TrainingEngine` under the name
- "fairscale".
-
- .. tip::
- To get the best performance out of :class:`FairScaleTrainingEngine` you should
- wrap individual layers of your model with :class:`~fairscale.nn.FullyShardedDataParallel`
- and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`
- while instantiating them. You can use :class:`with_wrapped_modules()` to accomplish this.
-
- .. important::
- Only the parameters listed below should be defined in a configuration
- file. The other parameters will be automatically passed to the constructor
- within :class:`~tango.integrations.torch.TorchTrainStep`.
-
- .. warning::
- :class:`~FairScaleTrainingEngine` can only be used in distributed training, i.e.
- when ``device_count > 1`` in the :class:`~tango.integrations.torch.TorchTrainStep`.
-
- For maximum memory savings, we recommend training with AMP enabled and the following
- :class:`FSDPConfig`:
-
- .. testcode::
-
- from tango.integrations.fairscale import FSDPConfig
-
- fsdp_config = FSDPConfig(
- reshard_after_forward=True,
- move_params_to_cpu=True,
- move_grads_to_cpu=True,
- mixed_precision=True,
- )
-
- For maximum training *speed*, we recommend training with AMP enabled and the following
- :class:`FSDPConfig`:
-
- .. testcode::
-
- from tango.integrations.fairscale import FSDPConfig
-
- fsdp_config = FSDPConfig(
- reshard_after_forward=False,
- move_params_to_cpu=False,
- move_grads_to_cpu=False,
- mixed_precision=True,
- )
-
- :param amp:
- Use automatic mixed precision (AMP). Default is ``False``.
- :param max_grad_norm:
- If set, gradients will be clipped to have this max norm. Default is ``None``.
- :param amp_use_bfloat16:
- Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training.
- Only applicable when ``amp=True``. If not specified, the default behavior will be
- to use ``bfloat16`` when training with AMP on CPU, otherwise not.
- :param fsdp_config:
- The options for :class:`~fairscale.nn.FullyShardedDataParallel`.
- If not specified, the default options will be used.
-
- """
-
- def __init__(
- self,
- train_config: TrainConfig,
- model: Lazy[Model],
- optimizer: Lazy[Optimizer],
- *,
- lr_scheduler: Optional[Lazy[LRScheduler]] = None,
- amp: bool = False,
- max_grad_norm: Optional[float] = None,
- amp_use_bfloat16: Optional[bool] = None,
- fsdp_config: Optional[FSDPConfig] = None,
- ) -> None:
- if not train_config.is_distributed:
- raise ConfigurationError(
- f"{self.__class__.__name__} can only be used with distributed training"
- )
-
- self.fsdp_config = fsdp_config or FSDPConfig()
- self.logger = logging.getLogger(self.__class__.__name__)
-
- super().__init__(
- train_config,
- model,
- optimizer,
- lr_scheduler=lr_scheduler,
- amp=amp,
- max_grad_norm=max_grad_norm,
- amp_use_bfloat16=amp_use_bfloat16,
- )
- if amp:
- self.grad_scaler = ShardedGradScaler()
-
- def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model:
- if isinstance(model, Lazy):
- model = model.construct()
- if not self.fsdp_config.move_params_to_cpu:
- model.to(self.train_config.worker_local_default_device)
- return FSDP(model, **self.fsdp_config.as_kwargs())
-
- def clip_grad_norm(self) -> None:
- if self.max_grad_norm is not None:
- self.model.clip_grad_norm_(self.max_grad_norm) # type: ignore
-
- def get_model_state(self) -> Dict[str, torch.Tensor]:
- return {
- "weights": self.model.local_state_dict(), # type: ignore
- "metadata": self.model.local_metadata_dict(), # type: ignore
- }
-
- def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None:
- self.model.load_local_state_dict(state_dict["weights"]) # type: ignore
-
- def save_complete_weights_from_checkpoint(
- self, checkpoint_dir: Path, weights_path: Path
- ) -> None:
- self.logger.info("Consolidating sharded checkpoint weights...")
- sharded_weights: List[Dict[str, torch.Tensor]] = []
- sharded_metadata: List[Dict[str, Any]] = []
- for path in checkpoint_dir.resolve().glob("worker*_model.pt"):
- sharded_state = torch.load(path, map_location="cpu")
- sharded_weights.append(sharded_state["weights"])
- sharded_metadata.append(sharded_state["metadata"])
- full_state = FSDP.consolidate_shard_weights(sharded_weights, sharded_metadata)
- del sharded_weights
- del sharded_metadata
- torch.save(full_state, weights_path)
diff --git a/tango/integrations/torch/__init__.py b/tango/integrations/torch/__init__.py
index ae5ecd21c..65044e68d 100644
--- a/tango/integrations/torch/__init__.py
+++ b/tango/integrations/torch/__init__.py
@@ -166,6 +166,9 @@ def _return_zero(self):
from .eval_callback import EvalCallback
from .exceptions import StopEarly
from .format import TorchFormat
+from .fsdp_config import FSDPConfig
+from .fsdp_module_wrapper import with_wrapped_modules
+from .fsdp_training_engine import FSDPTrainingEngine
from .model import Model
from .optim import LRScheduler, Optimizer
from .train import TorchTrainStep
diff --git a/tango/integrations/torch/fsdp_config.py b/tango/integrations/torch/fsdp_config.py
new file mode 100644
index 000000000..405295c91
--- /dev/null
+++ b/tango/integrations/torch/fsdp_config.py
@@ -0,0 +1,55 @@
+from dataclasses import asdict, dataclass
+from typing import Any, Dict, Optional
+
+import torch
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+from tango.common import FromParams
+
+
+@dataclass
+class FSDPConfig(FromParams):
+ """
+ Defines all of the configurable options for Torch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel`.
+ """ # noqa: E501
+
+ sharding_strategy: Optional[str] = None
+
+ cpu_offload: Optional[str] = None
+
+ backward_prefetch: Optional[str] = None
+
+ sync_module_states: bool = False
+
+ forward_prefetch: bool = False
+
+ limit_all_gathers: bool = False
+
+ use_orig_params: bool = False
+
+ mixed_precision: Optional[str] = None
+ """
+ See the docstring for :class:`~torch.distributed.fsdp.FullyShardedDataParallel`.
+
+ .. important::
+ We recommend setting this to the same value as the ``amp`` parameter in
+ :class:`FSDPTrainingEngine`.
+
+ """
+
+ def as_kwargs(self) -> Dict[str, Any]:
+ """
+ Convert to the appropriate ``kwargs`` for :class:`~torch.distributed.fsdp.FullyShardedDataParallel`.
+ """
+ return asdict(self)
+
+ def wrap(self, module: torch.nn.Module):
+ """
+ A convenience method for wrapping a module in :class:`~torch.distributed.fsdp.FullyShardedDataParallel`
+ with all the options defined in this class.
+
+ .. seealso::
+ Internally this is what :func:`with_wrapped_modules()` calls.
+
+ """
+ return FSDP(module, **self.as_kwargs())
diff --git a/tango/integrations/fairscale/module_wrapper.py b/tango/integrations/torch/fsdp_module_wrapper.py
similarity index 83%
rename from tango/integrations/fairscale/module_wrapper.py
rename to tango/integrations/torch/fsdp_module_wrapper.py
index 6c580cbc5..164bcdcf3 100644
--- a/tango/integrations/fairscale/module_wrapper.py
+++ b/tango/integrations/torch/fsdp_module_wrapper.py
@@ -3,14 +3,15 @@
import torch
import torch.nn as nn
-from fairscale.nn.checkpoint import checkpoint_wrapper
-
-from tango.integrations.torch import Model
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+ checkpoint_wrapper,
+)
from .fsdp_config import FSDPConfig
+from .model import Model
-@Model.register("fairscale::with_wrapped_modules") # type: ignore[arg-type]
+@Model.register("torch::with_wrapped_modules") # type: ignore[arg-type]
def with_wrapped_modules(
model: Model,
modules_to_wrap: Set[str],
@@ -19,15 +20,15 @@ def with_wrapped_modules(
) -> Model:
"""
A :class:`~tango.integrations.torch.Model` wrapper that can be used to easily wrap
- inner modules of a model with FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` wrapper
- and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`.
+ inner modules of a model with Torch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel` wrapper
+ and/or :class:`~torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper`.
.. tip::
Registered as a :class:`~tango.integrations.torch.Model` constructor under the name
- "fairscale::with_wrapped_modules".
+ "torch::with_wrapped_modules".
.. important::
- This is meant to be used with the :class:`FairScaleTrainingEngine`.
+ This is meant to be used with the :class:`FSDPTrainingEngine`.
:param model:
The model to wrap.
@@ -37,8 +38,8 @@ def with_wrapped_modules(
The ``FullyShardedDataParallel`` configuration to use when wrapping the modules.
If not specified, the modules will NOT be wrapped with FSDP.
:param activation_checkpointing:
- Whether to wrap the modules with FairScale's
- :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`.
+ Whether to wrap the modules with Torch's
+ :class:`~torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper`.
Examples
--------
@@ -77,7 +78,7 @@ def forward(self, x, y):
model = Model.from_params({
- "type": "fairscale::with_wrapped_modules",
+ "type": "torch::with_wrapped_modules",
"model": {
"type": "simple_regression_model",
},
diff --git a/tango/integrations/torch/fsdp_training_engine.py b/tango/integrations/torch/fsdp_training_engine.py
new file mode 100644
index 000000000..9f8509acc
--- /dev/null
+++ b/tango/integrations/torch/fsdp_training_engine.py
@@ -0,0 +1,213 @@
+import logging
+import os
+import tempfile
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+import torch
+from torch.distributed.fsdp import FullStateDictConfig
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import StateDictType
+from torch.distributed.fsdp.api import FullOptimStateDictConfig
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+
+from tango.common import Lazy, Tqdm
+from tango.common.exceptions import ConfigurationError
+
+from .fsdp_config import FSDPConfig
+from .model import Model
+from .optim import LRScheduler, Optimizer
+from .train_config import TrainConfig
+from .training_engine import TorchTrainingEngine, TrainingEngine
+
+
+@TrainingEngine.register("torch::fsdp")
+class FSDPTrainingEngine(TorchTrainingEngine):
+ """
+ A :class:`~tango.integrations.torch.TrainingEngine` that leverages Torch's
+ :class:`~torch.distributed.fsdp.FullyShardedDataParallel` for use within
+ :class:`~tango.integrations.torch.TorchTrainStep`.
+
+ .. tip::
+ Registered as an :class:`~tango.integrations.torch.TrainingEngine` under the name
+ "torch::fsdp".
+
+ .. tip::
+ To get the best performance out of :class:`FSDPTrainingEngine` you should
+ wrap individual layers of your model with :class:`~torch.distributed.fsdp.FullyShardedDataParallel`
+ and/or :class:`~torch.distributed.algorithms._checkpoint.checkpoint_wrapper.checkpoint_wrapper`
+ while instantiating them. You can use :class:`with_wrapped_modules()` to accomplish this.
+
+ .. important::
+ Only the parameters listed below should be defined in a configuration
+ file. The other parameters will be automatically passed to the constructor
+ within :class:`~tango.integrations.torch.TorchTrainStep`.
+
+ .. warning::
+ :class:`~FSDPTrainingEngine` can only be used in distributed training, i.e.
+ when ``device_count > 1`` in the :class:`~tango.integrations.torch.TorchTrainStep`.
+
+ For maximum memory savings, we recommend training with AMP enabled and the following
+ :class:`FSDPConfig`:
+
+ .. testcode::
+
+ from tango.integrations.torch import FSDPConfig
+
+ fsdp_config = FSDPConfig(
+ move_params_to_cpu=True,
+ move_grads_to_cpu=True,
+ mixed_precision=True,
+ )
+
+ For maximum training *speed*, we recommend training with AMP enabled and the following
+ :class:`FSDPConfig`:
+
+ .. testcode::
+
+ from tango.integrations.torch import FSDPConfig
+
+ fsdp_config = FSDPConfig(
+ move_params_to_cpu=False,
+ move_grads_to_cpu=False,
+ mixed_precision=True,
+ )
+
+ :param amp:
+ Use automatic mixed precision (AMP). Default is ``False``.
+ :param max_grad_norm:
+ If set, gradients will be clipped to have this max norm. Default is ``None``.
+ :param amp_use_bfloat16:
+ Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training.
+ Only applicable when ``amp=True``. If not specified, the default behavior will be
+ to use ``bfloat16`` when training with AMP on CPU, otherwise not.
+ :param fsdp_config:
+ The options for :class:`~torch.distributed.fsdp.FullyShardedDataParallel`.
+ If not specified, the default options will be used.
+
+ """
+
+ def __init__(
+ self,
+ train_config: TrainConfig,
+ model: Lazy[Model],
+ optimizer: Lazy[Optimizer],
+ *,
+ lr_scheduler: Optional[Lazy[LRScheduler]] = None,
+ amp: bool = False,
+ max_grad_norm: Optional[float] = None,
+ amp_use_bfloat16: Optional[bool] = None,
+ fsdp_config: Optional[FSDPConfig] = None,
+ ) -> None:
+ if not train_config.is_distributed:
+ raise ConfigurationError(
+ f"{self.__class__.__name__} can only be used with distributed training"
+ )
+
+ self.fsdp_config = fsdp_config or FSDPConfig()
+ self.logger = logging.getLogger(self.__class__.__name__)
+
+ super().__init__(
+ train_config,
+ model,
+ optimizer,
+ lr_scheduler=lr_scheduler,
+ amp=amp,
+ max_grad_norm=max_grad_norm,
+ amp_use_bfloat16=amp_use_bfloat16,
+ )
+ if amp:
+ self.grad_scaler = ShardedGradScaler()
+
+ def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model:
+ if isinstance(model, Lazy):
+ model = model.construct()
+ if not self.fsdp_config.move_params_to_cpu:
+ model.to(self.train_config.worker_local_default_device)
+ return FSDP(model, **self.fsdp_config.as_kwargs()) # type: ignore
+
+ def clip_grad_norm(self) -> None:
+ if self.max_grad_norm is not None:
+ self.model.clip_grad_norm_(self.max_grad_norm) # type: ignore
+
+ def get_model_state(self) -> Dict[str, torch.Tensor]:
+ return self.model.state_dict()
+
+ def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None:
+ self.model.load_state_dict(state_dict) # type: ignore
+
+ def save_checkpoint(self, checkpoint_dir: Path, client_state: Dict[str, Any]) -> None:
+ checkpoint_dir.mkdir(exist_ok=True)
+
+ def save_state(state: Dict[str, Any], name: str):
+ # only rank 0 writes any files
+ if self.train_config.worker_id != 0:
+ return
+
+ temp_state_file = tempfile.NamedTemporaryFile(
+ "w+b", dir=checkpoint_dir, delete=False, suffix=".pt"
+ )
+ try:
+ with Tqdm.wrapattr(
+ temp_state_file,
+ "write",
+ desc=f"Saving {name} state",
+ leave=False,
+ disable=not self.train_config.is_local_main_process,
+ ) as f:
+ torch.save(state, f)
+ temp_state_file.close()
+ os.replace(
+ temp_state_file.name,
+ checkpoint_dir / f"worker0_{name}.pt",
+ )
+ finally:
+ if os.path.exists(temp_state_file.name):
+ os.remove(temp_state_file.name)
+
+ with FSDP.state_dict_type(
+ self.model,
+ StateDictType.FULL_STATE_DICT,
+ FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
+ FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
+ ):
+ save_state(self.get_model_state(), "model")
+ save_state(FSDP.optim_state_dict(self.model, self.optimizer), "optimizer")
+ if self.lr_scheduler is not None:
+ save_state(self.lr_scheduler.state_dict(), "lr_scheduler")
+ if self.grad_scaler is not None:
+ save_state(self.grad_scaler.state_dict(), "grad_scaler")
+ save_state(client_state, "trainer")
+
+ def load_checkpoint(self, checkpoint_dir: Path) -> Dict[str, Any]:
+ with FSDP.state_dict_type(
+ self.model,
+ StateDictType.FULL_STATE_DICT,
+ FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
+ FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
+ ):
+ if self.train_config.worker_id == 0:
+ model_state_dict = torch.load(checkpoint_dir / "worker0_model.pt")
+ optimizer_state_dict = torch.load(checkpoint_dir / "worker0_optimizer.pt")
+ else:
+ model_state_dict = {}
+ optimizer_state_dict = {}
+
+ self.load_model_state(model_state_dict)
+ optimizer_state_dict = FSDP.optim_state_dict_to_load(
+ optimizer_state_dict, self.model, self.optimizer
+ )
+ self.optimizer.load_state_dict(optimizer_state_dict)
+
+ # The states for LR scheduler, grad scaler, and trainer are identical on all workers, so we have
+ # all of them load the same files.
+ if self.lr_scheduler is not None:
+ self.lr_scheduler.load_state_dict(
+ torch.load(checkpoint_dir / "worker0_lr_scheduler.pt")
+ )
+ if self.grad_scaler is not None:
+ self.grad_scaler.load_state_dict(
+ torch.load(checkpoint_dir / "worker0_grad_scaler.pt")
+ )
+
+ return torch.load(checkpoint_dir / "worker0_trainer.pt")
diff --git a/tango/integrations/torch/optim.py b/tango/integrations/torch/optim.py
index 54cc2c6b0..56ac7c133 100644
--- a/tango/integrations/torch/optim.py
+++ b/tango/integrations/torch/optim.py
@@ -73,11 +73,29 @@ class LRScheduler(torch.optim.lr_scheduler._LRScheduler, Registrable):
):
Optimizer.register("torch::" + name)(cls)
+if torch.__version__.startswith("1."):
+
+ def isLrSchedulerClass(cls):
+ return (
+ isinstance(cls, type)
+ and issubclass(cls, torch.optim.lr_scheduler._LRScheduler)
+ and not cls == torch.optim.lr_scheduler._LRScheduler
+ )
+
+elif torch.__version__.startswith("2."):
+
+ def isLrSchedulerClass(cls):
+ return (
+ isinstance(cls, type)
+ and issubclass(cls, torch.optim.lr_scheduler.LRScheduler)
+ and not cls == torch.optim.lr_scheduler.LRScheduler
+ and not cls == torch.optim.lr_scheduler._LRScheduler
+ )
+
+else:
+ raise Exception("Unknown version of PyTorch")
+
# Register all learning rate schedulers.
for name, cls in torch.optim.lr_scheduler.__dict__.items():
- if (
- isinstance(cls, type)
- and issubclass(cls, torch.optim.lr_scheduler._LRScheduler)
- and not cls == torch.optim.lr_scheduler._LRScheduler
- ):
+ if isLrSchedulerClass(cls):
LRScheduler.register("torch::" + name)(cls)
diff --git a/tango/integrations/transformers/__init__.py b/tango/integrations/transformers/__init__.py
index 6d9526717..449a25ce7 100644
--- a/tango/integrations/transformers/__init__.py
+++ b/tango/integrations/transformers/__init__.py
@@ -72,7 +72,8 @@
transformers::AutoModelForImageSegmentation::from_pretrained
transformers::AutoModelForInstanceSegmentation::from_config
transformers::AutoModelForInstanceSegmentation::from_pretrained
- transformers::AutoModelForMaskedImageModeling::from_config
+ transformers::AutoModelForMaskGeneration::from_config
+ transformers::AutoModelForMaskGeneration::from_pretrained
transformers::AutoModelForMaskedImageModeling::from_pretrained
transformers::AutoModelForMaskedLM::from_config
transformers::AutoModelForMaskedLM::from_pretrained
diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py
index 2afe4e667..90bd2d75c 100644
--- a/tango/integrations/transformers/finetune.py
+++ b/tango/integrations/transformers/finetune.py
@@ -424,7 +424,7 @@ def run( # type: ignore[override]
# Hacky way to deal with resizing the model embeddings.
model_params_dict = model._params.as_dict()
- if "fairscale" in model_params_dict["type"]:
+ if "fsdp" in model_params_dict["type"]:
model_params_dict["model"]["num_tokens"] = len(tokenizer) # type: ignore
else:
model_params_dict["num_tokens"] = len(tokenizer) # type: ignore
diff --git a/test_fixtures/integrations/fairscale/__init__.py b/test_fixtures/integrations/fairscale/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/test_fixtures/integrations/fairscale/components.py b/test_fixtures/integrations/torch/components.py
similarity index 100%
rename from test_fixtures/integrations/fairscale/components.py
rename to test_fixtures/integrations/torch/components.py
diff --git a/test_fixtures/integrations/fairscale/config.jsonnet b/test_fixtures/integrations/torch/fsdp_config.jsonnet
similarity index 94%
rename from test_fixtures/integrations/fairscale/config.jsonnet
rename to test_fixtures/integrations/torch/fsdp_config.jsonnet
index 468716b09..0e51e1544 100644
--- a/test_fixtures/integrations/fairscale/config.jsonnet
+++ b/test_fixtures/integrations/torch/fsdp_config.jsonnet
@@ -25,14 +25,13 @@ local learning_rate = 0.005;
local fsdp_config = {
- reshard_after_forward: true,
move_params_to_cpu: cpu_offloading,
move_grads_to_cpu: cpu_offloading,
mixed_precision: amp,
};
local training_engine = {
- type: "fairscale",
+ type: "forch::fsdp",
optimizer: {
type: "torch::AdamW",
lr: learning_rate,
@@ -60,7 +59,7 @@ local dataloader = {
trained_model: {
type: "torch::train",
model: {
- type: "fairscale::with_wrapped_modules",
+ type: "torch::with_wrapped_modules",
model: {
type: "simple_regression_model",
},
diff --git a/tests/common/params_test.py b/tests/common/params_test.py
index 7bf6a0ec3..2378c52de 100644
--- a/tests/common/params_test.py
+++ b/tests/common/params_test.py
@@ -88,7 +88,7 @@ def test_as_flat_dict(self):
assert params == {"a": 10, "b.c": 20, "b.d": "stuff"}
def test_jsonnet_features(self):
- config_file = self.TEST_DIR / "config.jsonnet"
+ config_file = self.TEST_DIR / "fsdp_config.jsonnet"
with open(config_file, "w") as f:
f.write(
"""{
@@ -187,7 +187,7 @@ def test_to_file(self):
params_dict = {"keyA": "valA", "keyB": "valB"}
expected_ordered_params_dict = OrderedDict({"keyB": "valB", "keyA": "valA"})
params = Params(params_dict)
- file_path = self.TEST_DIR / "config.jsonnet"
+ file_path = self.TEST_DIR / "fsdp_config.jsonnet"
# check with preference orders
params.to_file(file_path, [["keyB", "keyA"]])
with open(file_path, "r") as handle:
diff --git a/tests/integrations/fairscale/__init__.py b/tests/integrations/fairscale/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/tests/integrations/fairscale/train_test.py b/tests/integrations/torch/fsdp_train_test.py
similarity index 64%
rename from tests/integrations/fairscale/train_test.py
rename to tests/integrations/torch/fsdp_train_test.py
index 522bbcec2..ca36347f9 100644
--- a/tests/integrations/fairscale/train_test.py
+++ b/tests/integrations/torch/fsdp_train_test.py
@@ -7,7 +7,7 @@
from tango.common.testing import TangoTestCase
-class TestFairScaleTrain(TangoTestCase):
+class TestFSDPTrain(TangoTestCase):
def setup_method(self):
super().setup_method()
initialize_logging(log_level="info")
@@ -16,11 +16,12 @@ def teardown_method(self):
teardown_logging()
@pytest.mark.parametrize(
- "fsdp",
- (
+ "fsdp,activation_checkpoint",
+ [
pytest.param(
True,
- id="fsdp=True",
+ False,
+ id="fsdp=True-checkpointing=False",
marks=[
pytest.mark.gpu,
pytest.mark.skipif(
@@ -28,15 +29,22 @@ def teardown_method(self):
),
],
),
- pytest.param(False, id="fsdp=False"),
- ),
- )
- @pytest.mark.parametrize(
- "activation_checkpoint",
- (
- pytest.param(True, id="checkpointing=True"),
- pytest.param(False, id="checkpointing=False"),
- ),
+ pytest.param(
+ True,
+ True,
+ id="fsdp=True-checkpointing=True",
+ marks=[
+ pytest.mark.gpu,
+ pytest.mark.skipif(
+ torch.cuda.device_count() < 2, reason="Requires CUDA devices"
+ ),
+ ],
+ ),
+ pytest.param(False, False, id="fsdp=False-checkpointing=False"),
+ # This last configuration will try to use DDP with checkpointing, which is not supported by torch.
+ # TODO: remove DDP and recommend just using FSDP for everything
+ # pytest.param(False, True, id="fsdp=False-checkpointing=True"),
+ ],
)
@pytest.mark.parametrize(
"amp",
@@ -68,8 +76,8 @@ def test_train_tiny_gpt2(self, fsdp: bool, activation_checkpoint: bool, amp: boo
},
}
if fsdp:
- training_engine["type"] = "fairscale"
- fsdp_config = {"reshard_after_forward": True, "mixed_precision": amp}
+ training_engine["type"] = "torch::fsdp"
+ fsdp_config = {"mixed_precision": amp}
training_engine["fsdp_config"] = fsdp_config
overrides["steps.trained_model.model.fsdp_config"] = fsdp_config
else:
@@ -77,8 +85,8 @@ def test_train_tiny_gpt2(self, fsdp: bool, activation_checkpoint: bool, amp: boo
overrides["steps.trained_model.model.fsdp_config"] = None
overrides["steps.trained_model.training_engine"] = training_engine
run_dir = self.run(
- self.FIXTURES_ROOT / "integrations" / "fairscale" / "config.jsonnet",
- include_package=["test_fixtures.integrations.fairscale.components"],
+ self.FIXTURES_ROOT / "integrations" / "torch" / "fsdp_config.jsonnet",
+ include_package=["test_fixtures.integrations.torch.components"],
overrides=overrides,
)
assert (run_dir / "trained_model").is_dir()