Skip to content

Commit

Permalink
Move Jenatton test function to appropriate file (#2679)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2679

I didn't do this in the previous diff so that it would be easier to review.

Reviewed By: Balandat

Differential Revision: D61431983

fbshipit-source-id: 3f376b793b9627917e5093c960e2691aaef5f3de
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 22, 2024
1 parent 4c3e87d commit b9dd1b0
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 46 deletions.
40 changes: 0 additions & 40 deletions ax/benchmark/metrics/jenatton.py

This file was deleted.

29 changes: 25 additions & 4 deletions ax/benchmark/problems/synthetic/hss/jenatton.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.metrics.benchmark import BenchmarkMetric
from ax.benchmark.metrics.jenatton import jenatton_test_function
from ax.benchmark.runners.botorch_test import (
ParamBasedTestProblem,
ParamBasedTestProblemRunner,
Expand All @@ -21,17 +20,39 @@
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import HierarchicalSearchSpace
from ax.core.types import TParameterization
from pyre_extensions import none_throws


@dataclass(kw_only=True)
class Jenatton(ParamBasedTestProblem):
r"""Jenatton test function for hierarchical search spaces.
def jenatton_test_function(
x1: Optional[int] = None,
x2: Optional[int] = None,
x3: Optional[int] = None,
x4: Optional[float] = None,
x5: Optional[float] = None,
x6: Optional[float] = None,
x7: Optional[float] = None,
r8: Optional[float] = None,
r9: Optional[float] = None,
) -> float:
"""Jenatton test function for hierarchical search spaces.
This function is taken from:
R. Jenatton, C. Archambeau, J. González, and M. Seeger. Bayesian
optimization with tree-structured dependencies. ICML 2017.
"""
if x1 == 0:
if x2 == 0:
return none_throws(x4) ** 2 + 0.1 + none_throws(r8)
return none_throws(x5) ** 2 + 0.2 + none_throws(r8)
if x3 == 0:
return none_throws(x6) ** 2 + 0.3 + none_throws(r9)
return none_throws(x7) ** 2 + 0.4 + none_throws(r9)


@dataclass(kw_only=True)
class Jenatton(ParamBasedTestProblem):
"""Jenatton test function for hierarchical search spaces."""

noise_std: Optional[float] = None
negate: bool = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

from ax.benchmark.metrics.benchmark import BenchmarkMetric, GroundTruthBenchmarkMetric

from ax.benchmark.metrics.jenatton import jenatton_test_function
from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem
from ax.benchmark.problems.synthetic.hss.jenatton import (
get_jenatton_benchmark_problem,
jenatton_test_function,
)
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblemRunner
from ax.core.arm import Arm
Expand Down

0 comments on commit b9dd1b0

Please sign in to comment.