-
Notifications
You must be signed in to change notification settings - Fork 438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DCP][RFC] Faster intermediate checkpoints with DCP async save in TorchTune #2006
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2006
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit d72a756 with merge base 1814feb (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I will review it tomorrow. Thanks for the PR! |
recipes/full_finetune_distributed.py
Outdated
else None | ||
), | ||
) | ||
|
||
if self._resume_from_checkpoint and self._enable_async_checkpointing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that this allowed for fast verification of the checkpoint time, and the improvements look really great. But I don't think we want the internals of how checkpointing works exposed in the recipe. How can this the user exposure between dcp and standard checkpointer be consolidated to the same thing so that checkpointer logic only has to take a few lines of code in the recipe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Checkpointing in general needs a fair bit of refactor. Plenty of code duplication. I wanted to avoid a major refactor here since I believe @joecummings is planning this out already. Happy to help on that.
@@ -633,11 +744,17 @@ def save_checkpoint( | |||
|
|||
intermediate_checkpoint = epoch + 1 < self.total_epochs | |||
|
|||
# If async checkpointing is enabled, intermediate checkpoints will | |||
# be saved asynchronously. | |||
if intermediate_checkpoint and self._enable_async_checkpointing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there ever a case where we wouldn't want async checkpointing for intermediate? Is the flag really needed? Also, it should be possible to merge the async save and full save logic into the same function to not proliferate state dict code everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pbontrager This flag was added temporarily to make it an opt-in feature which allows us to get more validations before we make it a default. Also, in case of any bugs, devs can turn it off and fallback to the prior state. Currently async save and full save happen via different Checkpointers having different formats but we can probably create a helper method to move this logic out of the recipe code. I will refactor this a bit more.
@@ -614,6 +671,60 @@ def _setup_data( | |||
|
|||
return sampler, dataloader | |||
|
|||
def save_checkpoint_async( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the idea be to eventually use async checkpoint everywhere and consolidate the files at the end? Or will we always do the all gather at the end?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By end, I assume you mean the final checkpoint at the end of the fine-tuning?
Final checkpoint format should be the one that user configures in their job config. Intermediate checkpoints are only for fault tolerance so we can choose the Checkpointer that saves the wasted training time and GPUs the best, which in this case, is DCP (DistributedCheckpointing) with Async save. We will however provide DCP option as another Checkpointer and users can choose to get the final checkpoint in that format as well but we wont enforce it, probably. Use case I can think of is when users train using TorchTitan, where DCP is the default Checkpointer, they can then plug it into TorchTune directly for fine-tuning instead of conversions to torch.save or HF formats.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am unclear on if the DistributedCheckpointer
class will replace the HF or Meta checkpointers or if it should be used in tandem for the fault tolerant intermediate checkpoints. How can users easily convert between checkpoint formats? Would it make sense to integrate the APIs into each checkpointer just for the purposes of async intermediate checkpoints, and then the first load / final save still remain in the original HF / meta format?
Maybe I just need to see the north star for the checkpoint redesign (@joecummings)
"No intermediate checkpoint found in the specified directory. Please ensure that a checkpoint exists." | ||
) | ||
|
||
if self._is_rank_zero: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use the torchtune.utils.logging.log_rank_zero
utility so you don't have to keep checking for rank
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i feel that these if is_rank_zero are most of the code :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. DCP APIs are trainer agnostic. So really all the code in the integration is gonna be state dict init & updates along with the tests :)
_, rank = training.get_world_size_and_rank() | ||
self._is_rank_zero = rank == 0 | ||
|
||
def _get_latest_intermediate_checkpoint_path(self) -> Optional[str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could be generally useful for all other checkpointers, cc @joecummings to consider this for the redesign
os.path.join(last_checkpoint, self._metadata_file) | ||
): | ||
if self._is_rank_zero: | ||
logger.info( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to automatically remove the previous checkpoint? some users may want to keep them around so they can evaluate model checkpoints over the course of training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this, we will be limiting the number of checkpoints we can take as its gonna be bound by the storage available on the host and it can run out quickly for large models. For example, Llama 3.2 vision90B is almost 150-200GB in size.
For internal trainers, we keep the last checkpoint only, as for fault tolerance we always restore the latest one. We can possibly make it a config so the users can configure the behavior if they want to evaluate intermediate checkpoints and can also choose to keep only the latest, if the checkpoints are large and taken frequently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed with making it configurable. For this PR I would actually lean towards keeping all intermediate checkpoints around. Mainly because that's what we do for our other checkpointers and we shouldn't have diverging behavior across different checkpoint formats. I'd like to see us add a field like save_last_n_checkpoints
to all checkpointers in one go
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Exactly. It seems like we would need to have a set of checkpointing configs for the user to tune the performance and behavior. We should do it in one go across Checkpointers and get these configs documented. I will get it cleaned up from this PR./
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before making changes, can we align on the config to avoid extra work?
The option seems to be:
freq_save_intermediate_ckpt = N, and when N=0, it means that we never save it.
keep_only_last_intermediate_ckpt = True/False, if True, we dont keep them
Is that it? Probably needs better naming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@felipemello1 We decided to keep the current behavior consistent with DistributedCheckpointer as well in this PR, which is to keep all the checkpoints persisted for now. I believe when we introduce this feature across Checkpointers, just having keep_last_n_checkpoints
should be sufficient.
- None -> All the checkpoints are kept
- 0 -> No intermediate checkpoints
- 1 -> Only the latest valid checkpoint is persisted. (Probably the default)
- N ( N > 1) -> Latest N checkpoints are kept.
Absolutely. Happy to brainstorm on the future direction whenever the proposal is ready. For intermediate checkpoints, in this PR, I preferred the first suggestion of making it work in tandem with other Checkpointers. Positioning DCP as a default Checkpointer and a default format for I will write up some checkpointing suggestions that you and @joecummings can consider as you plan out the refactor. |
dcp_saver = DistributedCheckpointer( | ||
checkpoint_dir=self._checkpointer._checkpoint_dir, | ||
output_dir=self._checkpointer._output_dir, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanna check my understanding here. We normally would define checkpointer as part of the config to be instantiated but here you basically use enable_async_checkpointing
as a proxy for whether we're the DistributedCheckpointer
or not. Is the idea behind this to have the ability to save and load intermediate checkpoints with DistributedCheckpointer
but still save the final checkpoint in the HF/Meta format?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Exactly. I was less inclined to have users configure two different Checkpointers in the same job config: DistributedCheckpointer
for intermediate and HF/Meta checkpointer for first load and final save. I exposed the async checkpointing as a feature instead under the enable_async_checkpointing
config for two reasons:
DistributedCheckpointer
is the only checkpointer which supports async checkpointing.- Users will probably want to control the first load and final save checkpoints. Intermediate checkpoints are for fault tolerance and we have flexibility to enforce the Checkpointer that works best for that.
Basically the end state I was thinking of is the following:
- For first load and final save, users can choose from any of the following Checkpointers: HF, Meta or
DistributedCheckpointer
. No async checkpointing here, since checkpoint loads and final saves will always need to be synchronous. We can not finish the job early until the entire final checkpoint has been persisted. - For intermediate checkpoints, we use async save with
DistributedCheckpointer
by default.
recipes/full_finetune_distributed.py
Outdated
resume_from_checkpoint: bool = ( | ||
False if self._enable_async_checkpointing else self._resume_from_checkpoint | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by this.. is it just because of the if statement for self._update_recipe_state
(i.e. we do not need to explicitly load state when using DCP)? If so I think we should just use a different variable e.g. should_load_recipe_state = self._resume_from_checkpoint and not self._enable_async_checkpointing
or something to be more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Thats right. I agree. Current way is a bit sloppy to avoid loading the recipe state. I will update the variables to make it a bit more readable.
else: | ||
logger.error( | ||
f"Checkpoint failed to save asynchronously to {checkpoint_path} with the exception {f.exception()}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see a bunch of exceptions coming from rmtree when I run this on my machine. E.g.
exception calling callback for <Future at 0x7f7415a06a50 state=finished returned Metadata>
Traceback (most recent call last):
File "/home/ebs/.conda/envs/tt-alt-10-24/lib/python3.11/concurrent/futures/_base.py", line 340, in _invoke_callbacks
callback(self)
File "/data/users/ebs/ebs-torchtune-alt/torchtune/training/checkpointing/_checkpointer.py", line 1062, in callback
shutil.rmtree(last_checkpoint)
File "/home/ebs/.conda/envs/tt-alt-10-24/lib/python3.11/shutil.py", line 752, in rmtree
_rmtree_safe_fd(fd, path, onerror)
File "/home/ebs/.conda/envs/tt-alt-10-24/lib/python3.11/shutil.py", line 703, in _rmtree_safe_fd
onerror(os.unlink, fullname, sys.exc_info())
File "/home/ebs/.conda/envs/tt-alt-10-24/lib/python3.11/shutil.py", line 701, in _rmtree_safe_fd
os.unlink(entry.name, dir_fd=topfd)
FileNotFoundError: [Errno 2] No such file or directory: '__2_46.distcp'
(However I am still able to save the checkpoints successfully)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh, my bad. Thanks for catching that. I somehow did not notice this in my testing. Basically the checkpoint deletion should happen at rank_0 only otherwise two different ranks may have a race condition in deleting a file and will run into the error above. Its not a problem during save, since DCP prepares the global plan and distributes it across ranks so every rank knows exactly what needs to be saved and there is no collision there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you suggested to remove this functionality for now from this PR, this logic will get cleaned up. I will record this issue it in my doc of suggestions for refactor and we can get this added for all Checkpointers in one go.
state_dict, | ||
storage_writer=FileSystemWriter( | ||
checkpoint_path, | ||
thread_count=8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, how is this value determined? (Similar question above when you set it to 16 for async_save)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is set as per the available storage I/O bandwidth. This value determines the number of parallel I/O threads while writing the data to storage. For our internal storage like Manifold or even S3, we know the capacity available and tune this value accordingly. This way we use max IO capacity with optimal performance without getting throttled.
For TorchTune case, since we are writing to Disk, we may have to run some experiments and tune it as per the IOPS available. Ideally, this should also be exposed as a checkpointing_config to the user. They can tune it to get the best performance depending on the underlying storage: HDD, SSD or cloud storage like S3.
I used 16 since that is a default for Manifold integrations internally. For Disks, this number can be higher. Also, it will only affect the upload speed to storage so it virtually has no effect on the async checkpointing blocking time.
f"Checkpoint is saved asynchronously to {checkpoint_path}" | ||
) | ||
|
||
for index in range(epoch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also nit but do you really need a for loop here if this runs every epoch? (i.e. isn't it just sufficient to clean up epoch n-1 checkpoint after each epoch n save)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. We can do that. I added this code instead just for edge cases when some prior checkpoint folder still did not get cleaned up due to user cancelling the job or the job errored out while the clean up was happening asynchronously.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this PR, the results look great! I left a bunch of questions (some probably pretty basic) but I don't have any huge concerns here. The main things needed in my mind are:
- Testing. One thing that we should definitely test is that resuming from an intermediate DCP checkpoint works as expected. I think we can do that by modifying this recipe test. And unit tests for the
DistributedCheckpointer
class would also be nice, you can check the unit tests for our other checkpointers in this file. - Figuring out some of the UX points raised by @pbontrager. I think there are some larger points that @joecummings will be figuring out and we probably don't need to block this PR on. But would definitely like to figure out how we can get to a place where this is less checkpointing code in the recipe and more in reusable utilities.
Thanks @ebsmothers, @pbontrager, @RdoubleA for the reviews. Very helpful! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just left small comments since the code will still be refactored. As I was reading through it, i was thinking if the same utilities will also work for LoRA. I think that its great that you are focusing on a single recipe for now, but just food for thought as you are designing it. Thanks for the PR! :)
recipes/full_finetune_distributed.py
Outdated
training.SEED_KEY: 0, | ||
training.EPOCHS_KEY: 0, | ||
training.TOTAL_EPOCHS_KEY: 0, | ||
training.MAX_STEPS_KEY: 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this problem existed before this PR, btu we need to keep track of the LR of the scheduler too. Currently when we restart from ckpt, i believe we dont restart from the same lr. Its one of the issues in this list: #1551
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#2017 introduces the lr_scheduler. Once both of these PRs land, it should be straightforward to add the checkpointing support. We can do it as a follow up.
# Create the checkpoint dict to be sent to the checkpointer and ultimately persisted to storage | ||
ckpt_dict = { | ||
training.SEED_KEY: self.seed, | ||
training.EPOCHS_KEY: self.epochs_run, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably shouldnt use self.epochs_run AND the input epoch. Maybe just stick to one or the other.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@felipemello1 Sorry. Could you please explain this comment a bit more? Are you suggesting, we should not be saving the epochs_run in the checkpoint? How would be restore the training progress in terms of the number of epochs that were run till the last checkpoint was taken?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do that. save_checkpoint already takes the epoch
so I propagated it to the async API as well. It would require a minor change to update the self._epochs_run only after taking the checkpoint for the current epoch here: https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py#L878
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@felipemello1 I looked into it some more and decided to revert my changes back to the prior state.
- current_epoch refers to the epoch which just got completed and epochs_run gets incremented to the next epoch.
- current_epoch is used in the checkpoint name however in the training progress state dict, intermediate checkpoints save the next epoch to resume from but final save checkpoints save the epoch which just got completed.
Therefore passing the current_epoch in the checkpoint methods seems cleaner for now compared to doing +1s and -1s for different type of checkpoints. I do believe this logic should be refactored a bit and all the checkpoints should save the total epochs which have completed. During training run, we adjust the starting epoch depending on if we are resuming from a checkpoint and training from scratch.
os.path.join(last_checkpoint, self._metadata_file) | ||
): | ||
if self._is_rank_zero: | ||
logger.info( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before making changes, can we align on the config to avoid extra work?
The option seems to be:
freq_save_intermediate_ckpt = N, and when N=0, it means that we never save it.
keep_only_last_intermediate_ckpt = True/False, if True, we dont keep them
Is that it? Probably needs better naming.
"No intermediate checkpoint found in the specified directory. Please ensure that a checkpoint exists." | ||
) | ||
|
||
if self._is_rank_zero: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i feel that these if is_rank_zero are most of the code :P
4a0f55a
to
9b47975
Compare
9b47975
to
daaf4b8
Compare
@saumishr has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@saumishr has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
cec9587
to
1b1342e
Compare
@saumishr has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@saumishr has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Context
What is the purpose of this PR? Is it to
This diff introduces the DistributedCheckpointing based asynchronous checkpointing and enables it for intermediate checkpoints.
Experiments
Full Finetune Distributed
Changelog
What are the changes made in this PR?
Usage
enable_async_checkpointing
enables the asynchronous checkpoint saving for intermediate checkpoints.resume_from_checkpoint
will enable the resume from the latest intermediate checkpoint. No recipe state needs to be provided as its already saved in the distributed checkpoint.Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example