Skip to content

Commit

Permalink
Refactor scaler to util (#3142)
Browse files Browse the repository at this point in the history
* Refactor scaler to util

* Document

* Use the distributed_type directly
  • Loading branch information
muellerzr authored Oct 8, 2024
1 parent 506d732 commit fb68cb9
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 38 deletions.
8 changes: 2 additions & 6 deletions benchmarks/fp8/ms_amp/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
import msamp
import torch
from fp8_utils import evaluate_model, get_training_utilities
from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed


MODEL_NAME = "bert-base-cased"
Expand All @@ -36,10 +35,7 @@

def train_baseline(opt_level="O2"):
set_seed(42)
if version.parse(torch.__version__) > version.parse("2.3"):
scaler = torch.amp.GradScaler("cuda")
else:
scaler = torch.cuda.amp.GradScaler()
scaler = get_grad_scaler()
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
accelerator = Accelerator()
device = accelerator.device
Expand Down
8 changes: 2 additions & 6 deletions benchmarks/fp8/ms_amp/non_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
import msamp
import torch
from fp8_utils import evaluate_model, get_training_utilities
from packaging import version

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed


MODEL_NAME = "bert-base-cased"
Expand All @@ -42,10 +41,7 @@ def train_baseline(opt_level="O2"):

base_model_results = evaluate_model(model, eval_dataloader, METRIC)
model.train()
if version.parse(torch.__version__) > version.parse("2.3"):
scaler = torch.amp.GradScaler("cuda")
else:
scaler = torch.cuda.amp.GradScaler()
scaler = get_grad_scaler()

for batch in train_dataloader:
batch = batch.to("cuda")
Expand Down
4 changes: 4 additions & 0 deletions docs/source/package_reference/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ These include data operations that mimic the same `torch` ops but can be used on

[[autodoc]] utils.gather_object

[[autodoc]] utils.get_grad_scaler

[[autodoc]] utils.get_mixed_precision_context_manager

[[autodoc]] utils.listify

[[autodoc]] utils.pad_across_processes
Expand Down
28 changes: 3 additions & 25 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import torch
import torch.utils.hooks as hooks
from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version

from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
Expand Down Expand Up @@ -78,6 +77,7 @@
extract_model_from_parallel,
gather,
gather_object,
get_grad_scaler,
get_mixed_precision_context_manager,
get_pretty_name,
is_bf16_available,
Expand Down Expand Up @@ -136,7 +136,6 @@


if is_torch_xla_available():
import torch_xla.amp as xamp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

Expand Down Expand Up @@ -484,25 +483,7 @@ def __init__(
):
raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).")
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
if self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

self.scaler = ShardedGradScaler(**kwargs)
elif is_torch_xla_available(check_is_gpu=True):
self.scaler = xamp.GradScaler(**kwargs)
elif is_mlu_available():
self.scaler = torch.mlu.amp.GradScaler(**kwargs)
elif is_musa_available():
self.scaler = torch.musa.amp.GradScaler(**kwargs)
elif is_npu_available():
self.scaler = torch.npu.amp.GradScaler(**kwargs)
elif is_xpu_available():
self.scaler = torch.amp.GradScaler("xpu", **kwargs)
else:
if version.parse(torch.__version__) > version.parse("2.3"):
self.scaler = torch.amp.GradScaler("cuda", **kwargs)
else:
self.scaler = torch.cuda.amp.GradScaler(**kwargs)
self.scaler = get_grad_scaler(self.distributed_type, **kwargs)

elif self.state.mixed_precision == "bf16" and self.distributed_type not in (
DistributedType.DEEPSPEED,
Expand All @@ -526,10 +507,7 @@ def __init__(
)
elif self.distributed_type != DistributedType.DEEPSPEED:
# MS-AMP requires `GradScaler` even with bf16 autocast w/ single GPU or DDP:
if version.parse(torch.__version__) > version.parse("2.3"):
self.scaler = torch.amp.GradScaler("cuda")
else:
self.scaler = torch.cuda.amp.GradScaler()
self.scaler = get_grad_scaler(**kwargs)

# Start of internal step tracking
self.step = 0
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
dtype_byte_size,
find_tied_parameters,
get_balanced_memory,
get_grad_scaler,
get_max_layer_size,
get_max_memory,
get_mixed_precision_context_manager,
Expand Down
34 changes: 34 additions & 0 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,3 +1871,37 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg
return torch.autocast(device_type=device_type, **autocast_kwargs)
else:
return contextlib.nullcontext()


def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
"""
A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return
it.
Args:
distributed_type (`DistributedType`, *optional*, defaults to None):
The type of distributed environment.
kwargs:
Additional arguments for the utilized `GradScaler` constructor.
"""
if distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

return ShardedGradScaler(**kwargs)
if is_torch_xla_available(check_is_gpu=True):
import torch_xla.amp as xamp

return xamp.GradScaler(**kwargs)
elif is_mlu_available():
return torch.mlu.amp.GradScaler(**kwargs)
elif is_musa_available():
return torch.musa.amp.GradScaler(**kwargs)
elif is_npu_available():
return torch.npu.amp.GradScaler(**kwargs)
elif is_xpu_available():
return torch.amp.GradScaler("xpu", **kwargs)
else:
if is_torch_version(">=", "2.3"):
return torch.amp.GradScaler("cuda", **kwargs)
else:
return torch.cuda.amp.GradScaler(**kwargs)
2 changes: 1 addition & 1 deletion src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
from .imports import (
is_npu_available,
is_torch_distributed_available,
is_torch_version,
is_torch_xla_available,
is_xpu_available,
)
from .versions import is_torch_version


if is_torch_xla_available():
Expand Down

0 comments on commit fb68cb9

Please sign in to comment.