From e14734610d9841e59cd2e16f75beada5e50cc0f5 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Tue, 14 Feb 2023 16:31:25 -0800 Subject: [PATCH] Add isinstance_af (#1664) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1664 Creates a new helper method for checking both if a given AF is an instance of a class or if the given AF wraps a base AF that is an instance of a class Differential Revision: D43127722 fbshipit-source-id: e46a289a1e1fe815ded61bc271e62d21563eb1c7 --- botorch/acquisition/utils.py | 17 ++++++++-- test/acquisition/test_utils.py | 61 +++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 486fdd0cff..ccbbf471b2 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -11,7 +11,7 @@ from __future__ import annotations import math -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from botorch.acquisition import analytic, monte_carlo, multi_objective # noqa F401 @@ -22,6 +22,7 @@ MCAcquisitionObjective, PosteriorTransform, ) +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from botorch.exceptions.errors import UnsupportedError from botorch.models.fully_bayesian import MCMC_DIM from botorch.models.model import Model @@ -253,6 +254,18 @@ def objective(Y: Tensor, X: Optional[Tensor] = None): return -(lb.clamp_max(0.0)) +def isinstance_af( + __obj: object, + __class_or_tuple: Union[type, tuple[Union[type, tuple[Any, ...]], ...]], +) -> bool: + r"""A variant of isinstance first checks for the acq_func attribute on wrapped acquisition functions.""" + if isinstance(__obj, AbstractAcquisitionFunctionWrapper): + isinstance_base_af = isinstance(__obj.acq_func, __class_or_tuple) + else: + isinstance_base_af = False + return isinstance_base_af or isinstance(__obj, __class_or_tuple) + + def is_nonnegative(acq_function: AcquisitionFunction) -> bool: r"""Determine whether a given acquisition function is non-negative. @@ -267,7 +280,7 @@ def is_nonnegative(acq_function: AcquisitionFunction) -> bool: >>> qEI = qExpectedImprovement(model, best_f=0.1) >>> is_nonnegative(qEI) # returns True """ - return isinstance( + return isinstance_af( acq_function, ( analytic.ExpectedImprovement, diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index d12b5f6da4..39b8017ea2 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -8,7 +8,8 @@ from unittest import mock import torch -from botorch.acquisition import monte_carlo +from botorch.acquisition import analytic, monte_carlo, multi_objective +from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction from botorch.acquisition.multi_objective import ( MCMultiOutputObjective, monte_carlo as moo_monte_carlo, @@ -18,10 +19,13 @@ MCAcquisitionObjective, ScalarizedPosteriorTransform, ) +from botorch.acquisition.proximal import ProximalAcquisitionFunction from botorch.acquisition.utils import ( expand_trace_observations, get_acquisition_function, get_infeasible_cost, + is_nonnegative, + isinstance_af, project_to_sample_points, project_to_target_fidelity, prune_inferior_points, @@ -606,6 +610,61 @@ def test_get_infeasible_cost(self): self.assertAllClose(M4, torch.tensor([1.0], **tkwargs)) +class TestIsNonnegative(BotorchTestCase): + def test_is_nonnegative(self): + nonneg_afs = ( + analytic.ExpectedImprovement, + analytic.ConstrainedExpectedImprovement, + analytic.ProbabilityOfImprovement, + analytic.NoisyExpectedImprovement, + monte_carlo.qExpectedImprovement, + monte_carlo.qNoisyExpectedImprovement, + monte_carlo.qProbabilityOfImprovement, + multi_objective.analytic.ExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement, + ) + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, device=self.device), + variance=torch.ones(1, 1, device=self.device), + ) + ) + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) + with mock.patch( + "botorch.acquisition.utils.isinstance_af", return_value=True + ) as mock_isinstance_af: + self.assertTrue(is_nonnegative(acq_function=acq_func)) + mock_isinstance_af.assert_called_once() + cargs, _ = mock_isinstance_af.call_args + self.assertIs(cargs[0], acq_func) + self.assertEqual(cargs[1], nonneg_afs) + acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0) + self.assertFalse(is_nonnegative(acq_function=acq_func)) + + +class TestIsinstanceAf(BotorchTestCase): + def test_isinstance_af(self): + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, device=self.device), + variance=torch.ones(1, 1, device=self.device), + ) + ) + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) + self.assertTrue(isinstance_af(acq_func, analytic.ExpectedImprovement)) + self.assertFalse(isinstance_af(acq_func, analytic.UpperConfidenceBound)) + wrapped_af = FixedFeatureAcquisitionFunction( + acq_function=acq_func, d=2, columns=[1], values=[0.0] + ) + # test base af class + self.assertTrue(isinstance_af(wrapped_af, analytic.ExpectedImprovement)) + self.assertFalse(isinstance_af(wrapped_af, analytic.UpperConfidenceBound)) + # test wrapper class + self.assertTrue(isinstance_af(wrapped_af, FixedFeatureAcquisitionFunction)) + self.assertFalse(isinstance_af(wrapped_af, ProximalAcquisitionFunction)) + + class TestPruneInferiorPoints(BotorchTestCase): def test_prune_inferior_points(self): for dtype in (torch.float, torch.double):