From de5f40616b8111b2177e880b9965e02a6822a4f0 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 23 Oct 2024 11:34:46 +0000 Subject: [PATCH] 2024-10-23 nightly release (1a57ce124b5a4d508776dc1f0ff25bd8c5466fb2) --- .github/scripts/validate_binaries.sh | 11 ++-- .github/workflows/release_build.yml | 36 +++++------ .github/workflows/unittest_ci.yml | 16 +++++ .github/workflows/validate-binaries.yml | 2 +- README.MD | 9 +++ docs/source/index.rst | 2 +- docs/source/setup-torchrec.rst | 6 +- torchrec/distributed/comm_ops.py | 20 +++--- torchrec/distributed/planner/enumerators.py | 9 ++- .../planner/tests/test_proposers.py | 22 +++++++ .../planner/tests/test_shard_estimators.py | 32 +++++++++- .../test_utils/test_model_parallel.py | 64 ++++++++++++++----- torchrec/distributed/train_pipeline/utils.py | 10 ++- 13 files changed, 176 insertions(+), 63 deletions(-) diff --git a/.github/scripts/validate_binaries.sh b/.github/scripts/validate_binaries.sh index cbd7eed88..22009dac3 100755 --- a/.github/scripts/validate_binaries.sh +++ b/.github/scripts/validate_binaries.sh @@ -21,13 +21,8 @@ if [[ ${MATRIX_GPU_ARCH_TYPE} = 'rocm' ]]; then exit 0 fi -if [[ ${MATRIX_PYTHON_VERSION} = '3.12' ]]; then - echo "Temporarily disable validation for Python 3.12" - exit 0 -fi - if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then - export CUDA_VERSION="cu118" + export CUDA_VERSION="cu124" else export CUDA_VERSION="cpu" fi @@ -36,8 +31,10 @@ fi if [[ ${MATRIX_GPU_ARCH_TYPE} = 'cuda' ]]; then if [[ ${MATRIX_GPU_ARCH_VERSION} = '11.8' ]]; then export CUDA_VERSION="cu118" - else + elif [[ ${MATRIX_GPU_ARCH_VERSION} = '12.1' ]]; then export CUDA_VERSION="cu121" + else + export CUDA_VERSION="cu124" fi else export CUDA_VERSION="cpu" diff --git a/.github/workflows/release_build.yml b/.github/workflows/release_build.yml index 4dd841c02..8be8ea628 100644 --- a/.github/workflows/release_build.yml +++ b/.github/workflows/release_build.yml @@ -18,37 +18,32 @@ jobs: strategy: matrix: include: - - os: linux.2xlarge - python-version: 3.8 - python-tag: "py38" - cuda-tag: "cu121" - os: linux.2xlarge python-version: 3.9 python-tag: "py39" - cuda-tag: "cu121" + cuda-tag: "cu124" - os: linux.2xlarge python-version: '3.10' python-tag: "py310" - cuda-tag: "cu121" + cuda-tag: "cu124" - os: linux.2xlarge python-version: '3.11' python-tag: "py311" - cuda-tag: "cu121" + cuda-tag: "cu124" - os: linux.2xlarge python-version: '3.12' python-tag: "py312" - cuda-tag: "cu121" + cuda-tag: "cu124" steps: # Checkout the repository to the GitHub Actions runner - name: Check ldd --version run: ldd --version - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Update pip run: | sudo yum update -y sudo yum -y install git python3-pip - sudo pip3 install --upgrade pip - name: Setup conda run: | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh @@ -73,12 +68,12 @@ jobs: - name: Install PyTorch and CUDA shell: bash run: | - conda run -n build_binary pip install torch --index-url https://download.pytorch.org/whl/test/cu121 + conda run -n build_binary pip install torch - name: Install fbgemm shell: bash run: | conda run -n build_binary pip install numpy - conda run -n build_binary pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/test/cu121 + conda run -n build_binary pip install fbgemm-gpu - name: Install Dependencies shell: bash run: | @@ -102,7 +97,7 @@ jobs: python setup.py bdist_wheel \ --python-tag=${{ matrix.python-tag }} - name: Upload wheel as GHA artifact - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: torchrec_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl path: dist/torchrec-*.whl @@ -112,9 +107,9 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [linux.4xlarge.nvidia.gpu] - python-version: [3.8, 3.9, "3.10", "3.11", "3.12"] - cuda-tag: ["cu121"] + os: [linux.g5.12xlarge.nvidia.gpu] + python-version: [3.9, "3.10", "3.11", "3.12"] + cuda-tag: ["cu124"] needs: build_on_cpu # the glibc version should match the version of the one we used to build the binary # for this case, it's 2.26 @@ -149,12 +144,11 @@ jobs: sudo lshw -C display # Checkout the repository to the GitHub Actions runner - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Update pip run: | sudo yum update -y sudo yum -y install git python3-pip - sudo pip3 install --upgrade pip - name: Setup conda run: | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh @@ -179,19 +173,19 @@ jobs: - name: Install PyTorch and CUDA shell: bash run: | - conda run -n build_binary pip install torch --index-url https://download.pytorch.org/whl/test/cu121 + conda run -n build_binary pip install torch # download wheel from GHA - name: Install fbgemm shell: bash run: | conda run -n build_binary pip install numpy - conda run -n build_binary pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/test/cu121 + conda run -n build_binary pip install fbgemm-gpu - name: Install torchmetrics shell: bash run: | conda run -n build_binary pip install torchmetrics==1.0.3 - name: Download wheel - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v4 with: name: torchrec_${{ matrix.python-version }}_${{ matrix.cuda-tag }}.whl - name: Display structure of downloaded files diff --git a/.github/workflows/unittest_ci.yml b/.github/workflows/unittest_ci.yml index 2e120457a..8865acee4 100644 --- a/.github/workflows/unittest_ci.yml +++ b/.github/workflows/unittest_ci.yml @@ -23,6 +23,10 @@ jobs: python-version: 3.9 python-tag: "py39" cuda-tag: "cu121" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: 3.9 + python-tag: "py39" + cuda-tag: "cu124" - os: linux.g5.12xlarge.nvidia.gpu python-version: '3.10' python-tag: "py310" @@ -31,6 +35,10 @@ jobs: python-version: '3.10' python-tag: "py310" cuda-tag: "cu121" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.10' + python-tag: "py310" + cuda-tag: "cu124" - os: linux.g5.12xlarge.nvidia.gpu python-version: '3.11' python-tag: "py311" @@ -39,6 +47,10 @@ jobs: python-version: '3.11' python-tag: "py311" cuda-tag: "cu121" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.11' + python-tag: "py311" + cuda-tag: "cu124" - os: linux.g5.12xlarge.nvidia.gpu python-version: '3.12' python-tag: "py312" @@ -47,6 +59,10 @@ jobs: python-version: '3.12' python-tag: "py312" cuda-tag: "cu121" + - os: linux.g5.12xlarge.nvidia.gpu + python-version: '3.12' + python-tag: "py312" + cuda-tag: "cu124" uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: runner: ${{ matrix.os }} diff --git a/.github/workflows/validate-binaries.yml b/.github/workflows/validate-binaries.yml index 7f3cd8f69..248857214 100644 --- a/.github/workflows/validate-binaries.yml +++ b/.github/workflows/validate-binaries.yml @@ -16,7 +16,7 @@ on: workflow_dispatch: inputs: channel: - description: "Channel to use (nightly, release, test)" + description: "Channel to use (nightly, release, test, pypi)" required: true type: choice options: diff --git a/README.MD b/README.MD index 1f90acdf0..44fc026f6 100644 --- a/README.MD +++ b/README.MD @@ -8,6 +8,7 @@ TorchRec has been used to accelerate advancements in recommendation systems, som * [Disaggregated Multi-Tower: Topology-aware Modeling Technique for Efficient Large-Scale Recommendation](https://arxiv.org/abs/2403.00877) paper * [The Algorithm ML](https://github.com/twitter/the-algorithm-ml) from Twitter * [Training Recommendation Models with Databricks](https://docs.databricks.com/en/machine-learning/train-recommender-models.html) +* [Toward 100TB model with Embedding Offloading Paper](https://dl.acm.org/doi/10.1145/3640457.3688037) ## Introduction @@ -39,6 +40,10 @@ Check out the [Getting Started](https://pytorch.org/torchrec/setup-torchrec.html 1. Install pytorch. See [pytorch documentation](https://pytorch.org/get-started/locally/). ``` + CUDA 12.4 + + pip install torch --index-url https://download.pytorch.org/whl/nightly/cu124 + CUDA 12.1 pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121 @@ -60,6 +65,10 @@ Check out the [Getting Started](https://pytorch.org/torchrec/setup-torchrec.html 3. Install FBGEMM. ``` + CUDA 12.4 + + pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu124 + CUDA 12.1 pip install fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu121 diff --git a/docs/source/index.rst b/docs/source/index.rst index bcc8ed67e..c6fa49282 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -61,7 +61,7 @@ If you are interested in helping improve the TorchRec project, here is how you can contribute: 1. **Visit Our** `GitHub Repository `__: - There yoou can find the source code, issues, and ongoing projects. + There you can find the source code, issues, and ongoing projects. 1. **Submit Feedback or Issues**: If you encounter any bugs or have suggestions for improvements, please submit an issue through the diff --git a/docs/source/setup-torchrec.rst b/docs/source/setup-torchrec.rst index 65d6ff530..7a2cbf969 100644 --- a/docs/source/setup-torchrec.rst +++ b/docs/source/setup-torchrec.rst @@ -23,7 +23,7 @@ Below demonstrates the compatability matrix that is currently tested: * - Python Version - 3.9, 3.10, 3.11, 3.12 * - Compute Platform - - CPU, CUDA 11.8, CUDA 12.1 + - CPU, CUDA 11.8, CUDA 12.1, CUDA 12.4 Aside from those requirements, TorchRec's core dependencies are PyTorch and FBGEMM. If your system is compatible with both libraries generally, then it should be sufficient for TorchRec. @@ -50,7 +50,7 @@ Therefore, specific versions of TorchRec and FBGEMM should correspond to a speci Installation ------------ -Below we show installations for CUDA 12.1 as an example. For CPU or CUDA 11.8, swap ``cu121`` for ``cpu`` or ``cu118``. +Below we show installations for CUDA 12.1 as an example. For CPU, CUDA 11.8, or CUDA 12.4, swap ``cu121`` for ``cpu``, ``cu118``, or ``cu124`` respectively. .. tab-set:: @@ -63,7 +63,7 @@ Below we show installations for CUDA 12.1 as an example. For CPU or CUDA 11.8, s pip install torchmetrics==1.0.3 pip install torchrec --index-url https://download.pytorch.org/whl/cu121 - .. tab-item:: **Stable via PyPI (Only for CUDA 12.1)** + .. tab-item:: **Stable via PyPI (Only for CUDA 12.4)** .. code-block:: bash diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 34f2091e8..031804937 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -2153,8 +2153,8 @@ def forward( if rsi.codecs is not None: inputs = rsi.codecs.forward.encode(inputs) output = inputs.new_empty((inputs.size(0) // my_size, inputs.size(1))) - with record_function("## reduce_scatter_base ##"): - req = dist._reduce_scatter_base( + with record_function("## reduce_scatter_tensor ##"): + req = dist.reduce_scatter_tensor( output, inputs, group=pg, @@ -2222,7 +2222,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: grad_output = rsi.codecs.backward.encode(grad_output) grad_inputs = grad_output.new_empty(rsi.input_sizes) with record_function("## reduce_scatter_base_bw (all_gather) ##"): - req = dist._all_gather_base( + req = dist.all_gather_into_tensor( grad_inputs, grad_output.contiguous(), group=ctx.pg, @@ -2250,8 +2250,8 @@ def forward( input = agi.codecs.forward.encode(input) outputs = input.new_empty((input.size(0) * my_size, input.size(1))) - with record_function("## all_gather_base ##"): - req = dist._all_gather_base( + with record_function("## all_gather_into_tensor ##"): + req = dist.all_gather_into_tensor( outputs, input, group=pg, @@ -2319,7 +2319,7 @@ def backward(ctx, grad_outputs: Tensor) -> Tuple[None, None, Tensor]: grad_outputs = agi.codecs.backward.encode(grad_outputs) grad_input = grad_outputs.new_empty(agi.input_size) with record_function("## all_gather_base_bw (reduce_scatter) ##"): - req = dist._reduce_scatter_base( + req = dist.reduce_scatter_tensor( grad_input, grad_outputs.contiguous(), group=ctx.pg, @@ -2349,11 +2349,11 @@ def forward( output = input.new_empty(rsi.input_sizes[my_rank]) - # Use dist._reduce_scatter_base when a vector reduce-scatter is not needed + # Use dist.reduce_scatter_tensor when a vector reduce-scatter is not needed # else use dist.reduce_scatter which internally supports vector reduce-scatter if rsi.equal_splits: - with record_function("## reduce_scatter_base ##"): - req = dist._reduce_scatter_base( + with record_function("## reduce_scatter_tensor ##"): + req = dist.reduce_scatter_tensor( output, input, group=pg, @@ -2434,7 +2434,7 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]: if rsi.equal_splits: with record_function("## reduce_scatter_base_bw (all_gather) ##"): - req = dist._all_gather_base( + req = dist.all_gather_into_tensor( grad_input, grad_output.contiguous(), group=ctx.pg, diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 4845a52a8..66ea9ee2d 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -235,11 +235,16 @@ def populate_estimates(self, sharding_options: List[ShardingOption]) -> None: def _filter_sharding_types( self, name: str, allowed_sharding_types: List[str] ) -> List[str]: + # GRID_SHARD is only supported if specified by user in parameter constraints if not self._constraints or not self._constraints.get(name): - return allowed_sharding_types + return [ + t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value + ] constraints: ParameterConstraints = self._constraints[name] if not constraints.sharding_types: - return allowed_sharding_types + return [ + t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value + ] constrained_sharding_types: List[str] = constraints.sharding_types filtered_sharding_types = list( diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index e9ca17905..ca0612e0f 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -111,6 +111,7 @@ def setUp(self) -> None: self.uniform_proposer = UniformProposer() self.grid_search_proposer = GridSearchProposer() self.dynamic_programming_proposer = DynamicProgrammingProposer() + self._sharding_types = [x.value for x in ShardingType] def test_greedy_two_table(self) -> None: tables = [ @@ -127,6 +128,17 @@ def test_greedy_two_table(self) -> None: feature_names=["feature_1"], ), ] + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present. This means + the greedy proposer will have a different order of sharding types on each test invocation + which we cannot have a harcoded "correct" answer for. We mock the call to _filter_sharding_types + to ensure the order of the sharding types list is always the same. + """ + self.enumerator._filter_sharding_types = MagicMock( + return_value=self._sharding_types + ) model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) search_space = self.enumerator.enumerate( @@ -335,6 +347,16 @@ def test_grid_search_three_table(self) -> None: for i in range(1, 4) ] model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present, we mock the + call to _filter_sharding_types to ensure the order of the sharding types list is always + the same. + """ + self.enumerator._filter_sharding_types = MagicMock( + return_value=self._sharding_types + ) search_space = self.enumerator.enumerate( module=model, sharders=[ diff --git a/torchrec/distributed/planner/tests/test_shard_estimators.py b/torchrec/distributed/planner/tests/test_shard_estimators.py index ce52702cb..de97f9d35 100644 --- a/torchrec/distributed/planner/tests/test_shard_estimators.py +++ b/torchrec/distributed/planner/tests/test_shard_estimators.py @@ -11,7 +11,7 @@ import unittest from typing import cast, Dict, List, Tuple -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import torch import torchrec.optim as trec_optim @@ -59,6 +59,7 @@ def setUp(self) -> None: self.enumerator = EmbeddingEnumerator( topology=self.topology, batch_size=BATCH_SIZE, estimator=self.estimator ) + self._sharding_types = [x.value for x in ShardingType] def test_1_table_perf(self) -> None: tables = [ @@ -70,6 +71,16 @@ def test_1_table_perf(self) -> None: ) ] model = TestSparseNN(tables=tables, weighted_tables=[]) + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present, we mock the + call to _filter_sharding_types to ensure the order of the sharding types list is always + the same. + """ + self.enumerator._filter_sharding_types = MagicMock( + return_value=self._sharding_types + ) sharding_options = self.enumerator.enumerate( module=model, sharders=[ @@ -321,6 +332,17 @@ def test_1_table_perf_with_fp8_comm(self) -> None: ) ) + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present, we mock the + call to _filter_sharding_types to ensure the order of the sharding types list is always + the same. + """ + self.enumerator._filter_sharding_types = MagicMock( + return_value=self._sharding_types + ) + sharding_options = self.enumerator.enumerate( module=model, sharders=[ @@ -530,6 +552,14 @@ def cacheability(self) -> float: estimator=self.estimator, constraints=constraints, ) + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present, we mock the + call to _filter_sharding_types to ensure the order of the sharding types list is always + the same. + """ + enumerator._filter_sharding_types = MagicMock(return_value=self._sharding_types) model = TestSparseNN(tables=tables, weighted_tables=[]) sharding_options = enumerator.enumerate( module=model, diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index c9bc22ab3..372eb6c75 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -725,14 +725,30 @@ def test_sharding_grid( backend=self.backend, qcomms_config=qcomms_config, constraints={ - "table_0": ParameterConstraints(min_partition=8), - "table_1": ParameterConstraints(min_partition=12), - "table_2": ParameterConstraints(min_partition=16), - "table_3": ParameterConstraints(min_partition=20), - "table_4": ParameterConstraints(min_partition=8), - "table_5": ParameterConstraints(min_partition=12), - "weighted_table_0": ParameterConstraints(min_partition=8), - "weighted_table_1": ParameterConstraints(min_partition=12), + "table_0": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_1": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_2": ParameterConstraints( + min_partition=16, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_3": ParameterConstraints( + min_partition=20, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_4": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_5": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_0": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_1": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), }, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, pooling=pooling, @@ -800,14 +816,30 @@ def test_sharding_grid_8gpu( backend=self.backend, qcomms_config=qcomms_config, constraints={ - "table_0": ParameterConstraints(min_partition=8), - "table_1": ParameterConstraints(min_partition=12), - "table_2": ParameterConstraints(min_partition=8), - "table_3": ParameterConstraints(min_partition=10), - "table_4": ParameterConstraints(min_partition=4), - "table_5": ParameterConstraints(min_partition=6), - "weighted_table_0": ParameterConstraints(min_partition=2), - "weighted_table_1": ParameterConstraints(min_partition=3), + "table_0": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_1": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_2": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_3": ParameterConstraints( + min_partition=10, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_4": ParameterConstraints( + min_partition=4, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_5": ParameterConstraints( + min_partition=6, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_0": ParameterConstraints( + min_partition=2, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_1": ParameterConstraints( + min_partition=3, sharding_types=[ShardingType.GRID_SHARD.value] + ), }, apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, pooling=pooling, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index db6182caa..69058b0f6 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -31,9 +31,17 @@ Union, ) +import torch from torch import distributed as dist -from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 +if not torch._running_with_deploy(): + from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 +else: + + class FSDP2: + pass + + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.fx.immutable_collections import ( immutable_dict as fx_immutable_dict,