Skip to content

Commit

Permalink
Convert tensors to float64 uniformly (facebookresearch#407)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#407

BoTorch best practices suggest that we should be working in float64. We move as much as possible to float64.

Reviewed By: crasanders

Differential Revision: D64507796

fbshipit-source-id: f0b082f0b5e9961caafb08ac1a4bdecff47738bd
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Oct 17, 2024
1 parent 09b1d59 commit 7c7fa88
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
4 changes: 4 additions & 0 deletions aepsych/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import sys

import torch

from gpytorch.likelihoods import BernoulliLikelihood, GaussianLikelihood

from . import acquisition, config, factory, generators, models, strategy, utils
Expand All @@ -15,6 +17,8 @@
from .models import GPClassificationModel
from .strategy import SequentialStrategy, Strategy

torch.set_default_dtype(torch.float64)

__all__ = [
# modules
"acquisition",
Expand Down
2 changes: 1 addition & 1 deletion aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _str_to_array(self, v: str) -> np.ndarray:
return np.array(v, dtype=float)

def _str_to_tensor(self, v: str) -> torch.Tensor:
return torch.Tensor(self._str_to_list(v))
return torch.Tensor(self._str_to_list(v)).to(torch.float64)

def _str_to_obj(self, v: str, fallback_type: _T = str, warn: bool = True) -> object:

Expand Down
2 changes: 0 additions & 2 deletions aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@

logger = getLogger()

torch.set_default_dtype(torch.double) # TODO: find a better way to prevent type errors


class ModelProtocol(Protocol):
@property
Expand Down
17 changes: 9 additions & 8 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
from aepsych.generators.base import AEPsychGenerator
from aepsych.generators.sobol_generator import SobolGenerator
from aepsych.models.base import ModelProtocol
from aepsych.utils import (
_process_bounds,
make_scaled_sobol,
)
from aepsych.utils import _process_bounds, make_scaled_sobol
from aepsych.utils_logging import getLogger
from botorch.exceptions.errors import ModelFittingError

Expand Down Expand Up @@ -170,7 +167,9 @@ def __init__(

self.name = name

def normalize_inputs(self, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
def normalize_inputs(
self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""converts inputs into normalized format for this strategy
Args:
Expand All @@ -185,11 +184,11 @@ def normalize_inputs(self, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor
assert (
x.shape == self.event_shape or x.shape[1:] == self.event_shape
), f"x shape should be {self.event_shape} or batch x {self.event_shape}, instead got {x.shape}"

# Handle scalar y values
if y.ndim == 0:
y = y.unsqueeze(0)

if x.shape == self.event_shape:
x = x[None, :]

Expand Down Expand Up @@ -309,7 +308,9 @@ def n_trials(self):
)
return self.min_asks

def add_data(self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]):
def add_data(
self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]
):
"""
Adds new data points to the strategy, and normalizes the inputs.
Expand Down
5 changes: 3 additions & 2 deletions aepsych/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def _process_bounds(lb, ub, dim) -> Tuple[torch.Tensor, torch.Tensor, int]:
if not isinstance(ub, torch.Tensor):
ub = torch.tensor(ub)

lb = lb.float()
ub = ub.float()
lb = lb.to(torch.float64)
ub = ub.to(torch.float64)

assert lb.shape[0] == ub.shape[0], "bounds should be of equal shape!"

if dim is not None:
Expand Down

0 comments on commit 7c7fa88

Please sign in to comment.