Skip to content

Commit

Permalink
Save to /retriever subdir
Browse files Browse the repository at this point in the history
  • Loading branch information
tleyden committed Jun 11, 2024
1 parent 4431d51 commit d3c92cf
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions dalm/training/retriever_only/train_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def train_retriever(
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

logger.info("***** Running training *****")
logger.info("***** Running training (retriever only) *****")
logger.info(f" Num examples = {len(processed_datasets)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}")
Expand Down Expand Up @@ -411,10 +411,12 @@ def train_retriever(
if isinstance(checkpointing_steps, str):
accelerator.save_state(os.path.join(output_dir, f"epoch_{epoch}"))

retriever_ckpt_path = output_dir + "/retriever"

accelerator.unwrap_model(model.model).save_pretrained(
output_dir, state_dict=accelerator.get_state_dict(accelerator.unwrap_model(model.model))
retriever_ckpt_path, state_dict=accelerator.get_state_dict(accelerator.unwrap_model(model.model))
)
tokenizer.save_pretrained(output_dir)
tokenizer.save_pretrained(retriever_ckpt_path)
accelerator.wait_for_everyone()
if with_tracking:
accelerator.end_training()
Expand Down

0 comments on commit d3c92cf

Please sign in to comment.