Skip to content

Commit 6dbdb08

Browse files
yeqinglitensorflower-gardener
authored andcommitted
Refactors the run_experiment function for better reusability.
PiperOrigin-RevId: 458550388
1 parent f1add1b commit 6dbdb08

File tree

2 files changed

+292
-91
lines changed

2 files changed

+292
-91
lines changed

official/core/train_lib.py

Lines changed: 233 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,226 @@
3232
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
3333

3434

35+
class OrbitExperimentRunner:
36+
"""Runs experiment with Orbit training loop.
37+
38+
The default experiment runner for model garden experiments. User can
39+
customize the experiment pipeline by subclassing this class and replacing
40+
components or functions.
41+
42+
For example, an experiment runner with customized checkpoint manager:
43+
44+
```python
45+
class MyExpRunnerWithExporter(AbstractExperimentRunner):
46+
def _maybe_build_checkpoint_manager(sefl):
47+
return MyCheckpointManager(*args)
48+
49+
# In user code
50+
MyExpRunnerWithExporter(**needed_kwargs).run(mode)
51+
```
52+
53+
Similar override can be done to other components.
54+
"""
55+
56+
def __init__(
57+
self,
58+
distribution_strategy: tf.distribute.Strategy,
59+
task: base_task.Task,
60+
mode: str,
61+
params: config_definitions.ExperimentConfig,
62+
model_dir: str,
63+
run_post_eval: bool = False,
64+
save_summary: bool = True,
65+
train_actions: Optional[List[orbit.Action]] = None,
66+
eval_actions: Optional[List[orbit.Action]] = None,
67+
trainer: Optional[base_trainer.Trainer] = None,
68+
controller_cls=orbit.Controller
69+
):
70+
"""Constructor.
71+
72+
Args:
73+
distribution_strategy: A distribution strategy.
74+
task: A Task instance.
75+
mode: A 'str', specifying the mode. Can be 'train', 'eval',
76+
'train_and_eval' or 'continuous_eval'.
77+
params: ExperimentConfig instance.
78+
model_dir: A 'str', a path to store model checkpoints and summaries.
79+
run_post_eval: Whether to run post eval once after training, metrics logs
80+
are returned.
81+
save_summary: Whether to save train and validation summary.
82+
train_actions: Optional list of Orbit train actions.
83+
eval_actions: Optional list of Orbit eval actions.
84+
trainer: the base_trainer.Trainer instance. It should be created within
85+
the strategy.scope().
86+
controller_cls: The controller class to manage the train and eval process.
87+
Must be a orbit.Controller subclass.
88+
"""
89+
self.strategy = distribution_strategy or tf.distribute.get_strategy()
90+
self._params = params
91+
self._model_dir = model_dir
92+
self._mode = mode
93+
self._run_post_eval = run_post_eval
94+
95+
self._trainer = trainer or self._build_trainer(
96+
task,
97+
train='train' in mode,
98+
evaluate=('eval' in mode) or run_post_eval)
99+
assert self.trainer is not None
100+
self._checkpoint_manager = self._maybe_build_checkpoint_manager()
101+
self._controller = self._build_controller(
102+
trainer=self.trainer if 'train' in mode else None,
103+
evaluator=self.trainer,
104+
save_summary=save_summary,
105+
train_actions=train_actions,
106+
eval_actions=eval_actions,
107+
controller_cls=controller_cls)
108+
109+
@property
110+
def params(self) -> config_definitions.ExperimentConfig:
111+
return self._params
112+
113+
@property
114+
def model_dir(self) -> str:
115+
return self._model_dir
116+
117+
@property
118+
def trainer(self) -> base_trainer.Trainer:
119+
return self._trainer
120+
121+
@property
122+
def checkpoint_manager(self) -> tf.train.CheckpointManager:
123+
return self._checkpoint_manager
124+
125+
@property
126+
def controller(self) -> orbit.Controller:
127+
return self._controller
128+
129+
def _build_trainer(self, task: base_task.Task, train: bool,
130+
evaluate: bool) -> base_trainer.Trainer:
131+
"""Create trainer."""
132+
with self.strategy.scope():
133+
trainer = train_utils.create_trainer(
134+
self.params,
135+
task,
136+
train=train,
137+
evaluate=evaluate,
138+
checkpoint_exporter=self._build_best_checkpoint_exporter())
139+
return trainer
140+
141+
def _build_best_checkpoint_exporter(self):
142+
return maybe_create_best_ckpt_exporter(self.params, self.model_dir)
143+
144+
def _maybe_build_checkpoint_manager(
145+
self) -> Optional[tf.train.CheckpointManager]:
146+
"""Maybe create a CheckpointManager."""
147+
assert self.trainer is not None
148+
if self.trainer.checkpoint:
149+
if self.model_dir is None:
150+
raise ValueError('model_dir must be specified, but got None')
151+
checkpoint_manager = tf.train.CheckpointManager(
152+
self.trainer.checkpoint,
153+
directory=self.model_dir,
154+
max_to_keep=self.params.trainer.max_to_keep,
155+
step_counter=self.trainer.global_step,
156+
checkpoint_interval=self.params.trainer.checkpoint_interval,
157+
init_fn=self.trainer.initialize)
158+
else:
159+
checkpoint_manager = None
160+
return checkpoint_manager
161+
162+
def _build_controller(self,
163+
trainer,
164+
evaluator,
165+
save_summary: bool = True,
166+
train_actions: Optional[List[orbit.Action]] = None,
167+
eval_actions: Optional[List[orbit.Action]] = None,
168+
controller_cls=orbit.Controller) -> orbit.Controller:
169+
"""Builds a Orbit controler."""
170+
train_actions = [] if not train_actions else train_actions
171+
if trainer:
172+
train_actions += actions.get_train_actions(
173+
self.params,
174+
trainer,
175+
self.model_dir,
176+
checkpoint_manager=self.checkpoint_manager)
177+
178+
eval_actions = [] if not eval_actions else eval_actions
179+
if evaluator:
180+
eval_actions += actions.get_eval_actions(self.params, evaluator,
181+
self.model_dir)
182+
183+
controller = controller_cls(
184+
strategy=self.strategy,
185+
trainer=trainer,
186+
evaluator=evaluator,
187+
global_step=self.trainer.global_step,
188+
steps_per_loop=self.params.trainer.steps_per_loop,
189+
checkpoint_manager=self.checkpoint_manager,
190+
summary_dir=os.path.join(self.model_dir, 'train') if
191+
(save_summary) else None,
192+
eval_summary_dir=os.path.join(
193+
self.model_dir, self.params.trainer.validation_summary_subdir) if
194+
(save_summary) else None,
195+
summary_interval=self.params.trainer.summary_interval if
196+
(save_summary) else None,
197+
train_actions=train_actions,
198+
eval_actions=eval_actions)
199+
return controller
200+
201+
def run(self) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
202+
"""Run experiments by mode.
203+
204+
Returns:
205+
A 2-tuple of (model, eval_logs).
206+
model: `tf.keras.Model` instance.
207+
eval_logs: returns eval metrics logs when run_post_eval is set to True,
208+
otherwise, returns {}.
209+
"""
210+
mode = self._mode
211+
params = self.params
212+
logging.info('Starts to execute mode: %s', mode)
213+
with self.strategy.scope():
214+
if mode == 'train' or mode == 'train_and_post_eval':
215+
self.controller.train(steps=params.trainer.train_steps)
216+
elif mode == 'train_and_eval':
217+
self.controller.train_and_evaluate(
218+
train_steps=params.trainer.train_steps,
219+
eval_steps=params.trainer.validation_steps,
220+
eval_interval=params.trainer.validation_interval)
221+
elif mode == 'eval':
222+
self.controller.evaluate(steps=params.trainer.validation_steps)
223+
elif mode == 'continuous_eval':
224+
225+
def timeout_fn():
226+
if self.trainer.global_step.numpy() >= params.trainer.train_steps:
227+
return True
228+
return False
229+
230+
self.controller.evaluate_continuously(
231+
steps=params.trainer.validation_steps,
232+
timeout=params.trainer.continuous_eval_timeout,
233+
timeout_fn=timeout_fn)
234+
else:
235+
raise NotImplementedError('The mode is not implemented: %s' % mode)
236+
237+
num_params = train_utils.try_count_params(self.trainer.model)
238+
if num_params is not None:
239+
logging.info('Number of trainable params in model: %f Millions.',
240+
num_params / 10.**6)
241+
242+
flops = train_utils.try_count_flops(self.trainer.model)
243+
if flops is not None:
244+
logging.info('FLOPs (multi-adds) in model: %f Billions.',
245+
flops / 10.**9 / 2)
246+
247+
if self._run_post_eval or mode == 'train_and_post_eval':
248+
with self.strategy.scope():
249+
return self.trainer.model, self.controller.evaluate(
250+
steps=params.trainer.validation_steps)
251+
else:
252+
return self.trainer.model, {}
253+
254+
35255
def run_experiment(
36256
distribution_strategy: tf.distribute.Strategy,
37257
task: base_task.Task,
@@ -70,91 +290,17 @@ def run_experiment(
70290
eval_logs: returns eval metrics logs when run_post_eval is set to True,
71291
otherwise, returns {}.
72292
"""
73-
74-
with distribution_strategy.scope():
75-
if not trainer:
76-
trainer = train_utils.create_trainer(
77-
params,
78-
task,
79-
train='train' in mode,
80-
evaluate=('eval' in mode) or run_post_eval,
81-
checkpoint_exporter=maybe_create_best_ckpt_exporter(
82-
params, model_dir))
83-
84-
if trainer.checkpoint:
85-
if model_dir is None:
86-
raise ValueError('model_dir must be specified, but got None')
87-
checkpoint_manager = tf.train.CheckpointManager(
88-
trainer.checkpoint,
89-
directory=model_dir,
90-
max_to_keep=params.trainer.max_to_keep,
91-
step_counter=trainer.global_step,
92-
checkpoint_interval=params.trainer.checkpoint_interval,
93-
init_fn=trainer.initialize)
94-
else:
95-
checkpoint_manager = None
96-
97-
train_actions = [] if not train_actions else train_actions
98-
train_actions += actions.get_train_actions(
99-
params, trainer, model_dir, checkpoint_manager=checkpoint_manager)
100-
101-
eval_actions = [] if not eval_actions else eval_actions
102-
eval_actions += actions.get_eval_actions(params, trainer, model_dir)
103-
104-
controller = controller_cls(
105-
strategy=distribution_strategy,
106-
trainer=trainer if 'train' in mode else None,
107-
evaluator=trainer,
108-
global_step=trainer.global_step,
109-
steps_per_loop=params.trainer.steps_per_loop,
110-
checkpoint_manager=checkpoint_manager,
111-
summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
112-
eval_summary_dir=os.path.join(model_dir,
113-
params.trainer.validation_summary_subdir) if
114-
(save_summary) else None,
115-
summary_interval=params.trainer.summary_interval if
116-
(save_summary) else None,
293+
runner = OrbitExperimentRunner(
294+
distribution_strategy=distribution_strategy,
295+
task=task,
296+
mode=mode,
297+
params=params,
298+
model_dir=model_dir,
299+
run_post_eval=run_post_eval,
300+
save_summary=save_summary,
117301
train_actions=train_actions,
118-
eval_actions=eval_actions)
119-
120-
logging.info('Starts to execute mode: %s', mode)
121-
with distribution_strategy.scope():
122-
if mode == 'train' or mode == 'train_and_post_eval':
123-
controller.train(steps=params.trainer.train_steps)
124-
elif mode == 'train_and_eval':
125-
controller.train_and_evaluate(
126-
train_steps=params.trainer.train_steps,
127-
eval_steps=params.trainer.validation_steps,
128-
eval_interval=params.trainer.validation_interval)
129-
elif mode == 'eval':
130-
controller.evaluate(steps=params.trainer.validation_steps)
131-
elif mode == 'continuous_eval':
132-
133-
def timeout_fn():
134-
if trainer.global_step.numpy() >= params.trainer.train_steps:
135-
return True
136-
return False
137-
138-
controller.evaluate_continuously(
139-
steps=params.trainer.validation_steps,
140-
timeout=params.trainer.continuous_eval_timeout,
141-
timeout_fn=timeout_fn)
142-
else:
143-
raise NotImplementedError('The mode is not implemented: %s' % mode)
144-
145-
num_params = train_utils.try_count_params(trainer.model)
146-
if num_params is not None:
147-
logging.info('Number of trainable params in model: %f Millions.',
148-
num_params / 10.**6)
149-
150-
flops = train_utils.try_count_flops(trainer.model)
151-
if flops is not None:
152-
logging.info('FLOPs (multi-adds) in model: %f Billions.',
153-
flops / 10.**9 / 2)
154-
155-
if run_post_eval or mode == 'train_and_post_eval':
156-
with distribution_strategy.scope():
157-
return trainer.model, controller.evaluate(
158-
steps=params.trainer.validation_steps)
159-
else:
160-
return trainer.model, {}
302+
eval_actions=eval_actions,
303+
trainer=trainer,
304+
controller_cls=controller_cls,
305+
)
306+
return runner.run()

0 commit comments

Comments
 (0)