diff --git a/aepsych/generators/completion_criterion/min_asks.py b/aepsych/generators/completion_criterion/min_asks.py index f4319538b..5669a8220 100644 --- a/aepsych/generators/completion_criterion/min_asks.py +++ b/aepsych/generators/completion_criterion/min_asks.py @@ -5,21 +5,30 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict - -from ax.core.experiment import Experiment -from ax.modelbridge.completion_criterion import CompletionCriterion +from typing import Any, Dict, Optional, Set from aepsych.config import Config, ConfigurableMixin +from ax.core.experiment import Experiment +from ax.modelbridge.transition_criterion import TransitionCriterion -class MinAsks(CompletionCriterion, ConfigurableMixin): + +class MinAsks(TransitionCriterion, ConfigurableMixin): def __init__(self, threshold: int) -> None: self.threshold = threshold def is_met(self, experiment: Experiment) -> bool: return experiment.num_asks >= self.threshold + def block_continued_generation_error( + self, + node_name: Optional[str], + model_name: Optional[str], + experiment: Optional[Experiment], + trials_from_node: Optional[Set[int]] = None, + ) -> None: + pass + @classmethod def get_config_options(cls, config: Config, name: str) -> Dict[str, Any]: min_asks = config.getint(name, "min_asks", fallback=1)