diff --git a/generation/maisi/scripts/diff_model_train.py b/generation/maisi/scripts/diff_model_train.py index e6bfcdd7c..e47f58f93 100644 --- a/generation/maisi/scripts/diff_model_train.py +++ b/generation/maisi/scripts/diff_model_train.py @@ -357,7 +357,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat )[local_rank] train_loader = prepare_data( - train_files, device, args.diffusion_unet_train["cache_rate"], args.diffusion_unet_train["batch_size"] + train_files, device, args.diffusion_unet_train["cache_rate"], batch_size=args.diffusion_unet_train["batch_size"] ) unet = load_unet(args, device, logger)