From e34c71f772145d89af87b989a7402e7a9261a8cb Mon Sep 17 00:00:00 2001 From: aaronphilip19 <97271769+aaronphilip19@users.noreply.github.com> Date: Fri, 26 Apr 2024 16:43:36 -0400 Subject: [PATCH] Update model_manager.py - trainer config change - Issue #240 --- zamba/models/model_manager.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/zamba/models/model_manager.py b/zamba/models/model_manager.py index a80c8014..fb32ac63 100644 --- a/zamba/models/model_manager.py +++ b/zamba/models/model_manager.py @@ -280,19 +280,17 @@ def train_model( accelerator, devices = configure_accelerator_and_devices_from_gpus(train_config.gpus) + + multiprocessing_strategy = getattr(train_config,"multiprocessing_strategy",None) + trainer = pl.Trainer( accelerator=accelerator, devices=devices, max_epochs=train_config.max_epochs, logger=tensorboard_logger, - callbacks=callbacks, - fast_dev_run=train_config.dry_run, - strategy=( - DDPStrategy(find_unused_parameters=False) - if (data_module.multiprocessing_context is not None) and (train_config.gpus > 1) - else "auto" - ), + strategy = multiprocessing_strategy, ) + #Set the strategy within trainer to reflect changes if video_loader_config.cache_dir is None: logger.info("No cache dir is specified. Videos will not be cached.")