Skip to content

Commit

Permalink
Bye bye torch <2 (#3331)
Browse files Browse the repository at this point in the history
* Bye bye torch <1

* Add 2.6.0 dl args

* Rm require fsdp

* Adjust imports + 2.0 specific modeling code

* Bring back is_bf16
  • Loading branch information
muellerzr authored Jan 9, 2025
1 parent 58f1436 commit b13aadc
Show file tree
Hide file tree
Showing 16 changed files with 43 additions and 119 deletions.
31 changes: 22 additions & 9 deletions docs/source/basic_tutorials/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,36 @@ accelerate env

An example output is shown below, which describes two GPUs on a single machine with no mixed precision being used:


```bash
- `Accelerate` version: 0.11.0.dev0
- Platform: Linux-5.10.0-15-cloud-amd64-x86_64-with-debian-11.3
- Python version: 3.7.12
- Numpy version: 1.19.5
- PyTorch version (GPU?): 1.12.0+cu102 (True)
- `Accelerate` version: 1.2.0.dev0
- Platform: Linux-6.8.0-47-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /home/zach/miniconda3/envs/accelerate/bin/accelerate
- Python version: 3.10.13
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 187.91 GB
- GPU type: NVIDIA GeForce RTX 4090
- `Accelerate` default config:
- compute_environment: LOCAL_MACHINE
- distributed_type: MULTI_GPU
- mixed_precision: no
- use_cpu: False
- debug: False
- num_processes: 2
- machine_rank: 0
- num_machines: 1
- main_process_ip: None
- main_process_port: None
- gpu_ids: all
- rdzv_backend: static
- same_network: True
- main_training_function: main
- deepspeed_config: {}
- fsdp_config: {}
- enable_cpu_affinity: False
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
```
5 changes: 1 addition & 4 deletions docs/source/usage_guides/mps.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ accelerate launch /examples/cv_example.py --data_dir images

## A few caveats to be aware of

1. We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing) on your MacOS machine.
It has major fixes related to model correctness and performance improvements for transformer based models.
Please refer to https://github.com/pytorch/pytorch/issues/82707 for more details.
2. Distributed setups `gloo` and `nccl` are not working with `mps` device.
1. Distributed setups `gloo` and `nccl` are not working with `mps` device.
This means that currently only single GPU of `mps` device type can be used.

Finally, please, remember that, `Accelerate` only integrates MPS backend, therefore if you
Expand Down
4 changes: 2 additions & 2 deletions docs/source/usage_guides/sagemaker.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ image_uri: null
mixed_precision: fp16
num_machines: 1
profile: xxxxx
py_version: py38
pytorch_version: 1.10.2
py_version: py10
pytorch_version: 2.5.0
region: us-east-1
transformers_version: 4.17.0
use_cpu: false
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"packaging>=20.0",
"psutil",
"pyyaml",
"torch>=1.10.0",
"torch>=2.0.0",
"huggingface_hub>=0.21.0",
"safetensors>=0.4.3",
],
Expand Down
2 changes: 0 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,8 +1608,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
# torch.compile should be called last and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
if not is_torch_version(">=", "2.0"):
raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.")
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
return model

Expand Down
4 changes: 1 addition & 3 deletions src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
is_mlu_available,
is_musa_available,
is_npu_available,
is_torch_version,
is_xpu_available,
load_checkpoint_in_model,
offload_state_dict,
Expand Down Expand Up @@ -114,8 +113,7 @@ def init_on_device(device: torch.device, include_buffers: bool = None):
if include_buffers is None:
include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)

# TODO(shingjan): remove the torch version check once older versions are deprecated
if is_torch_version(">=", "2.0") and include_buffers:
if include_buffers:
with device:
yield
return
Expand Down
3 changes: 1 addition & 2 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
is_npu_available,
is_rich_available,
is_sagemaker_available,
is_torch_version,
is_torch_xla_available,
is_xpu_available,
patch_environment,
Expand Down Expand Up @@ -1055,7 +1054,7 @@ def _validate_launch_command(args):
mp_from_config_flag = True
else:
if args.use_cpu or (args.use_xpu and torch.xpu.is_available()):
native_amp = is_torch_version(">=", "1.10")
native_amp = True
else:
native_amp = is_bf16_available(True)
if (
Expand Down
14 changes: 7 additions & 7 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

logger = get_logger(__name__)

# kwargs of the DataLoader in min version 1.4.0.
# kwargs of the DataLoader in min version 2.0
_PYTORCH_DATALOADER_KWARGS = {
"batch_size": 1,
"shuffle": False,
Expand All @@ -55,10 +55,11 @@
"generator": None,
"prefetch_factor": 2,
"persistent_workers": False,
"pin_memory_device": "",
}

# kwargs added after by version
_PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {}
_PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {"2.6.0": {"in_order": True}}

for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items():
if is_torch_version(">=", v):
Expand Down Expand Up @@ -718,12 +719,11 @@ def __init__(
**kwargs,
):
shuffle = False
if is_torch_version(">=", "1.11.0"):
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe

# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.split_batches = split_batches
if shuffle:
Expand Down
7 changes: 0 additions & 7 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,6 @@ def require_deepspeed(test_case):
return unittest.skipUnless(is_deepspeed_available(), "test requires DeepSpeed")(test_case)


def require_fsdp(test_case):
"""
Decorator marking a test that requires FSDP installed. These tests are skipped when FSDP isn't installed
"""
return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case)


def require_torch_min_version(test_case=None, version=None):
"""
Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an
Expand Down
3 changes: 0 additions & 3 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,6 @@ def is_xpu_available(check_device=False):
return False

if is_ipex_available():
if is_torch_version("<=", "1.12"):
return False

import intel_extension_for_pytorch # noqa: F401
else:
if is_torch_version("<=", "2.3"):
Expand Down
66 changes: 3 additions & 63 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tempfile
import warnings
from collections import OrderedDict, defaultdict
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -544,64 +544,6 @@ def check_tied_parameters_on_same_device(tied_params, device_map):
)


def _get_named_modules(
module: torch.nn.Module,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = "",
remove_duplicate: bool = True,
):
"""
Return an iterator over all modules in the network, yielding both the name of the module as well as the module
itself. Copied from PyTorch `torch.nn.Module.named_modules` for compatability with torch < 2.0 versions with
`remove_duplicate` option added.
Args:
memo (set of `torch.nn.Module`, *optional*):
A memo to store the set of modules already added to the result
prefix (`str`, *optional*):
A prefix that will be added to the name of the module
remove_duplicate (`bool`, *optional*):
Whether to remove the duplicated module instances in the result or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following example, ``l`` will be returned only once.
"""
if memo is None:
memo = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
yield prefix, module
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
yield from _get_named_modules(sub_module, memo, submodule_prefix, remove_duplicate)


def _get_named_parameters(module: torch.nn.Module, prefix="", recurse=True, remove_duplicate: bool = True):
"""
Help yield various names + members of modules. Copied from PyTorch `torch.nn.Module.named_modules` for
compatability with torch < 2.0 versions with `remove_duplicate` option added.
"""
memo = set()
modules = (
_get_named_modules(module, prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, module)]
)
for module_prefix, module in modules:
members = module._parameters.items()
for k, v in members:
if v is None or v in memo:
continue
if remove_duplicate:
memo.add(v)
name = module_prefix + ("." if module_prefix else "") + k
yield name, v


def find_tied_parameters(model: torch.nn.Module, **kwargs):
"""
Find the tied parameters in a given model.
Expand Down Expand Up @@ -633,13 +575,11 @@ def find_tied_parameters(model: torch.nn.Module, **kwargs):
"""

# get ALL model parameters and thier names
all_named_parameters = {name: param for name, param in _get_named_parameters(model, remove_duplicate=False)}
all_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=False)}

# get ONLY unique named parameters,
# if parameter is tied and have multiple names, it will be included only once
no_duplicate_named_parameters = {
name: param for name, param in _get_named_parameters(model, remove_duplicate=True)
}
no_duplicate_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=True)}

# the difference of the two sets will give us the tied parameters
tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())
Expand Down
6 changes: 1 addition & 5 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
is_torch_xla_available,
is_xpu_available,
)
from .versions import is_torch_version


if is_torch_xla_available():
Expand Down Expand Up @@ -320,10 +319,7 @@ def _tpu_gather_one(tensor):

def _gpu_gather(tensor):
state = PartialState()
if is_torch_version(">=", "1.13"):
gather_op = torch.distributed.all_gather_into_tensor
else:
gather_op = torch.distributed._all_gather_base
gather_op = torch.distributed.all_gather_into_tensor

def _gpu_gather_one(tensor):
if tensor.ndim == 0:
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def is_compiled_module(module):
"""
Check whether the module was compiled with torch.compile()
"""
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
if not hasattr(torch, "_dynamo"):
return False
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)

Expand Down
3 changes: 0 additions & 3 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
execute_subprocess_async,
get_launch_command,
path_in_accelerate_package,
require_fsdp,
require_multi_device,
require_non_cpu,
require_non_torch_xla,
Expand All @@ -55,7 +54,6 @@
dtypes = [FP16, BF16]


@require_fsdp
@require_non_cpu
@require_non_torch_xla
class FSDPPluginIntegration(AccelerateTestCase):
Expand Down Expand Up @@ -290,7 +288,6 @@ def test_cpu_ram_efficient_loading(self):

# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
@require_non_torch_xla
@require_fsdp
@require_multi_device
@slow
class FSDPIntegrationTest(TempDirTestCase):
Expand Down
7 changes: 3 additions & 4 deletions tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
slow,
torch_device,
)
from accelerate.utils import is_torch_version, offload_state_dict
from accelerate.utils import offload_state_dict


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -166,9 +166,8 @@ def test_init_empty_weights(self):
with init_empty_weights(include_buffers=True):
module = nn.BatchNorm1d(4)
# nn.Module.register_parameter/buffer shouldn't be changed with torch >= 2.0
if is_torch_version(">=", "2.0"):
assert register_parameter_func == nn.Module.register_parameter
assert register_buffer_func == nn.Module.register_buffer
assert register_parameter_func == nn.Module.register_parameter
assert register_buffer_func == nn.Module.register_buffer
assert module.weight.device == torch.device("meta")
assert module.running_mean.device == torch.device("meta")

Expand Down
3 changes: 0 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def test_can_undo_fp16_conversion(self):

@require_triton
@require_non_cpu
@require_torch_min_version(version="2.0")
def test_dynamo(self):
model = RegressionModel()
model._original_forward = model.forward
Expand Down Expand Up @@ -239,7 +238,6 @@ def nested_wrap(model):
for original_key, new_key in zip(orig_state_dict_keys, unwrapped_state_dict_keys):
assert original_key == new_key, f"Keys did not align: {original_key} != {new_key}"

@require_torch_min_version(version="2.0")
def test_dynamo_extract_model_keep_torch_compile(self):
model = RegressionModel()
compiled_model = torch.compile(model)
Expand All @@ -251,7 +249,6 @@ def test_dynamo_extract_model_keep_torch_compile(self):

assert compiled_model._orig_mod == compiled_model_unwrapped._orig_mod

@require_torch_min_version(version="2.0")
def test_dynamo_extract_model_remove_torch_compile(self):
model = RegressionModel()
compiled_model = torch.compile(model)
Expand Down

0 comments on commit b13aadc

Please sign in to comment.