Skip to content

Commit

Permalink
Add extra guards
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead authored Oct 17, 2023
1 parent bffa428 commit 473cfbf
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,20 @@ def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if "_target_" in cfg.strategy:
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
if "mixed_precision" in strategy.__dict__ and strategy.mixed_precision is not None:
if "mixed_precision" in strategy.__dict__ and getattr(strategy, "mixed_precision", None) is not None:
strategy.mixed_precision.param_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
if cfg.strategy.mixed_precision.param_dtype is not None
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
else None
)
strategy.mixed_precision.reduce_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
if cfg.strategy.mixed_precision.reduce_dtype is not None
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
else None
)
strategy.mixed_precision.buffer_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
if cfg.strategy.mixed_precision.buffer_dtype is not None
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
else None
)

Expand Down

0 comments on commit 473cfbf

Please sign in to comment.