Skip to content

Commit

Permalink
Update min_asks to directly inherit from TransitionCriterion (#328)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #328

This diff we update the 3rd of 4 aepsych completion criterion to directly rely on transitioncriterion

In following diffs we will:
- Completely remove the completion criterion file
- revisit storage
- remove all todos in gennode, genstrat, and transitioncriterion classes related to maintaining this deprecated code
- update AEPsych GSs as needed
- determine if run indefinetly can be replaced by simply having gen_unlimited_trials = true
- determine if additional updates to min_asks and run_indefinetly are necessary

Reviewed By: crasanders

Differential Revision: D52852941

fbshipit-source-id: b515689c199f5d5f8ba1d547b12d171c4bbae6ba
  • Loading branch information
mgarrard authored and facebook-github-bot committed Jan 29, 2024
1 parent 1aa4280 commit 3177eb4
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions aepsych/generators/completion_criterion/min_asks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,30 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict

from ax.core.experiment import Experiment
from ax.modelbridge.completion_criterion import CompletionCriterion
from typing import Any, Dict, Optional, Set

from aepsych.config import Config, ConfigurableMixin

from ax.core.experiment import Experiment
from ax.modelbridge.transition_criterion import TransitionCriterion

class MinAsks(CompletionCriterion, ConfigurableMixin):

class MinAsks(TransitionCriterion, ConfigurableMixin):
def __init__(self, threshold: int) -> None:
self.threshold = threshold

def is_met(self, experiment: Experiment) -> bool:
return experiment.num_asks >= self.threshold

def block_continued_generation_error(
self,
node_name: Optional[str],
model_name: Optional[str],
experiment: Optional[Experiment],
trials_from_node: Optional[Set[int]] = None,
) -> None:
pass

@classmethod
def get_config_options(cls, config: Config, name: str) -> Dict[str, Any]:
min_asks = config.getint(name, "min_asks", fallback=1)
Expand Down

0 comments on commit 3177eb4

Please sign in to comment.