Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support acc fsdp optim_state_dict #6

Open
wants to merge 2 commits into
base: acc
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2959,16 +2959,14 @@ def _save_optimizer_and_scheduler(self, output_dir):
if is_torch_xla_available():
xm.rendezvous("saving_optimizer_states")
if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled:
optm = {
"optimizer": self.optimizer.state_dict(),
"shard_metadata": self.model.get_shard_metadata(),
}
from torchacc.dist.fsdp import FullyShardedDataParallel as FSDP

optm = FSDP.full_optim_state_dict(self.model, self.optimizer)
xm.save(
optm,
os.path.join(
output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
),
master_only=False,
output_dir, f"{OPTIMIZER_NAME}"
)
)
else:
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
Expand Down Expand Up @@ -3050,23 +3048,27 @@ def _load_optimizer_and_scheduler(self, checkpoint):
)
)
checkpoint_file_exists = (
glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}"))
glob.glob(os.path.join(checkpoint, f"{OPTIMIZER_NAME}"))
if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled
else checkpoint_file_exists
)

if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# Load in optimizer and scheduler states
if is_torch_xla_available():
from torchacc.dist.fsdp import FullyShardedDataParallel as FSDP
# On TPU we have to take some extra precautions to properly load the states on the right device.
if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled:
optimizer_state = torch.load(
os.path.join(
checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
),
map_location="cpu",
)
# We only need `optimizer` when resuming from checkpoint
optimizer_state = optimizer_state["optimizer"]
optimizer_state = None
if self.args.process_index == 0:
optimizer_state = torch.load(
os.path.join(
checkpoint, f"{OPTIMIZER_NAME}"
),
map_location="cpu",
)

optimizer_state = FSDP.load_optim_state_dict(self.model, optimizer_state, self.optimizer)
else:
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
with warnings.catch_warnings(record=True) as caught_warnings:
Expand Down