From 4a20c181fcbd9646862013cc183b1848fc32f157 Mon Sep 17 00:00:00 2001 From: DocGarbanzo <47540921+DocGarbanzo@users.noreply.github.com> Date: Mon, 4 Apr 2022 21:55:20 +0100 Subject: [PATCH] Specify myconfig in training (#997) * Adding support for specifying own myconfig.py file in donkey train and small refactorings * Update version after rebase from main --- donkeycar/config.py | 16 ++++++---------- donkeycar/management/base.py | 25 ++++++++++++------------- scripts/profile.py | 2 +- setup.py | 2 +- 4 files changed, 20 insertions(+), 25 deletions(-) diff --git a/donkeycar/config.py b/donkeycar/config.py index f5a559e84..607fd35f9 100644 --- a/donkeycar/config.py +++ b/donkeycar/config.py @@ -6,6 +6,9 @@ """ import os import types +from logging import getLogger + +logger = getLogger(__name__) class Config: @@ -51,25 +54,18 @@ def load_config(config_path=None, myconfig="myconfig.py"): if os.path.exists(local_config): config_path = local_config - print('loading config file: {}'.format(config_path)) + logger.info(f'loading config file: {config_path}') cfg = Config() cfg.from_pyfile(config_path) # look for the optional myconfig.py in the same path. personal_cfg_path = config_path.replace("config.py", myconfig) if os.path.exists(personal_cfg_path): - print("loading personal config over-rides from", myconfig) + logger.info(f"loading personal config over-rides from {myconfig}") personal_cfg = Config() personal_cfg.from_pyfile(personal_cfg_path) cfg.from_object(personal_cfg) else: - print("personal config: file not found ", personal_cfg_path) - - # derived settings - if hasattr(cfg, 'IMAGE_H') and hasattr(cfg, 'IMAGE_W'): - cfg.TARGET_H = cfg.IMAGE_H - cfg.TARGET_W = cfg.IMAGE_W - if hasattr(cfg, 'IMAGE_DEPTH'): - cfg.TARGET_D = cfg.IMAGE_DEPTH + logger.warning(f"personal config: file not found {personal_cfg_path}") return cfg diff --git a/donkeycar/management/base.py b/donkeycar/management/base.py index 95c0deb9f..d5a604447 100644 --- a/donkeycar/management/base.py +++ b/donkeycar/management/base.py @@ -27,22 +27,20 @@ def make_dir(path): return real_path -def load_config(config_path): - - ''' +def load_config(config_path, myconfig='myconfig.py'): + """ load a config from the given path - ''' + """ conf = os.path.expanduser(config_path) - if not os.path.exists(conf): - print("No config file at location: %s. Add --config to specify\ - location or run from dir containing config.py." % conf) + logger.error(f"No config file at location: {conf}. Add --config to " + f"specify location or run from dir containing config.py.") return None try: - cfg = dk.load_config(conf) - except: - print("Exception while loading config from", conf) + cfg = dk.load_config(conf, myconfig) + except Exception as e: + logger.error(f"Exception {e} while loading config from {conf}") return None return cfg @@ -541,7 +539,8 @@ def parse_args(self, args): def run(self, args): args = self.parse_args(args) args.tub = ','.join(args.tub) - cfg = load_config(args.config) + my_cfg = args.myconfig + cfg = load_config(args.config, my_cfg) framework = args.framework if args.framework \ else getattr(cfg, 'DEFAULT_AI_FRAMEWORK', 'tensorflow') @@ -554,8 +553,8 @@ def run(self, args): train(cfg, args.tub, args.model, args.type, checkpoint_path=args.checkpoint) else: - print(f"Unrecognized framework: {framework}. Please specify one of " - f"'tensorflow' or 'pytorch'") + logger.error(f"Unrecognized framework: {framework}. Please specify " + f"one of 'tensorflow' or 'pytorch'") class ModelDatabase(BaseCommand): diff --git a/scripts/profile.py b/scripts/profile.py index 275264775..c6d48f3b7 100755 --- a/scripts/profile.py +++ b/scripts/profile.py @@ -21,7 +21,7 @@ def profile(model_path, model_type): model = dk.utils.get_model_by_type(model_type, cfg) model.load(model_path) - h, w, ch = cfg.TARGET_H, cfg.TARGET_W, cfg.TARGET_D + h, w, ch = cfg.IMAGE_H, cfg.IMAGE_W, cfg.IMAGE_DEPTH # generate random array in the right shape in [0,1) img = np.random.randint(0, 255, size=(h, w, ch)) diff --git a/setup.py b/setup.py index b0afb46d3..c1fa28385 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ def package_files(directory, strip_leading): long_description = fh.read() setup(name='donkeycar', - version="4.3.8", + version="4.3.9", long_description=long_description, description='Self driving library for python.', url='https://github.com/autorope/donkeycar',