From 63dd0cd62bdcbf24fef32b7fa99db19e19327a14 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Tue, 14 Feb 2023 16:12:54 -0800 Subject: [PATCH] Bug fix: `_filter_kwargs` was erroring when provided a function without a `__name__` attribute (#1678) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1678 See https://github.com/pytorch/botorch/issues/1667 Reviewed By: danielrjiang Differential Revision: D43286116 fbshipit-source-id: 3da3e6ff23b517f5379ee90f407dc04d4f2ad06e --- botorch/optim/utils/common.py | 7 ++++++- test/optim/test_optimize.py | 6 ------ test/optim/utils/test_common.py | 19 ++++++++++++++----- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/botorch/optim/utils/common.py b/botorch/optim/utils/common.py index 5cd687b104..0bd4650630 100644 --- a/botorch/optim/utils/common.py +++ b/botorch/optim/utils/common.py @@ -23,9 +23,14 @@ def _filter_kwargs(function: Callable, **kwargs: Any) -> Any: allowed_params = signature(function).parameters removed = {k for k in kwargs.keys() if k not in allowed_params} if len(removed) > 0: + fn_descriptor = ( + f" for function {function.__name__}" + if hasattr(function, "__name__") + else "" + ) warn( f"Keyword arguments {list(removed)} will be ignored because they are" - f" not allowed parameters for function {function.__name__}. Allowed " + f" not allowed parameters{fn_descriptor}. Allowed " f"parameters are {list(allowed_params.keys())}." ) return {k: v for k, v in kwargs.items() if k not in removed} diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 421467ef7c..cdafe3755c 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -115,13 +115,10 @@ def test_optimize_acqf_joint( mock_gen_candidates_scipy, mock_gen_candidates_torch, ): - # Mocks don't have a __name__ attribute. - # Set the attribute, since it is needed for testing _filter_kwargs if mock_gen_candidates == mock_gen_candidates_torch: mock_signature.return_value = signature(gen_candidates_torch) else: mock_signature.return_value = signature(gen_candidates_scipy) - mock_gen_candidates.__name__ = "gen_candidates" mock_gen_batch_initial_conditions.return_value = torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype @@ -835,13 +832,10 @@ def nlc(x): mock_gen_candidates_torch, mock_gen_candidates_scipy, ): - # Mocks don't have a __name__ attribute. - # Set the attribute, since it is needed for testing _filter_kwargs if mock_gen_candidates == mock_gen_candidates_torch: mock_signature.return_value = signature(gen_candidates_torch) else: mock_signature.return_value = signature(gen_candidates_scipy) - mock_gen_candidates.__name__ = "gen_candidates" for dtype in (torch.float, torch.double): mock_acq_function = MockAcquisitionFunction() diff --git a/test/optim/utils/test_common.py b/test/optim/utils/test_common.py index 0c157a125c..b2d68f5851 100644 --- a/test/optim/utils/test_common.py +++ b/test/optim/utils/test_common.py @@ -25,14 +25,23 @@ def mock_adam(params, lr: float = 0.001) -> None: return # pragma: nocover kwargs = {"lr": 0.01, "maxiter": 3000} - with catch_warnings(record=True) as ws: + expected_msg = ( + r"Keyword arguments \['maxiter'\] will be ignored because they are " + r"not allowed parameters for function mock_adam. Allowed parameters " + r"are \['params', 'lr'\]." + ) + + with self.assertWarnsRegex(Warning, expected_msg): valid_kwargs = _filter_kwargs(mock_adam, **kwargs) + self.assertEqual(set(valid_kwargs.keys()), {"lr"}) + + mock_partial = partial(mock_adam, lr=2.0) expected_msg = ( - "Keyword arguments ['maxiter'] will be ignored because they are not" - " allowed parameters for function mock_adam. Allowed parameters are " - "['params', 'lr']." + r"Keyword arguments \['maxiter'\] will be ignored because they are " + r"not allowed parameters. Allowed parameters are \['params', 'lr'\]." ) - self.assertEqual(expected_msg, str(ws[0].message)) + with self.assertWarnsRegex(Warning, expected_msg): + valid_kwargs = _filter_kwargs(mock_partial, **kwargs) self.assertEqual(set(valid_kwargs.keys()), {"lr"}) def test_handle_numerical_errors(self):