Skip to content

Commit

Permalink
Update discriminator option
Browse files Browse the repository at this point in the history
  • Loading branch information
lixfz committed Feb 23, 2024
1 parent 53fbb51 commit 0a6f962
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions hypernets/experiment/_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def make_experiment(hyper_model_cls,
- nf
optimize_direction : str, optional
Hypernets search reward metric direction, default is detected from reward_metric.
discriminator : instance of hypernets.discriminator.BaseDiscriminator, optional
Discriminator is used to determine whether to continue training
discriminator : instance of hypernets.discriminator.BaseDiscriminator or bool, optional
Discriminator is used to determine whether to continue training, set False to disable it.
hyper_model_options: dict, optional
Options to initlize HyperModel except *reward_metric*, *task*, *callbacks*, *discriminator*.
evaluation_metrics: str, list, or None (default='auto'),
Expand Down Expand Up @@ -365,10 +365,14 @@ def append_early_stopping_callbacks(cbs):
report_render = to_report_render_object(report_render, report_render_options)
callbacks.append(MLReportCallback(report_render))

if discriminator is None and cfg.experiment_discriminator is not None and len(cfg.experiment_discriminator) > 0:
if ((discriminator is None or discriminator is True)
and cfg.experiment_discriminator is not None
and len(cfg.experiment_discriminator) > 0):
discriminator = make_discriminator(cfg.experiment_discriminator,
optimize_direction=optimize_direction,
**(cfg.experiment_discriminator_options or {}))
elif discriminator is False:
discriminator = None

if id is None:
hasher = tb.data_hasher()
Expand Down

0 comments on commit 0a6f962

Please sign in to comment.