From 7b9050798ddc2a13c6c77dd7fa4574ff0ed5b5ca Mon Sep 17 00:00:00 2001 From: Carl Doersch Date: Mon, 4 Mar 2024 03:12:19 -0800 Subject: [PATCH] modify ModelWithAux for use with custom train steps and multi-task learning. PiperOrigin-RevId: 612383436 --- kauldron/train/train_step.py | 58 ++++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/kauldron/train/train_step.py b/kauldron/train/train_step.py index 36315333..e025b194 100644 --- a/kauldron/train/train_step.py +++ b/kauldron/train/train_step.py @@ -185,14 +185,66 @@ class ModelWithAux(config_util.UpdateFromRootCfg): """Wrapper around model which also compute the summaries and metrics.""" model: nn.Module = config_util.ROOT_CFG_REF.model - losses: Mapping[str, kd_losses.Loss] = config_util.ROOT_CFG_REF.train_losses - metrics: Mapping[str, kd_metrics.Metric] = ( + + # These fields are configured via the root config, which is the common + # approach. + root_losses: Mapping[str, kd_losses.Loss] = ( + config_util.ROOT_CFG_REF.train_losses + ) + root_metrics: Mapping[str, kd_metrics.Metric] = ( config_util.ROOT_CFG_REF.train_metrics ) - summaries: Mapping[str, kd_summaries.Summary] = ( + root_summaries: Mapping[str, kd_summaries.Summary] = ( config_util.ROOT_CFG_REF.train_summaries ) + # These fields are configured via any custom calls to ModelWithAux, e.g. a + # multi-task setup where different ModelWithAux may have different losses. + # These will override any matching keys in the root. + local_losses: Mapping[str, kd_losses.Loss] = dataclasses.field( + default_factory=flax.core.FrozenDict + ) + local_metrics: Mapping[str, kd_metrics.Metric] = dataclasses.field( + default_factory=flax.core.FrozenDict + ) + local_summaries: Mapping[str, kd_summaries.Summary] = dataclasses.field( + default_factory=flax.core.FrozenDict + ) + + def __init__( + self, + model: Optional[nn.Module] = None, + losses: Optional[Mapping[str, kd_losses.Loss]] = None, + metrics: Optional[Mapping[str, kd_metrics.Metric]] = None, + summaries: Optional[Mapping[str, kd_summaries.Summary]] = None, + ): + if model is not None: + self.model = model + if losses is not None: + self.local_losses = flax.core.FrozenDict(losses) + if metrics is not None: + self.local_metrics = flax.core.FrozenDict(metrics) + if summaries is not None: + self.local_summaries = flax.core.FrozenDict(summaries) + + @property + def losses(self) -> Mapping[str, kd_losses.Loss]: + return flax.core.FrozenDict( + dict(self.root_losses) | dict(self.local_losses) + ) + + @property + def metrics(self) -> Mapping[str, kd_metrics.Metric]: + return flax.core.FrozenDict( + dict(self.root_metrics) | dict(self.local_metrics) + ) + + @property + def summaries(self) -> Mapping[str, kd_summaries.Summary]: + return flax.core.FrozenDict( + dict(self.root_summaries) | dict(self.local_summaries) + ) + def init( # pylint:disable=missing-function-docstring self, init_rngs: rngs_lib.Rngs,