v0.34.0: StatefulDataLoader Support, FP8 Improvements, and PyTorch Updates!
Dependency Changes
- Updated Safetensors Requirement: The library now requires
safetensors
version 0.4.3. - Added support for Numpy 2.0: The library now fully supports
numpy
2.0.0
Core
New Script Behavior Changes
- Process Group Management: PyTorch now requires users to destroy process groups after training. The
accelerate
library will handle this automatically withaccelerator.end_training()
, or you can do it manually usingPartialState().destroy_process_group()
. - MLU Device Support: Added support for saving and loading RNG states on MLU devices by @huismiling
- NPU Support: Corrected backend and distributed settings when using
transfer_to_npu
, ensuring better performance and compatibility.
DataLoader Enhancements
- Stateful DataDataLoader: We are excited to announce that early support has been added for the
StatefulDataLoader
fromtorchdata
, allowing better handling of data loading states. Enable by passinguse_stateful_dataloader=True
to theDataLoaderConfiguration
, and when callingload_state()
theDataLoader
will automatically be resumed from its last step, no more having to iterate through passed batches. - Decoupled Data Loader Preparation: The
prepare_data_loader()
function is now independent of theAccelerator
, giving you more flexibility towards which API levels you would like to use. - XLA Compatibility: Added support for skipping initial batches when using XLA.
- Improved State Management: Bug fixes and enhancements for saving/loading
DataLoader
states, ensuring smoother training sessions. - Epoch Setting: Introduced the
set_epoch
function forMpDeviceLoaderWrapper
.
FP8 Training Improvements
- Enhanced FP8 Training: Fully Sharded Data Parallelism (FSDP) and DeepSpeed support now work seamlessly with
TransformerEngine
FP8 training, including better defaults for the quantized FP8 weights. - Integration baseline: We've added a new suite of examples and benchmarks to ensure that our
TransformerEngine
integration works exactly as intended. These scripts run one half using 🤗 Accelerate's integration, the other with rawTransformersEngine
, providing users with a nice example of what we do under the hood with accelerate, and a good sanity check to make sure nothing breaks down over time. Find them here - Import Fixes: Resolved issues with import checks for the Transformers Engine that has downstream issues.
- FP8 Docker Images: We've added new docker images for
TransformerEngine
andaccelerate
as well. Usedocker pull huggingface/accelerate@gpu-fp8-transformerengine
to quickly get an environment going.
torchpippy
no more, long live torch.distributed.pipelining
- With the latest PyTorch release,
torchpippy
is now fully integrated into torch core, and as a result we are exclusively supporting the PyTorch implementation from now on - There are breaking examples and changes that comes from this shift. Namely:
- Tracing of inputs is done with a shape each GPU will see, rather than the size of the total batch. So for 2 GPUs, one should pass in an input of
[1, n, n]
rather than[2, n, n]
as before. - We no longer support Encoder/Decoder models. PyTorch tracing for
pipelining
no longer supports encoder/decoder models, so thet5
example has been removed. - Computer vision model support currently does not work: There are some tracing issues regarding resnet's we are actively looking into.
- Tracing of inputs is done with a shape each GPU will see, rather than the size of the total batch. So for 2 GPUs, one should pass in an input of
- If either of these changes are too breaking, we recommend pinning your accelerate version. If the encoder/decoder model support is actively blocking your inference using pippy, please open an issue and let us know. We can look towards adding in the old support for
torchpippy
potentially if needed.
Fully Sharded Data Parallelism (FSDP)
- Environment Flexibility: Environment variables are now fully optional for FSDP, simplifying configuration. You can now fully create a
FullyShardedDataParallelPlugin
yourself manually with no need for environment patching:
from accelerate import FullyShardedDataParallelPlugin
fsdp_plugin = FullyShardedDataParallelPlugin(...)
- FSDP RAM efficient loading: Added a utility to enable RAM-efficient model loading (by setting the proper environmental variable). This is generally needed if not using
accelerate launch
and need to ensure the env variables are setup properly for model loading:
from accelerate.utils import enable_fsdp_ram_efficient_loading, disable_fsdp_ram_efficient_loading
enable_fsdp_ram_efficient_loading()
- Model State Dict Management: Enhanced support for unwrapping model state dicts in FSDP, making it easier to manage distributed models.
New Examples
- Configuration and Models: Improved configuration handling and introduced a configuration zoo for easier experimentation. You can learn more here. This was largely inspired by the
axolotl
library, so very big kudos to their wonderful work - FSDP + SLURM Example: Added a minimal configuration example for running jobs with SLURM and using FSDP
Bug Fixes
- Fix bug of clip_grad_norm_ for xla fsdp by @hanwen-sun in #2941
- Explicit check for
step
when loading the state by @muellerzr in #2992 - Fix
find_tied_params
for models with shared layers by @qubvel in #2986 - clear memory after offload by @SunMarc in #2994
- fix default value for rank size in cpu threads_per_process assignment logic by @rbrugaro in #3009
- Fix batch_sampler maybe None error by @candlewill in #3025
- Do not import
transformer_engine
on import by @oraluben in #3056 - Fix torchvision to be compatible with torch version in CI by @SunMarc in #2982
- Fix gated test by @muellerzr in #2993
- Fix typo on warning str: "on the meta device device" -> "on the meta device" by @HeAndres in #2997
- Fix deepspeed tests by @muellerzr in #3003
- Fix torch version check by @muellerzr in #3024
- Fix fp8 benchmark on single GPU by @muellerzr in #3032
- Fix typo in comment by @zmoki688 in #3045
- Speed up tests by shaving off subprocess when not needed by @muellerzr in #3042
- Remove
skip_first_batches
support for StatefulDataloader and fix all the tests by @muellerzr in #3068
New Contributors
- @byi8220 made their first contribution in #2957
- @alex-jw-brooks made their first contribution in #2959
- @XciD made their first contribution in #2981
- @hanwen-sun made their first contribution in #2941
- @HeAndres made their first contribution in #2997
- @yitongh made their first contribution in #2966
- @qubvel made their first contribution in #2986
- @rbrugaro made their first contribution in #3009
- @candlewill made their first contribution in #3025
- @siddk made their first contribution in #3047
- @oraluben made their first contribution in #3056
- @tmm1 made their first contribution in #3055
- @zmoki688 made their first contribution in #3045
Full Changelog:
- Require safetensors>=0.4.3 by @byi8220 in #2957
- Fix torchvision to be compatible with torch version in CI by @SunMarc in #2982
- Enable Unwrapping for Model State Dicts (FSDP) by @alex-jw-brooks in #2959
- chore: Update runs-on configuration for CI workflows by @XciD in #2981
- add MLU devices for rng state saving and loading. by @huismiling in #2940
- remove .md to allow proper linking by @nbroad1881 in #2977
- Fix bug of clip_grad_norm_ for xla fsdp by @hanwen-sun in #2941
- Fix gated test by @muellerzr in #2993
- Explicit check for
step
when loading the state by @muellerzr in #2992 - Fix typo on warning str: "on the meta device device" -> "on the meta device" by @HeAndres in #2997
- Support skip_first_batches for XLA by @yitongh in #2966
- clear memory after offload by @SunMarc in #2994
- Fix deepspeed tests by @muellerzr in #3003
- Make env variables optional for FSDP by @muellerzr in #2998
- Add small util to enable FSDP offloading quickly by @muellerzr in #3006
- update version to 0.34.dev0 by @SunMarc in #3007
- Fix
find_tied_params
for models with shared layers by @qubvel in #2986 - Enable FSDP & Deepspeed + FP8 by @muellerzr in #2983
- fix default value for rank size in cpu threads_per_process assignment logic by @rbrugaro in #3009
- Wrong import check for TE by @muellerzr in #3016
- destroy process group in
end_training
by @SunMarc in #3012 - Tweak defaults for quantized-typed FP8 TE weights by @muellerzr in #3018
- Set correct NPU backend and distributed_type when using transfer_to_npu by @ArthurinRUC in #3021
- Fix torch version check by @muellerzr in #3024
- Add end_training/destroy_pg to everything and unpin numpy by @muellerzr in #3030
- Improve config handling and add a zoo by @muellerzr in #3029
- Add early support for
torchdata.stateful_dataloader.StatefulDataLoader
within theAccelerator
by @byi8220 in #2895 - Fix fp8 benchmark on single GPU by @muellerzr in #3032
- Fix batch_sampler maybe None error by @candlewill in #3025
- Fixup dataloader state dict bugs + incorporate load/save_state API by @muellerzr in #3034
- Decouple
prepare_data_loader()
from Accelerator by @siddk in #3047 - Update CONTRIBUTING.md Setup Instructions by @siddk in #3046
- Add a SLURM example with minimal config by @muellerzr in #2950
- Add FP8 docker images by @muellerzr in #3048
- Update torchpippy by @muellerzr in #2938
- Do not import
transformer_engine
on import by @oraluben in #3056 - use duck-typing to ensure underlying optimizer supports schedulefree hooks by @tmm1 in #3055
- Fix typo in comment by @zmoki688 in #3045
- add set_epoch for MpDeviceLoaderWrapper by @hanwen-sun in #3053
- Speed up tests by shaving off subprocess when not needed by @muellerzr in #3042
- Remove
skip_first_batches
support for StatefulDataloader and fix all the tests by @muellerzr in #3068