diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 4745ee9734..9e3ec07853 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -22,11 +22,9 @@ from __future__ import annotations -import warnings - from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, List, Mapping, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from botorch.models.transforms.utils import ( @@ -256,19 +254,6 @@ def __init__( self._batch_shape = batch_shape self._min_stdv = min_stdv - def load_state_dict( - self, state_dict: Mapping[str, Any], strict: bool = True - ) -> None: - r"""Custom logic for loading the state dict.""" - if "_is_trained" not in state_dict: - warnings.warn( - "Key '_is_trained' not found in state_dict. Setting to True. " - "In a future release, this will result in an error.", - DeprecationWarning, - ) - state_dict = {**state_dict, "_is_trained": torch.tensor(True)} - super().load_state_dict(state_dict, strict=strict) - def forward( self, Y: Tensor, Yvar: Optional[Tensor] = None ) -> Tuple[Tensor, Optional[Tensor]]: diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index 44dadf79ef..d4acc388eb 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -355,13 +355,6 @@ def test_standardize_state_dict(self): self.assertFalse(new_transform._is_trained) new_transform.load_state_dict(state_dict) self.assertTrue(new_transform._is_trained) - # test deprecation error when loading state dict without _is_trained - state_dict.pop("_is_trained") - with self.assertWarnsRegex( - DeprecationWarning, - "Key '_is_trained' not found in state_dict. Setting to True.", - ): - new_transform.load_state_dict(state_dict) def test_log(self): ms = (1, 2)