From d3c92cf041d831afb804135299c40a249efc6194 Mon Sep 17 00:00:00 2001 From: Traun Leyden Date: Tue, 11 Jun 2024 15:43:26 +0000 Subject: [PATCH] Save to /retriever subdir --- dalm/training/retriever_only/train_retriever_only.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dalm/training/retriever_only/train_retriever_only.py b/dalm/training/retriever_only/train_retriever_only.py index c928625..10d4260 100644 --- a/dalm/training/retriever_only/train_retriever_only.py +++ b/dalm/training/retriever_only/train_retriever_only.py @@ -309,7 +309,7 @@ def train_retriever( accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) - logger.info("***** Running training *****") + logger.info("***** Running training (retriever only) *****") logger.info(f" Num examples = {len(processed_datasets)}") logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}") @@ -411,10 +411,12 @@ def train_retriever( if isinstance(checkpointing_steps, str): accelerator.save_state(os.path.join(output_dir, f"epoch_{epoch}")) + retriever_ckpt_path = output_dir + "/retriever" + accelerator.unwrap_model(model.model).save_pretrained( - output_dir, state_dict=accelerator.get_state_dict(accelerator.unwrap_model(model.model)) + retriever_ckpt_path, state_dict=accelerator.get_state_dict(accelerator.unwrap_model(model.model)) ) - tokenizer.save_pretrained(output_dir) + tokenizer.save_pretrained(retriever_ckpt_path) accelerator.wait_for_everyone() if with_tracking: accelerator.end_training()