|
32 | 32 | maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
|
33 | 33 |
|
34 | 34 |
|
| 35 | +class OrbitExperimentRunner: |
| 36 | + """Runs experiment with Orbit training loop. |
| 37 | +
|
| 38 | + The default experiment runner for model garden experiments. User can |
| 39 | + customize the experiment pipeline by subclassing this class and replacing |
| 40 | + components or functions. |
| 41 | +
|
| 42 | + For example, an experiment runner with customized checkpoint manager: |
| 43 | +
|
| 44 | + ```python |
| 45 | + class MyExpRunnerWithExporter(AbstractExperimentRunner): |
| 46 | + def _maybe_build_checkpoint_manager(sefl): |
| 47 | + return MyCheckpointManager(*args) |
| 48 | +
|
| 49 | + # In user code |
| 50 | + MyExpRunnerWithExporter(**needed_kwargs).run(mode) |
| 51 | + ``` |
| 52 | +
|
| 53 | + Similar override can be done to other components. |
| 54 | + """ |
| 55 | + |
| 56 | + def __init__( |
| 57 | + self, |
| 58 | + distribution_strategy: tf.distribute.Strategy, |
| 59 | + task: base_task.Task, |
| 60 | + mode: str, |
| 61 | + params: config_definitions.ExperimentConfig, |
| 62 | + model_dir: str, |
| 63 | + run_post_eval: bool = False, |
| 64 | + save_summary: bool = True, |
| 65 | + train_actions: Optional[List[orbit.Action]] = None, |
| 66 | + eval_actions: Optional[List[orbit.Action]] = None, |
| 67 | + trainer: Optional[base_trainer.Trainer] = None, |
| 68 | + controller_cls=orbit.Controller |
| 69 | + ): |
| 70 | + """Constructor. |
| 71 | +
|
| 72 | + Args: |
| 73 | + distribution_strategy: A distribution strategy. |
| 74 | + task: A Task instance. |
| 75 | + mode: A 'str', specifying the mode. Can be 'train', 'eval', |
| 76 | + 'train_and_eval' or 'continuous_eval'. |
| 77 | + params: ExperimentConfig instance. |
| 78 | + model_dir: A 'str', a path to store model checkpoints and summaries. |
| 79 | + run_post_eval: Whether to run post eval once after training, metrics logs |
| 80 | + are returned. |
| 81 | + save_summary: Whether to save train and validation summary. |
| 82 | + train_actions: Optional list of Orbit train actions. |
| 83 | + eval_actions: Optional list of Orbit eval actions. |
| 84 | + trainer: the base_trainer.Trainer instance. It should be created within |
| 85 | + the strategy.scope(). |
| 86 | + controller_cls: The controller class to manage the train and eval process. |
| 87 | + Must be a orbit.Controller subclass. |
| 88 | + """ |
| 89 | + self.strategy = distribution_strategy or tf.distribute.get_strategy() |
| 90 | + self._params = params |
| 91 | + self._model_dir = model_dir |
| 92 | + self._mode = mode |
| 93 | + self._run_post_eval = run_post_eval |
| 94 | + |
| 95 | + self._trainer = trainer or self._build_trainer( |
| 96 | + task, |
| 97 | + train='train' in mode, |
| 98 | + evaluate=('eval' in mode) or run_post_eval) |
| 99 | + assert self.trainer is not None |
| 100 | + self._checkpoint_manager = self._maybe_build_checkpoint_manager() |
| 101 | + self._controller = self._build_controller( |
| 102 | + trainer=self.trainer if 'train' in mode else None, |
| 103 | + evaluator=self.trainer, |
| 104 | + save_summary=save_summary, |
| 105 | + train_actions=train_actions, |
| 106 | + eval_actions=eval_actions, |
| 107 | + controller_cls=controller_cls) |
| 108 | + |
| 109 | + @property |
| 110 | + def params(self) -> config_definitions.ExperimentConfig: |
| 111 | + return self._params |
| 112 | + |
| 113 | + @property |
| 114 | + def model_dir(self) -> str: |
| 115 | + return self._model_dir |
| 116 | + |
| 117 | + @property |
| 118 | + def trainer(self) -> base_trainer.Trainer: |
| 119 | + return self._trainer |
| 120 | + |
| 121 | + @property |
| 122 | + def checkpoint_manager(self) -> tf.train.CheckpointManager: |
| 123 | + return self._checkpoint_manager |
| 124 | + |
| 125 | + @property |
| 126 | + def controller(self) -> orbit.Controller: |
| 127 | + return self._controller |
| 128 | + |
| 129 | + def _build_trainer(self, task: base_task.Task, train: bool, |
| 130 | + evaluate: bool) -> base_trainer.Trainer: |
| 131 | + """Create trainer.""" |
| 132 | + with self.strategy.scope(): |
| 133 | + trainer = train_utils.create_trainer( |
| 134 | + self.params, |
| 135 | + task, |
| 136 | + train=train, |
| 137 | + evaluate=evaluate, |
| 138 | + checkpoint_exporter=self._build_best_checkpoint_exporter()) |
| 139 | + return trainer |
| 140 | + |
| 141 | + def _build_best_checkpoint_exporter(self): |
| 142 | + return maybe_create_best_ckpt_exporter(self.params, self.model_dir) |
| 143 | + |
| 144 | + def _maybe_build_checkpoint_manager( |
| 145 | + self) -> Optional[tf.train.CheckpointManager]: |
| 146 | + """Maybe create a CheckpointManager.""" |
| 147 | + assert self.trainer is not None |
| 148 | + if self.trainer.checkpoint: |
| 149 | + if self.model_dir is None: |
| 150 | + raise ValueError('model_dir must be specified, but got None') |
| 151 | + checkpoint_manager = tf.train.CheckpointManager( |
| 152 | + self.trainer.checkpoint, |
| 153 | + directory=self.model_dir, |
| 154 | + max_to_keep=self.params.trainer.max_to_keep, |
| 155 | + step_counter=self.trainer.global_step, |
| 156 | + checkpoint_interval=self.params.trainer.checkpoint_interval, |
| 157 | + init_fn=self.trainer.initialize) |
| 158 | + else: |
| 159 | + checkpoint_manager = None |
| 160 | + return checkpoint_manager |
| 161 | + |
| 162 | + def _build_controller(self, |
| 163 | + trainer, |
| 164 | + evaluator, |
| 165 | + save_summary: bool = True, |
| 166 | + train_actions: Optional[List[orbit.Action]] = None, |
| 167 | + eval_actions: Optional[List[orbit.Action]] = None, |
| 168 | + controller_cls=orbit.Controller) -> orbit.Controller: |
| 169 | + """Builds a Orbit controler.""" |
| 170 | + train_actions = [] if not train_actions else train_actions |
| 171 | + if trainer: |
| 172 | + train_actions += actions.get_train_actions( |
| 173 | + self.params, |
| 174 | + trainer, |
| 175 | + self.model_dir, |
| 176 | + checkpoint_manager=self.checkpoint_manager) |
| 177 | + |
| 178 | + eval_actions = [] if not eval_actions else eval_actions |
| 179 | + if evaluator: |
| 180 | + eval_actions += actions.get_eval_actions(self.params, evaluator, |
| 181 | + self.model_dir) |
| 182 | + |
| 183 | + controller = controller_cls( |
| 184 | + strategy=self.strategy, |
| 185 | + trainer=trainer, |
| 186 | + evaluator=evaluator, |
| 187 | + global_step=self.trainer.global_step, |
| 188 | + steps_per_loop=self.params.trainer.steps_per_loop, |
| 189 | + checkpoint_manager=self.checkpoint_manager, |
| 190 | + summary_dir=os.path.join(self.model_dir, 'train') if |
| 191 | + (save_summary) else None, |
| 192 | + eval_summary_dir=os.path.join( |
| 193 | + self.model_dir, self.params.trainer.validation_summary_subdir) if |
| 194 | + (save_summary) else None, |
| 195 | + summary_interval=self.params.trainer.summary_interval if |
| 196 | + (save_summary) else None, |
| 197 | + train_actions=train_actions, |
| 198 | + eval_actions=eval_actions) |
| 199 | + return controller |
| 200 | + |
| 201 | + def run(self) -> Tuple[tf.keras.Model, Mapping[str, Any]]: |
| 202 | + """Run experiments by mode. |
| 203 | +
|
| 204 | + Returns: |
| 205 | + A 2-tuple of (model, eval_logs). |
| 206 | + model: `tf.keras.Model` instance. |
| 207 | + eval_logs: returns eval metrics logs when run_post_eval is set to True, |
| 208 | + otherwise, returns {}. |
| 209 | + """ |
| 210 | + mode = self._mode |
| 211 | + params = self.params |
| 212 | + logging.info('Starts to execute mode: %s', mode) |
| 213 | + with self.strategy.scope(): |
| 214 | + if mode == 'train' or mode == 'train_and_post_eval': |
| 215 | + self.controller.train(steps=params.trainer.train_steps) |
| 216 | + elif mode == 'train_and_eval': |
| 217 | + self.controller.train_and_evaluate( |
| 218 | + train_steps=params.trainer.train_steps, |
| 219 | + eval_steps=params.trainer.validation_steps, |
| 220 | + eval_interval=params.trainer.validation_interval) |
| 221 | + elif mode == 'eval': |
| 222 | + self.controller.evaluate(steps=params.trainer.validation_steps) |
| 223 | + elif mode == 'continuous_eval': |
| 224 | + |
| 225 | + def timeout_fn(): |
| 226 | + if self.trainer.global_step.numpy() >= params.trainer.train_steps: |
| 227 | + return True |
| 228 | + return False |
| 229 | + |
| 230 | + self.controller.evaluate_continuously( |
| 231 | + steps=params.trainer.validation_steps, |
| 232 | + timeout=params.trainer.continuous_eval_timeout, |
| 233 | + timeout_fn=timeout_fn) |
| 234 | + else: |
| 235 | + raise NotImplementedError('The mode is not implemented: %s' % mode) |
| 236 | + |
| 237 | + num_params = train_utils.try_count_params(self.trainer.model) |
| 238 | + if num_params is not None: |
| 239 | + logging.info('Number of trainable params in model: %f Millions.', |
| 240 | + num_params / 10.**6) |
| 241 | + |
| 242 | + flops = train_utils.try_count_flops(self.trainer.model) |
| 243 | + if flops is not None: |
| 244 | + logging.info('FLOPs (multi-adds) in model: %f Billions.', |
| 245 | + flops / 10.**9 / 2) |
| 246 | + |
| 247 | + if self._run_post_eval or mode == 'train_and_post_eval': |
| 248 | + with self.strategy.scope(): |
| 249 | + return self.trainer.model, self.controller.evaluate( |
| 250 | + steps=params.trainer.validation_steps) |
| 251 | + else: |
| 252 | + return self.trainer.model, {} |
| 253 | + |
| 254 | + |
35 | 255 | def run_experiment(
|
36 | 256 | distribution_strategy: tf.distribute.Strategy,
|
37 | 257 | task: base_task.Task,
|
@@ -70,91 +290,17 @@ def run_experiment(
|
70 | 290 | eval_logs: returns eval metrics logs when run_post_eval is set to True,
|
71 | 291 | otherwise, returns {}.
|
72 | 292 | """
|
73 |
| - |
74 |
| - with distribution_strategy.scope(): |
75 |
| - if not trainer: |
76 |
| - trainer = train_utils.create_trainer( |
77 |
| - params, |
78 |
| - task, |
79 |
| - train='train' in mode, |
80 |
| - evaluate=('eval' in mode) or run_post_eval, |
81 |
| - checkpoint_exporter=maybe_create_best_ckpt_exporter( |
82 |
| - params, model_dir)) |
83 |
| - |
84 |
| - if trainer.checkpoint: |
85 |
| - if model_dir is None: |
86 |
| - raise ValueError('model_dir must be specified, but got None') |
87 |
| - checkpoint_manager = tf.train.CheckpointManager( |
88 |
| - trainer.checkpoint, |
89 |
| - directory=model_dir, |
90 |
| - max_to_keep=params.trainer.max_to_keep, |
91 |
| - step_counter=trainer.global_step, |
92 |
| - checkpoint_interval=params.trainer.checkpoint_interval, |
93 |
| - init_fn=trainer.initialize) |
94 |
| - else: |
95 |
| - checkpoint_manager = None |
96 |
| - |
97 |
| - train_actions = [] if not train_actions else train_actions |
98 |
| - train_actions += actions.get_train_actions( |
99 |
| - params, trainer, model_dir, checkpoint_manager=checkpoint_manager) |
100 |
| - |
101 |
| - eval_actions = [] if not eval_actions else eval_actions |
102 |
| - eval_actions += actions.get_eval_actions(params, trainer, model_dir) |
103 |
| - |
104 |
| - controller = controller_cls( |
105 |
| - strategy=distribution_strategy, |
106 |
| - trainer=trainer if 'train' in mode else None, |
107 |
| - evaluator=trainer, |
108 |
| - global_step=trainer.global_step, |
109 |
| - steps_per_loop=params.trainer.steps_per_loop, |
110 |
| - checkpoint_manager=checkpoint_manager, |
111 |
| - summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None, |
112 |
| - eval_summary_dir=os.path.join(model_dir, |
113 |
| - params.trainer.validation_summary_subdir) if |
114 |
| - (save_summary) else None, |
115 |
| - summary_interval=params.trainer.summary_interval if |
116 |
| - (save_summary) else None, |
| 293 | + runner = OrbitExperimentRunner( |
| 294 | + distribution_strategy=distribution_strategy, |
| 295 | + task=task, |
| 296 | + mode=mode, |
| 297 | + params=params, |
| 298 | + model_dir=model_dir, |
| 299 | + run_post_eval=run_post_eval, |
| 300 | + save_summary=save_summary, |
117 | 301 | train_actions=train_actions,
|
118 |
| - eval_actions=eval_actions) |
119 |
| - |
120 |
| - logging.info('Starts to execute mode: %s', mode) |
121 |
| - with distribution_strategy.scope(): |
122 |
| - if mode == 'train' or mode == 'train_and_post_eval': |
123 |
| - controller.train(steps=params.trainer.train_steps) |
124 |
| - elif mode == 'train_and_eval': |
125 |
| - controller.train_and_evaluate( |
126 |
| - train_steps=params.trainer.train_steps, |
127 |
| - eval_steps=params.trainer.validation_steps, |
128 |
| - eval_interval=params.trainer.validation_interval) |
129 |
| - elif mode == 'eval': |
130 |
| - controller.evaluate(steps=params.trainer.validation_steps) |
131 |
| - elif mode == 'continuous_eval': |
132 |
| - |
133 |
| - def timeout_fn(): |
134 |
| - if trainer.global_step.numpy() >= params.trainer.train_steps: |
135 |
| - return True |
136 |
| - return False |
137 |
| - |
138 |
| - controller.evaluate_continuously( |
139 |
| - steps=params.trainer.validation_steps, |
140 |
| - timeout=params.trainer.continuous_eval_timeout, |
141 |
| - timeout_fn=timeout_fn) |
142 |
| - else: |
143 |
| - raise NotImplementedError('The mode is not implemented: %s' % mode) |
144 |
| - |
145 |
| - num_params = train_utils.try_count_params(trainer.model) |
146 |
| - if num_params is not None: |
147 |
| - logging.info('Number of trainable params in model: %f Millions.', |
148 |
| - num_params / 10.**6) |
149 |
| - |
150 |
| - flops = train_utils.try_count_flops(trainer.model) |
151 |
| - if flops is not None: |
152 |
| - logging.info('FLOPs (multi-adds) in model: %f Billions.', |
153 |
| - flops / 10.**9 / 2) |
154 |
| - |
155 |
| - if run_post_eval or mode == 'train_and_post_eval': |
156 |
| - with distribution_strategy.scope(): |
157 |
| - return trainer.model, controller.evaluate( |
158 |
| - steps=params.trainer.validation_steps) |
159 |
| - else: |
160 |
| - return trainer.model, {} |
| 302 | + eval_actions=eval_actions, |
| 303 | + trainer=trainer, |
| 304 | + controller_cls=controller_cls, |
| 305 | + ) |
| 306 | + return runner.run() |
0 commit comments