Skip to content

Commit

Permalink
Remove Standardize.load_state_dict
Browse files Browse the repository at this point in the history
Summary: This was a temporary wrokaround for backwards. compatibility, added in #1875

Reviewed By: Balandat

Differential Revision: D56801096

fbshipit-source-id: c867fe2ca2c65d5f4c6ca9a09fd4798452767cc0
  • Loading branch information
saitcakmak authored and facebook-github-bot committed May 1, 2024
1 parent b438b2d commit 4f362be
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 23 deletions.
17 changes: 1 addition & 16 deletions botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@

from __future__ import annotations

import warnings

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, List, Mapping, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
from botorch.models.transforms.utils import (
Expand Down Expand Up @@ -256,19 +254,6 @@ def __init__(
self._batch_shape = batch_shape
self._min_stdv = min_stdv

def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True
) -> None:
r"""Custom logic for loading the state dict."""
if "_is_trained" not in state_dict:
warnings.warn(
"Key '_is_trained' not found in state_dict. Setting to True. "
"In a future release, this will result in an error.",
DeprecationWarning,
)
state_dict = {**state_dict, "_is_trained": torch.tensor(True)}
super().load_state_dict(state_dict, strict=strict)

def forward(
self, Y: Tensor, Yvar: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down
7 changes: 0 additions & 7 deletions test/models/transforms/test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,6 @@ def test_standardize_state_dict(self):
self.assertFalse(new_transform._is_trained)
new_transform.load_state_dict(state_dict)
self.assertTrue(new_transform._is_trained)
# test deprecation error when loading state dict without _is_trained
state_dict.pop("_is_trained")
with self.assertWarnsRegex(
DeprecationWarning,
"Key '_is_trained' not found in state_dict. Setting to True.",
):
new_transform.load_state_dict(state_dict)

def test_log(self):
ms = (1, 2)
Expand Down

0 comments on commit 4f362be

Please sign in to comment.