diff --git a/zamba/models/model_manager.py b/zamba/models/model_manager.py index a80c8014..fb32ac63 100644 --- a/zamba/models/model_manager.py +++ b/zamba/models/model_manager.py @@ -280,19 +280,17 @@ def train_model( accelerator, devices = configure_accelerator_and_devices_from_gpus(train_config.gpus) + + multiprocessing_strategy = getattr(train_config,"multiprocessing_strategy",None) + trainer = pl.Trainer( accelerator=accelerator, devices=devices, max_epochs=train_config.max_epochs, logger=tensorboard_logger, - callbacks=callbacks, - fast_dev_run=train_config.dry_run, - strategy=( - DDPStrategy(find_unused_parameters=False) - if (data_module.multiprocessing_context is not None) and (train_config.gpus > 1) - else "auto" - ), + strategy = multiprocessing_strategy, ) + #Set the strategy within trainer to reflect changes if video_loader_config.cache_dir is None: logger.info("No cache dir is specified. Videos will not be cached.")