Skip to content

Commit

Permalink
Partial reproducibility script
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-healey committed Oct 6, 2023
1 parent e311edd commit 467691c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
52 changes: 52 additions & 0 deletions src/llama2d/modal/repro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from common import transformers_dir,llama_recipes_dir,root_dir
import os
import sys

def check_all_code_committed(dir):

old_dir = os.getcwd()
os.chdir(dir)

# assert that all code in current directory is committed
git_diff = os.popen(f"git diff").read()
git_diff_cached = os.popen("git diff --cached").read()

dir_name = os.path.basename(dir)
assert (
git_diff == "" and git_diff_cached == ""
), f"Please commit all code in {dir_name} before running this script."

git_commit_hash = os.popen(f"git rev-parse HEAD").read().strip()

# assert that all code in transformers is committed
os.chdir(old_dir)

return git_commit_hash

def check_llama2d_code():
llama2d = check_all_code_committed(root_dir)
transformers = check_all_code_committed(transformers_dir)
llama_recipes = check_all_code_committed(llama_recipes_dir)

return {
"llama2d": llama2d,
"transformers": transformers,
"llama_recipes": llama_recipes,
}

def make_repro_command():
commits = check_llama2d_code()

# get full command line command
command = " ".join(sys.argv)

# TODO: fill in HF dataset name if it's not there

return f"""
# run in llama2d
git checkout {commits["llama2d"]}
cd transformers && git checkout {commits["transformers"]}
cd ../llama-recipes && git checkout {commits["llama_recipes"]}
cd src/llama2d/modal
{command}
"""
6 changes: 5 additions & 1 deletion src/llama2d/modal/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys

from repro import make_repro_command
from common import BASE_MODELS, GPU_MEM, N_GPUS, VOLUME_CONFIG, stub
from modal import Mount, Secret, gpu

Expand Down Expand Up @@ -88,7 +89,6 @@ def train(train_kwargs):
print("Committing results volume (no progress bar) ...")
stub.results_volume.commit()


@stub.local_entrypoint() # Runs locally to kick off remote training job.
def main(
dataset: str,
Expand Down Expand Up @@ -120,6 +120,10 @@ def main(
print(f"Syncing base model {model_name} to volume.")
download.remote(model_name)

cmd = make_repro_command()
print(cmd)
raise Exception("Done")

if not run_id:
import secrets

Expand Down

0 comments on commit 467691c

Please sign in to comment.