Skip to content

Fsdp2 support for activation checkpointing #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 64 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
725a218
refactor: towards fsdp2 support for activation checkpointing
le1nux Apr 18, 2025
8362e0d
feat: added test configs for AC
le1nux Apr 18, 2025
a546189
feat: added selective activation checkpointing configs
le1nux Apr 21, 2025
61070d1
feat: added selective activation checkpointing tests
le1nux Apr 21, 2025
9c7573b
feat: added selective activation checkpointing strategies
le1nux Apr 21, 2025
034662a
feat: wired up selective activation checkpointing
le1nux Apr 21, 2025
a9fe173
refactor: experiment id can now be passed to Main constructor to allo…
le1nux Apr 21, 2025
f977525
chore: added doc strings to the selective AC entry point
le1nux Apr 21, 2025
ba3c9e2
refactor: renamed typing and logging module to fix name shadowing issues
le1nux Apr 21, 2025
391d4eb
refactor: reneamed typing module to typing_utils
le1nux Apr 21, 2025
c6b6e47
feat: added SAC benchmarking script
le1nux Apr 21, 2025
742ab46
chore: added activation checkpointing benchmark config
le1nux Apr 21, 2025
641fad1
chore: removed legacy code
le1nux Apr 24, 2025
f940b0e
refactor: extracted save_list to parent in AC
le1nux Apr 24, 2025
c346f0f
refactor: moved Main from __main__ to main
le1nux Apr 24, 2025
90c4bbe
feat: added batch generator util
le1nux Apr 24, 2025
c87688a
refactor: split experiment_id syncing into multiple utity functions
le1nux Apr 24, 2025
ebaf13e
feat: implemented grid search setup for profiling
le1nux Apr 24, 2025
bcc0e7b
refactor: added OOM error handling in CudaEnv
le1nux Apr 24, 2025
899401a
feat: added torchrun script for distributed profiling
le1nux Apr 24, 2025
87296d1
feat: added profiling README
le1nux Apr 24, 2025
d6dddc0
feat: added profiler implementation
le1nux Apr 24, 2025
011e41b
feat: drafted profile logs analyzer
le1nux Apr 24, 2025
e6da67f
chore: minor renamings
le1nux Apr 24, 2025
2d4a9ad
refactor: making sure that each compil
le1nux Apr 27, 2025
d990420
feat: added torchrun launcher
le1nux Apr 27, 2025
f4a5b48
refactor: wrapped up the profiler_starter
le1nux Apr 27, 2025
b2aa30b
feat: added activation checkpoint profiling example
le1nux Apr 27, 2025
94025af
feat: setup forward pass profiling
le1nux Apr 29, 2025
3ea7cf5
feat: ops in selective op activation checkpointing are now configurable
le1nux Apr 29, 2025
0b99db2
feat: added profliing logs analysis notebook
le1nux Apr 29, 2025
1889776
refactor: adapted the configs and evaluation code for selective op AC
le1nux Apr 29, 2025
c718182
chore: added profiling experiments to gitignore
le1nux Apr 29, 2025
00b5bb5
chore: Merge branch 'fsdp2_min_integration' into fsdp2_activation_che…
le1nux May 26, 2025
40db2da
chore: fixed failing AC tests
le1nux May 26, 2025
522489e
chore: Merge branch 'fsdp2_activation_checkpointing' into profiling_f…
le1nux May 26, 2025
4da7d2c
chore: fix failing unit tests
le1nux May 26, 2025
ed2772c
Merge pull request #360 from Modalities/profiling_feature
le1nux May 26, 2025
4b2f8dc
refactor: improved the activation checkpointing interface
le1nux May 27, 2025
6bb5d78
feat: added temporary env context manager
le1nux May 27, 2025
6e0888d
refactor: improved readability and structure of profiling
le1nux May 27, 2025
d37c352
chore: drafted scaling up tutorial
le1nux May 31, 2025
0388145
chore: extended profiling analysis notebook
le1nux May 31, 2025
cc39d11
chore: updated gitignore
le1nux Jun 1, 2025
085995f
feat: added slurm / sbatch scripts for the scale up tutorial
le1nux Jun 1, 2025
f64c8dc
refactor: rendezvous timeout is now configurable
le1nux Jun 1, 2025
b12cefb
feat: added jupyter notebook for scalability analysis
le1nux Jun 1, 2025
3a86ebe
refactor: removed legacy rdzv_timeout
le1nux Jun 1, 2025
8dc0fc0
chore: removed hardcoded stuff from sbatch script
le1nux Jun 3, 2025
80c9e6a
feat: more work on scaling notebook
le1nux Jun 3, 2025
4fb2adf
refactor: grid search is now part of config
le1nux Jun 3, 2025
7f3a1bf
feat: run_array_jobs.sh now accepts number of ranks as CMD arguments
le1nux Jun 3, 2025
5e4d08b
feat: added README.md for scaling tutorial
le1nux Jun 3, 2025
a7d94f5
refactor: renamed sac_params to ac_params
le1nux Jun 4, 2025
87667ae
feat: added pretraining benchmarking config
le1nux Jun 4, 2025
ffe7d8f
refactor: improved usability of submission scripts
le1nux Jun 4, 2025
da66a3c
chore: improved error handling / logging
le1nux Jun 4, 2025
226dc44
feat: added training run optimization part
le1nux Jun 5, 2025
6a0fa8e
refactor: defined grid search for training run optimizatino
le1nux Jun 5, 2025
ec56a05
feat: added peak memory tracking
le1nux Jun 5, 2025
cbab52c
feat: added results subscriber that writes JSONL to disc
le1nux Jun 5, 2025
c1561e2
chore: Merge remote-tracking branch 'refs/remotes/origin/fsdp2_activa…
le1nux Jun 5, 2025
edbda8c
chore: replaced timer with perf_counter
le1nux Jun 5, 2025
dd9a453
chore: resolved merge conflict
le1nux Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,5 @@ tests/tmp/*
*wandb_storage*
.coverage/*
*.pbin

tutorials/profiling/experiments
tutorials/profiling/experiments
tutorials/scaling_up/experiments
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ settings:
config:
checkpoint_path: ${settings.warmstart_checkpoint_paths.checkpoint_folder_path}
warmstart_checkpoint_paths: # ${warmstart_env:checkpoint_paths}
checkpoint_folder_path: /raid/fromm/modalities/data/checkpoints/2025-04-16__12-40-51_6dcbb1a0/eid_2025-04-16__12-40-51_6dcbb1a0-seen_steps_32-seen_tokens_65536-target_steps_162-target_tokens_331776
checkpoint_folder_path: /raid/s3/opengptx/max_lue/repositories/modalities/data/checkpoints/2025-03-14__15-25-59_970fedec/eid_2025-03-14__15-25-59_970fedec-seen_steps_96-seen_tokens_196608-target_steps_162-target_tokens_331776

collate_fn:
component_key: collate_fn
Expand Down
255 changes: 57 additions & 198 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
#!/usr/bin/env python

import json
import logging
import os
import shutil
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Callable, Optional, Type
from typing import Optional

import click
import click_pathlib
import yaml
from omegaconf import DictConfig
from pydantic import BaseModel, FilePath
from pydantic import FilePath

from modalities.api import (
FileExistencePolicy,
Expand All @@ -27,22 +22,12 @@
shuffle_jsonl_data,
shuffle_tokenized_data,
)
from modalities.batch import EvaluationResultBatch
from modalities.config.component_factory import ComponentFactory
from modalities.config.config import ProcessGroupBackendType, load_app_config_dict
from modalities.config.instantiation_models import TrainingComponentsInstantiationModel, TrainingReportGenerator
from modalities.evaluator import Evaluator
from modalities.gym import Gym
from modalities.logging_broker.message_broker import MessageBroker
from modalities.logging_broker.messages import MessageTypes, ProgressUpdate
from modalities.logging_broker.publisher import MessagePublisher
from modalities.logging_broker.subscriber import MessageSubscriberIF
from modalities.config.instantiation_models import TrainingComponentsInstantiationModel
from modalities.main import Main
from modalities.models.huggingface_adapters.hf_adapter import HFModelAdapter
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from modalities.running_env.cuda_env import CudaEnv
from modalities.trainer import Trainer
from modalities.util import get_experiment_id_of_run, get_total_number_of_trainable_parameters, print_rank_0
from modalities.utils.profilers.modalities_profiler import ModalitiesProfiler


@click.group()
Expand Down Expand Up @@ -511,187 +496,61 @@ def CMD_shuffle_jsonl_data(
)


class Main:
"""Main class that orchestrates the training process."""

def __init__(self, config_path: Path, additional_resolver_funs: Optional[dict[str, Callable]] = None) -> None:
experiment_id = get_experiment_id_of_run(config_path)
self.config_dict = load_app_config_dict(
config_file_path=config_path, experiment_id=experiment_id, additional_resolver_funs=additional_resolver_funs
)
self.config_path = config_path

self.registry = Registry(COMPONENTS)
self.component_factory = ComponentFactory(registry=self.registry)

def add_custom_component(
self, component_key: str, variant_key: str, custom_component: Type, custom_config: Type
) -> None:
"""Add a custom component to the registry.

This method comes in especially handy
when Modalities is used as a library and the user wants to add custom components
(e.g., custom model or custom loss function) to the registry.

Args:
component_key (str): Key of the component to be added to the registry
variant_key (str): Key of the variant to be added to the registry
custom_component (Type): The class type of the custom component
custom_config (Type): The pydantic config type of the custom component
"""
self.registry.add_entity(
component_key=component_key,
variant_key=variant_key,
component_type=custom_component,
component_config_type=custom_config,
)

def build_components(self, components_model_type: Type[BaseModel]) -> BaseModel:
"""Given a pydantic basemodel, this method builds the components specified in the config file.

Depending on the use case (e.g., training, inference, etc.), the user can pass different pydantic base models.
For instance, for tokenization, the basemodel would only have the tokenization-related components specified.

Args:
components_model_type (Type[BaseModel]): The pydantic basemodel type that should be
used to build the components.

Returns:
BaseModel: The components built based on the config file.
"""
components = self.component_factory.build_components(
config_dict=self.config_dict, components_model_type=components_model_type
)
return components

def run(self, components: TrainingComponentsInstantiationModel):
"""Entrypoint fo running the training process.

We pass in a TrainingComponentsInstantiationModel,
which is a pydantic model that contains all the components needed for the training process.

Args:
components (TrainingComponentsInstantiationModel): The components needed for the training process.
"""
# save the config file to the checkpointing path
if components.settings.cuda_env.global_rank == 0:
experiment_path = components.settings.paths.checkpoint_saving_path / components.settings.experiment_id
os.makedirs(experiment_path, exist_ok=True)
shutil.copy(self.config_path, experiment_path / self.config_path.name)
resolved_config_path = (experiment_path / self.config_path.name).with_suffix(".yaml.resolved")
with open(resolved_config_path, "w", encoding="utf-8") as f:
yaml.dump(self.config_dict, f)

evaluation_result_publisher, progress_publisher = self.get_logging_publishers(
progress_subscriber=components.progress_subscriber,
results_subscriber=components.evaluation_subscriber,
global_rank=components.settings.cuda_env.global_rank,
local_rank=components.settings.cuda_env.local_rank,
)

# Trainer
global_num_tokens_per_train_step = (
components.settings.step_profile.local_train_micro_batch_size
* components.settings.step_profile.sequence_length
* components.settings.step_profile.gradient_accumulation_steps
* components.settings.cuda_env.world_size
)
trainer = Trainer(
global_rank=components.settings.cuda_env.global_rank,
progress_publisher=progress_publisher,
num_target_steps=components.settings.training_target.num_target_steps,
num_target_tokens=components.settings.training_target.num_target_tokens,
num_seen_train_steps=components.settings.training_progress.num_seen_steps,
global_num_seen_tokens=components.settings.training_progress.global_num_seen_tokens,
evaluation_result_publisher=evaluation_result_publisher,
gradient_acc_steps=components.settings.step_profile.gradient_accumulation_steps,
gradient_clipper=components.gradient_clipper,
global_num_tokens_per_train_step=global_num_tokens_per_train_step,
mfu_calculator=components.mfu_calculator,
)

# Evaluator
evaluator = Evaluator(
progress_publisher=progress_publisher,
evaluation_result_publisher=evaluation_result_publisher,
)

# Gym
gym = Gym(
trainer=trainer,
evaluator=evaluator,
loss_fun=components.loss_fn,
num_ranks=components.settings.cuda_env.world_size,
)
num_params = get_total_number_of_trainable_parameters(components.app_state.model)
components.evaluation_subscriber.consume_dict({"No. parameters": num_params})
logging.info(f"Training model with {num_params} parameters.")

print_rank_0(f"Model initialized at {datetime.now()}.")

report = TrainingReportGenerator(
training_target=components.settings.training_target,
intervals=components.settings.intervals,
step_profile=components.settings.step_profile,
cuda_env=components.settings.cuda_env,
consistency_enforcement=components.settings.consistency_enforcement,
train_dataset=components.train_dataset,
training_progress=components.settings.training_progress,
).get_report()

print_rank_0(report)

gym.run(
train_data_loader=components.train_dataloader,
evaluation_data_loaders=components.eval_dataloaders,
checkpoint_saving=components.checkpoint_saving,
app_state=components.app_state,
checkpointing_interval_in_steps=components.settings.intervals.checkpointing_interval_in_steps,
evaluation_interval_in_steps=components.settings.intervals.evaluation_interval_in_steps,
training_log_interval_in_steps=components.settings.intervals.training_log_interval_in_steps,
)
@main.group(name="profile")
def profile():
"""
Collection of utilities to profile modalities.
"""
pass

def get_logging_publishers(
self,
progress_subscriber: MessageSubscriberIF[ProgressUpdate],
results_subscriber: MessageSubscriberIF[EvaluationResultBatch],
global_rank: int,
local_rank: int,
) -> tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]:
"""Returns the logging publishers for the training.

These publishers are used to pass the evaluation results and the progress updates to the message broker.
The message broker is then used to pass the messages to the subscribers, such as WandB.

Args:
progress_subscriber (MessageSubscriberIF[ProgressUpdate]): The progress subscriber
results_subscriber (MessageSubscriberIF[EvaluationResultBatch]): The results subscriber
global_rank (int): The global rank of the current process
local_rank (int): The local rank of the current process on the current node

Returns:
tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: The evaluation
result publisher and the progress publisher
"""
message_broker = MessageBroker()
progress_publisher = MessagePublisher[ProgressUpdate](
message_broker=message_broker,
global_rank=global_rank,
local_rank=local_rank,
)
evaluation_result_publisher = MessagePublisher[EvaluationResultBatch](
message_broker=message_broker,
global_rank=global_rank,
local_rank=local_rank,
)

message_broker.add_subscriber(subscription=MessageTypes.EVALUATION_RESULT, subscriber=results_subscriber)
message_broker.add_subscriber(
subscription=MessageTypes.BATCH_PROGRESS_UPDATE,
subscriber=progress_subscriber,
)
@profile.command(name="train_step")
@click.option(
"--config_file_path",
type=click_pathlib.Path(exists=True),
required=True,
help="Path to the YAML training config file.",
)
@click.option(
"--experiment_folder_path",
type=click_pathlib.Path(file_okay=False),
required=True,
help="Path to the experiment output directory.",
)
@click.option(
"--num_warmup_steps",
type=int,
default=1,
show_default=True,
help="Number of warmup steps to skip in profiling.",
)
@click.option(
"--num_measurement_steps",
type=int,
default=3,
show_default=True,
help="Number of steps to measure during profiling.",
)
def CMD_entry_point_run_train_step_profiler(
config_file_path: Path,
experiment_folder_path: Path,
num_warmup_steps: int,
num_measurement_steps: int,
):
"""Run train step profiler and write result to JSON if global rank=0.

return evaluation_result_publisher, progress_publisher
Args:
config_file_path (Path): Path to the YAML training config file.
experiment_folder_path (Path): Path to the experiment output directory.
num_warmup_steps (int): Number of warmup steps to skip in profiling.
num_measurement_steps (int): Number of steps to measure during profiling.
"""
ModalitiesProfiler.get_train_step_statistics(
config_file_path=config_file_path,
experiment_folder_path=experiment_folder_path,
num_warmup_steps=num_warmup_steps,
num_measurement_steps=num_measurement_steps,
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/modalities/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from modalities.preprocessing.shuffle_data import DataShuffler
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from modalities.utils.logging import get_logger
from modalities.utils.logger_utils import get_logger
from modalities.utils.seeding import calculate_hashed_seed


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from modalities.checkpointing.checkpoint_loading import DistributedCheckpointLoadingIF, FSDP1CheckpointLoadingIF
from modalities.checkpointing.stateful.app_state import AppState
from modalities.running_env.env_utils import MixedPrecisionSettings
from modalities.utils.logging import get_logger
from modalities.utils.logger_utils import get_logger


class FSDP1CheckpointLoading(FSDP1CheckpointLoadingIF):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from modalities.checkpointing.stateful.app_state import AppState
from modalities.exceptions import CheckpointingError
from modalities.training.training_progress import TrainingProgress
from modalities.utils.logging import get_logger
from modalities.utils.logger_utils import get_logger


class CheckpointingEntityType(Enum):
Expand Down
26 changes: 25 additions & 1 deletion src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
has_bfloat_support,
)
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees
from modalities.training.activation_checkpointing.activation_checkpointing_variants import (
ActivationCheckpointingVariants,
)
from modalities.util import parse_enum_by_name


Expand Down Expand Up @@ -299,11 +302,27 @@ class WeightInitializedModelConfig(BaseModel):
model_config = ConfigDict(protected_namespaces=())


class ActivationCheckpointedModelConfig(BaseModel):
class FSDP1ActivationCheckpointedModelConfig(BaseModel):
model: PydanticFSDP1ModuleType
activation_checkpointing_modules: Optional[list[str]] = Field(default_factory=list)


class ActivationCheckpointedModelConfig(BaseModel):
class FullACParams(BaseModel):
pass

class SelectiveLayerACParams(BaseModel):
ac_freq: int

class SelectiveOpACParams(BaseModel):
save_ops_keys: list[str]

ac_variant: ActivationCheckpointingVariants
layers_fqn: str
model: PydanticPytorchModuleType | PydanticFSDP1ModuleType
ac_fun_params: Optional[FullACParams | SelectiveLayerACParams | SelectiveOpACParams] = None


class RawAppStateConfig(BaseModel):
model: PydanticPytorchModuleType
optimizer: PydanticOptimizerIFType
Expand Down Expand Up @@ -411,6 +430,11 @@ class DummyResultSubscriberConfig(BaseModel):
pass


class EvaluationResultToDiscSubscriberConfig(BaseModel):
output_folder_path: Path
experiment_id: str


class WandBEvaluationResultSubscriberConfig(BaseModel):
global_rank: int
project: str
Expand Down
4 changes: 4 additions & 0 deletions src/modalities/config/pydantic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
from modalities.utils.mfu import MFUCalculatorABC
from modalities.utils.profilers.batch_generator import DatasetBatchGeneratorIF


class PydanticThirdPartyTypeIF:
Expand Down Expand Up @@ -79,3 +80,6 @@ def __get_pydantic_core_schema__(
PydanticDeviceMeshIFType = Annotated[DeviceMesh, PydanticThirdPartyTypeIF(DeviceMesh)]
PydanticAppStateType = Annotated[AppState, PydanticThirdPartyTypeIF(AppState)]
PydanticMFUCalculatorABCType = Annotated[MFUCalculatorABC, PydanticThirdPartyTypeIF(MFUCalculatorABC)]
PydanticDatasetBatchGeneratorIFType = Annotated[
DatasetBatchGeneratorIF, PydanticThirdPartyTypeIF(DatasetBatchGeneratorIF)
]
Loading