From 467691c87c2bc72742a254af57f9807b29baa316 Mon Sep 17 00:00:00 2001 From: Andrew Healey Date: Fri, 6 Oct 2023 11:40:49 +0100 Subject: [PATCH] Partial reproducibility script --- src/llama2d/modal/repro.py | 52 ++++++++++++++++++++++++++++++++++++++ src/llama2d/modal/train.py | 6 ++++- 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 src/llama2d/modal/repro.py diff --git a/src/llama2d/modal/repro.py b/src/llama2d/modal/repro.py new file mode 100644 index 0000000..4765e05 --- /dev/null +++ b/src/llama2d/modal/repro.py @@ -0,0 +1,52 @@ +from common import transformers_dir,llama_recipes_dir,root_dir +import os +import sys + +def check_all_code_committed(dir): + + old_dir = os.getcwd() + os.chdir(dir) + + # assert that all code in current directory is committed + git_diff = os.popen(f"git diff").read() + git_diff_cached = os.popen("git diff --cached").read() + + dir_name = os.path.basename(dir) + assert ( + git_diff == "" and git_diff_cached == "" + ), f"Please commit all code in {dir_name} before running this script." + + git_commit_hash = os.popen(f"git rev-parse HEAD").read().strip() + + # assert that all code in transformers is committed + os.chdir(old_dir) + + return git_commit_hash + +def check_llama2d_code(): + llama2d = check_all_code_committed(root_dir) + transformers = check_all_code_committed(transformers_dir) + llama_recipes = check_all_code_committed(llama_recipes_dir) + + return { + "llama2d": llama2d, + "transformers": transformers, + "llama_recipes": llama_recipes, + } + +def make_repro_command(): + commits = check_llama2d_code() + + # get full command line command + command = " ".join(sys.argv) + + # TODO: fill in HF dataset name if it's not there + + return f""" + # run in llama2d + git checkout {commits["llama2d"]} + cd transformers && git checkout {commits["transformers"]} + cd ../llama-recipes && git checkout {commits["llama_recipes"]} + cd src/llama2d/modal + {command} + """ \ No newline at end of file diff --git a/src/llama2d/modal/train.py b/src/llama2d/modal/train.py index f883f6b..b9e5f5b 100644 --- a/src/llama2d/modal/train.py +++ b/src/llama2d/modal/train.py @@ -1,6 +1,7 @@ import os import sys +from repro import make_repro_command from common import BASE_MODELS, GPU_MEM, N_GPUS, VOLUME_CONFIG, stub from modal import Mount, Secret, gpu @@ -88,7 +89,6 @@ def train(train_kwargs): print("Committing results volume (no progress bar) ...") stub.results_volume.commit() - @stub.local_entrypoint() # Runs locally to kick off remote training job. def main( dataset: str, @@ -120,6 +120,10 @@ def main( print(f"Syncing base model {model_name} to volume.") download.remote(model_name) + cmd = make_repro_command() + print(cmd) + raise Exception("Done") + if not run_id: import secrets