diff --git a/train_model.py b/train_model.py index be4f25f9..8ecc148b 100644 --- a/train_model.py +++ b/train_model.py @@ -1,3 +1,4 @@ +import random import torch import pytorch_lightning as pl from lightning_fabric.utilities import seed @@ -82,6 +83,9 @@ def main(): assert args.step_length <= 3, "Too high step length" assert args.eval in (None, "val", "test"), f"Unknown eval setting: {args.eval}" + # Get an (actual) random run id as a unique identifier + random_run_id = random.randint(0, 9999) + # Set seed seed.seed_everything(args.seed) @@ -120,7 +124,7 @@ def main(): if args.eval: prefix = prefix + f"eval-{args.eval}-" run_name = f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-"\ - f"{time.strftime('%m_%d_%H_%M_%S')}" + f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}" checkpoint_callback = pl.callbacks.ModelCheckpoint( dirpath=f"saved_models/{run_name}", filename="min_val_loss", monitor="val_mean_loss", mode="min", save_last=True)