Skip to content

Commit

Permalink
Update doc strings on TransitionCriterion to improve usability (#2278)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2278

Sait pointed out recently that the doc strings on transition criterion could be improved, this diff does so

Reviewed By: Balandat

Differential Revision: D54914622

fbshipit-source-id: 6fb7121b57a16ddb9b92139c71ac7d890db3ed78
  • Loading branch information
mgarrard authored and facebook-github-bot committed Mar 27, 2024
1 parent 7491233 commit 35642c1
Showing 1 changed file with 138 additions and 14 deletions.
152 changes: 138 additions & 14 deletions ax/modelbridge/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
class TransitionCriterion(SortableBase, SerializationMixin):
# TODO: @mgarrard rename to ActionCriterion
"""
Simple class to descibe a condition which must be met for this GenerationNode to
Simple class to describe a condition which must be met for this GenerationNode to
take an action such as generation, transition, etc.
Args:
Expand All @@ -44,7 +44,7 @@ class TransitionCriterion(SortableBase, SerializationMixin):
until MinimumTrialsInStatus is met (thus overriding MaxTrials).
block_transition_if_unmet: A flag to prevent the node from completing and
being able to transition to another node. Ex: MaxGenerationParallelism
defaults to setting this to Flase since we can complete and move on from
defaults to setting this to False since we can complete and move on from
this node without ever reaching its threshold.
"""

Expand Down Expand Up @@ -101,7 +101,30 @@ def _unique_id(self) -> str:


class TrialBasedCriterion(TransitionCriterion):
"""Common class for action criterion that are based on trial information."""
"""Common class for transition criterion that are based on trial information.
Args:
threshold: The threshold as an integer for this criterion. Ex: If we want to
generate at most 3 trials, then the threshold is 3.
block_transition_if_unmet: A flag to prevent the node from completing and
being able to transition to another node. Ex: MaxGenerationParallelism
defaults to setting this to Flase since we can complete and move on from
this node without ever reaching its threshold.
block_gen_if_met: A flag to prevent continued generation from the
associated GenerationNode if this criterion is met but other criterion
remain unmet. Ex: MinimumTrialsInStatus has not been met yet, but
MaxTrials has been reached. If this flag is set to true on MaxTrials then
we will raise an error, otherwise we will continue to generate trials
until MinimumTrialsInStatus is met (thus overriding MaxTrials).
only_in_statuses: A list of trial statuses to filter on when checking the
criterion threshold.
not_in_statuses: A list of trial statuses to exclude when checking the
criterion threshold.
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to when this criterion is met, if it exists.
use_all_trials_in_exp: A flag to use all trials in the experiment, instead of
only those generated by the current GenerationNode.
"""

def __init__(
self,
Expand Down Expand Up @@ -157,7 +180,7 @@ def experiment_trials_by_status(
Args:
experiment: The experiment associated with this GenerationStrategy.
statuses: The statuses to filter on.
statuses: The trial statuses to filter on.
Returns:
The trial indices in the experiment with the desired statuses.
"""
Expand Down Expand Up @@ -194,6 +217,7 @@ def num_contributing_to_threshold(
Args:
experiment: The experiment associated with this GenerationStrategy.
trials_from_node: The set of trials generated by this GenerationNode.
"""
all_trials_to_check = self.all_trials_to_check(experiment=experiment)
# Some criteria may rely on experiment level data, instead of only trials
Expand All @@ -212,10 +236,11 @@ def num_contributing_to_threshold(
def num_till_threshold(
self, experiment: Experiment, trials_from_node: Optional[Set[int]]
) -> int:
"""Returns the number of trials until the threshold is met.
"""Returns the number of trials needed to meet the threshold.
Args:
experiment: The experiment associated with this GenerationStrategy.
trials_from_node: The set of trials generated by this GenerationNode.
"""
return self.threshold - self.num_contributing_to_threshold(
experiment=experiment, trials_from_node=trials_from_node
Expand All @@ -227,7 +252,17 @@ def is_met(
trials_from_node: Optional[Set[int]] = None,
block_continued_generation: Optional[bool] = False,
) -> bool:
"""Returns if this criterion has been met given its constraints."""
"""Returns if this criterion has been met given its constraints.
Args:
experiment: The experiment associated with this GenerationStrategy.
trials_from_node: The set of trials generated by this GenerationNode.
block_continued_generation: A flag to prevent continued generation from the
associated GenerationNode if this criterion is met but other criterion
remain unmet. Ex: MinimumTrialsInStatus has not been met yet, but
MaxTrials has been reached. If this flag is set to true on MaxTrials
then we will raise an error, otherwise we will continue to generate
trials until MinimumTrialsInStatus is met (thus overriding MaxTrials).
"""
return (
self.num_contributing_to_threshold(
experiment=experiment, trials_from_node=trials_from_node
Expand All @@ -237,6 +272,36 @@ def is_met(


class MaxGenerationParallelism(TrialBasedCriterion):
"""Specific TransitionCriterion implementation which defines the maximum number
of trials that can simultaneously be in the designated trial statuses. The
default behavior is to block generation from the associated GenerationNode if the
threshold is met. This is configured via the `block_gen_if_met` flag being set to
True. This criterion defaults to not blocking transition to another node via the
`block_transition_if_unmet` flag being set to False.
Args:
threshold: The threshold as an integer for this criterion. Ex: If we want to
generate at most 3 trials, then the threshold is 3.
only_in_statuses: A list of trial statuses to filter on when checking the
criterion threshold.
not_in_statuses: A list of trial statuses to exclude when checking the
criterion threshold.
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to when this criterion is met, if it exists.
block_transition_if_unmet: A flag to prevent the node from completing and
being able to transition to another node. Ex: MaxGenerationParallelism
defaults to setting this to False since we can complete and move on from
this node without ever reaching its threshold.
block_gen_if_met: A flag to prevent continued generation from the
associated GenerationNode if this criterion is met but other criterion
remain unmet. Ex: MinimumTrialsInStatus has not been met yet, but
MaxTrials has been reached. If this flag is set to true on MaxTrials then
we will raise an error, otherwise we will continue to generate trials
until MinimumTrialsInStatus is met (thus overriding MaxTrials).
use_all_trials_in_exp: A flag to use all trials in the experiment, instead of
only those generated by the current GenerationNode.
"""

def __init__(
self,
threshold: int,
Expand Down Expand Up @@ -283,14 +348,32 @@ def block_continued_generation_error(

class MaxTrials(TrialBasedCriterion):
"""
Simple class to enforce a maximum threshold for the number of trials generated
by a specific GenerationNode.
Simple class to enforce a maximum threshold for the number of trials with the
designated statuses being generated by a specific GenerationNode. The default
behavior is to block transition to the next node if the threshold is unmet, but
not affect continued generation.
Args:
threshold: the designated maximum number of trials
enforce: whether or not to enforce the max trial constraint
only_in_status: optional argument for specifying only checking trials with
this status. If not specified, all trial statuses are counted.
threshold: The threshold as an integer for this criterion. Ex: If we want to
generate at most 3 trials, then the threshold is 3.
only_in_statuses: A list of trial statuses to filter on when checking the
criterion threshold.
not_in_statuses: A list of trial statuses to exclude when checking the
criterion threshold.
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to when this criterion is met, if it exists.
block_transition_if_unmet: A flag to prevent the node from completing and
being able to transition to another node. Ex: MaxGenerationParallelism
defaults to setting this to False since we can complete and move on from
this node without ever reaching its threshold.
block_gen_if_met: A flag to prevent continued generation from the
associated GenerationNode if this criterion is met but other criterion
remain unmet. Ex: MinimumTrialsInStatus has not been met yet, but
MaxTrials has been reached. If this flag is set to true on MaxTrials then
we will raise an error, otherwise we will continue to generate trials
until MinimumTrialsInStatus is met (thus overriding MaxTrials).
use_all_trials_in_exp: A flag to use all trials in the experiment, instead of
only those generated by the current GenerationNode.
"""

def __init__(
Expand Down Expand Up @@ -333,8 +416,32 @@ def block_continued_generation_error(

class MinTrials(TrialBasedCriterion):
"""
Simple class to decide if the number of trials of a given status in the
GenerationStrategy experiment has reached a certain threshold.
Simple class to enforce a minimum threshold for the number of trials with the
designated statuses being generated by a specific GenerationNode. The default
behavior is to block transition to the next node if the threshold is unmet, but
not affect continued generation.
Args:
threshold: The threshold as an integer for this criterion. Ex: If we want to
generate at most 3 trials, then the threshold is 3.
only_in_statuses: A list of trial statuses to filter on when checking the
criterion threshold.
not_in_statuses: A list of trial statuses to exclude when checking the
criterion threshold.
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to when this criterion is met, if it exists.
block_transition_if_unmet: A flag to prevent the node from completing and
being able to transition to another node. Ex: MaxGenerationParallelism
defaults to setting this to False since we can complete and move on from
this node without ever reaching its threshold.
block_gen_if_met: A flag to prevent continued generation from the
associated GenerationNode if this criterion is met but other criterion
remain unmet. Ex: MinimumTrialsInStatus has not been met yet, but
MaxTrials has been reached. If this flag is set to true on MaxTrials then
we will raise an error, otherwise we will continue to generate trials
until MinimumTrialsInStatus is met (thus overriding MaxTrials).
use_all_trials_in_exp: A flag to use all trials in the experiment, instead of
only those generated by the current GenerationNode.
"""

def __init__(
Expand Down Expand Up @@ -379,6 +486,23 @@ class MinimumPreferenceOccurances(TransitionCriterion):
In a preference Experiment (i.e. Metric values may either be zero for No and
nonzero for Yes) do not transition until a minimum number of both Yes and No
responses have been received.
Args:
metric_name: name of the metric to check for preference occurrences.
threshold: The threshold as an integer for this criterion. Ex: If we want to
generate at most 3 trials, then the threshold is 3.
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to when this criterion is met, if it exists.
block_gen_if_met: A flag to prevent continued generation from the
associated GenerationNode if this criterion is met but other criterion
remain unmet. Ex: MinimumTrialsInStatus has not been met yet, but
MaxTrials has been reached. If this flag is set to true on MaxTrials then
we will raise an error, otherwise we will continue to generate trials
until MinimumTrialsInStatus is met (thus overriding MaxTrials).
block_transition_if_unmet: A flag to prevent the node from completing and
being able to transition to another node. Ex: MaxGenerationParallelism
defaults to setting this to False since we can complete and move on from
this node without ever reaching its threshold.
"""

def __init__(
Expand Down

0 comments on commit 35642c1

Please sign in to comment.