Skip to content

Commit

Permalink
Refactor to Use Vectorized Operations and Transition from NumPy to Py…
Browse files Browse the repository at this point in the history
…Torch Tensors (facebookresearch#406)

Summary:
This PR partially solves facebookresearch#365, with changes as follows:

- **Grid Generation and Meshgrid Handling (`dim_grid` and `get_lse_interval`)**:
   Transitioned from `np.mgrid` to `torch.meshgrid` and `torch.linspace`, simplifying setup and ensuring full compatibility with PyTorch, reducing conversion steps.

- **Interpolation (`interpolate_monotonic`)**:
   Switched from `np.searchsorted` to `torch.searchsorted` and used `torch.where` for interpolation, enabling efficient, single-pass processing and maintaining overall consistency.

- **Probability and Quantile Calculations (`get_lse_interval`)**:
   Updated to use `torch.distributions.Normal`, `torch.median`, and `torch.quantile`.

- **Generalized Vectorization (`get_jnd_1d` and `get_jnd_multid`)**:
   Functions are now fully vectorized using PyTorch’s capabilities, avoiding element-wise iteration.

Pull Request resolved: facebookresearch#406

Reviewed By: crasanders

Differential Revision: D64563850

Pulled By: JasonKChow

fbshipit-source-id: 867b86b6822eb3380a2f1e0535849b3ea44a5a05
  • Loading branch information
yalsaffar authored and facebook-github-bot committed Oct 18, 2024
1 parent bbda7fe commit a3d3a33
Show file tree
Hide file tree
Showing 4 changed files with 13,700 additions and 126 deletions.
10 changes: 4 additions & 6 deletions aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,10 @@ def get_jnd(
return torch.tensor(1 / np.gradient(fmean, coords, axis=intensity_dim))
elif method == "step":
return torch.clip(
torch.tensor(
get_jnd_multid(
fmean.detach().numpy(),
coords.detach().numpy(),
mono_dim=intensity_dim,
)
get_jnd_multid(
fmean,
coords,
mono_dim=intensity_dim,
),
0,
np.inf,
Expand Down
14 changes: 7 additions & 7 deletions aepsych/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def _plot_strat_1d(

threshold_samps = [
interpolate_monotonic(
grid.squeeze().numpy(), s, target_level, strat.lb[0], strat.ub[0]
)
grid, s, target_level, strat.lb[0], strat.ub[0]
).cpu().numpy()
for s in samps
]
thresh_med = np.mean(threshold_samps)
Expand All @@ -201,12 +201,12 @@ def _plot_strat_1d(
ax.plot(grid, true_f.squeeze(), label="True function")
if target_level is not None:
true_thresh = interpolate_monotonic(
grid.squeeze().numpy(),
grid,
true_f.squeeze(),
target_level,
strat.lb[0],
strat.ub[0],
)
).cpu().numpy()

ax.plot(
true_thresh,
Expand Down Expand Up @@ -305,18 +305,18 @@ def _plot_strat_2d(
)
ax.plot(
context_grid,
thresh_75,
thresh_75.cpu().numpy(),
label=f"Est. {target_level*100:.0f}% threshold \n(with {cred_level*100:.0f}% posterior \nmass shaded)",
)
ax.fill_between(
context_grid, lower, upper, alpha=0.3, hatch="///", edgecolor="gray"
context_grid, lower.cpu().numpy(), upper.cpu().numpy(), alpha=0.3, hatch="///", edgecolor="gray"
)

if true_testfun is not None:
true_f = true_testfun(grid).reshape(gridsize, gridsize)
true_thresh = get_lse_contour(
true_f, mono_grid, level=target_level, lb=strat.lb[-1], ub=strat.ub[-1]
)
).cpu().numpy()
ax.plot(context_grid, true_thresh, label="Ground truth threshold")

ax.set_xlabel(xlabel)
Expand Down
143 changes: 84 additions & 59 deletions aepsych/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

from collections.abc import Iterable
from configparser import NoOptionError
from typing import Dict, List, Mapping, Optional, Tuple
from typing import Dict, List, Mapping, Optional, Tuple, Union

import numpy as np
import torch
from scipy.stats import norm
from torch.quasirandom import SobolEngine

from aepsych.config import Config

def make_scaled_sobol(lb, ub, size, seed=None):
lb, ub, ndim = _process_bounds(lb, ub, None)
Expand Down Expand Up @@ -59,11 +60,11 @@ def dim_grid(

for i in range(dim):
if i in slice_dims.keys():
mesh_vals.append(slice(slice_dims[i] - 1e-10, slice_dims[i] + 1e-10, 1))
mesh_vals.append(torch.tensor([slice_dims[i] - 1e-10, slice_dims[i] + 1e-10]))
else:
mesh_vals.append(slice(lower[i].item(), upper[i].item(), gridsize * 1j))
mesh_vals.append(torch.linspace(lower[i].item(), upper[i].item(), gridsize))

return torch.Tensor(np.mgrid[mesh_vals].reshape(dim, -1).T)
return torch.stack(torch.meshgrid(*mesh_vals, indexing='ij'), dim=-1).reshape(-1, dim)


def _process_bounds(lb, ub, dim) -> Tuple[torch.Tensor, torch.Tensor, int]:
Expand Down Expand Up @@ -98,95 +99,119 @@ def _process_bounds(lb, ub, dim) -> Tuple[torch.Tensor, torch.Tensor, int]:
return lb, ub, dim


def interpolate_monotonic(x, y, z, min_x=-np.inf, max_x=np.inf):
def interpolate_monotonic(x: torch.Tensor, y: torch.Tensor, z: Union[torch.Tensor, float], min_x: Union[torch.Tensor, float] =-float('inf'), max_x: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor:
# Ben Letham's 1d interpolation code, assuming monotonicity.
# basic idea is find the nearest two points to the LSE and
# linearly interpolate between them (I think this is bisection
# root-finding)
idx = np.searchsorted(y, z)
if idx == len(y):
return float(max_x)
elif idx == 0:
return float(min_x)
idx = torch.searchsorted(y, z, right=False)

# Handle edge cases where idx is 0 or at the end
idx = torch.clamp(idx, 1, len(y) - 1)

x0 = x[idx - 1]
x1 = x[idx]
y0 = y[idx - 1]
y1 = y[idx]

x_star = x0 + (x1 - x0) * (z - y0) / (y1 - y0)
# Apply min and max boundaries
x_star = torch.where(z < y[0], min_x, x_star)
x_star = torch.where(z > y[-1], max_x, x_star)

return x_star


def get_lse_interval(
model,
mono_grid,
target_level,
cred_level=None,
mono_dim=-1,
n_samps=500,
lb=-np.inf,
ub=np.inf,
gridsize=30,
mono_grid: Union[torch.Tensor, np.ndarray],
target_level: float,
cred_level: Optional[float]=None,
mono_dim: int =-1,
n_samps: int =500,
lb: float =-float('inf'),
ub: float =float('inf'),
gridsize: int =30,
**kwargs,
):
xgrid = torch.Tensor(
np.mgrid[
[
slice(model.lb[i].item(), model.ub[i].item(), gridsize * 1j)
for i in range(model.dim)
]
]
.reshape(model.dim, -1)
.T
)
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
# Create a meshgrid using torch.linspace
xgrid = torch.stack(
torch.meshgrid(
[torch.linspace(model.lb[i].item(), model.ub[i].item(), gridsize) for i in range(model.dim)]
),
dim=-1
).reshape(-1, model.dim)

samps = model.sample(xgrid, num_samples=n_samps, **kwargs)
samps = [s.reshape((gridsize,) * model.dim) for s in samps.detach().numpy()]
contours = np.stack(
samps = [s.reshape((gridsize,) * model.dim) for s in samps]

# Define the normal distribution for the CDF
normal_dist = torch.distributions.Normal(0, 1)

# Calculate contours using torch.stack and the torch CDF for each sample
contours = torch.stack(
[
get_lse_contour(norm.cdf(s), mono_grid, target_level, mono_dim, lb, ub)
get_lse_contour(normal_dist.cdf(s), mono_grid, target_level, mono_dim, lb, ub)
for s in samps
]
)

if cred_level is None:
return np.mean(contours, 0.5, axis=0)
return torch.median(contours, dim=0).values
else:
alpha = 1 - cred_level
qlower = alpha / 2
qupper = 1 - alpha / 2

upper = np.quantile(contours, qupper, axis=0)
lower = np.quantile(contours, qlower, axis=0)
median = np.quantile(contours, 0.5, axis=0)
lower = torch.quantile(contours, qlower, dim=0)
upper = torch.quantile(contours, qupper, dim=0)
median = torch.quantile(contours, 0.5, dim=0)

return median, lower, upper


def get_lse_contour(post_mean, mono_grid, level, mono_dim=-1, lb=-np.inf, ub=np.inf):
return np.apply_along_axis(
lambda p: interpolate_monotonic(mono_grid, p, level, lb, ub),
mono_dim,
post_mean,
)


def get_jnd_1d(post_mean, mono_grid, df=1, mono_dim=-1, lb=-np.inf, ub=np.inf):
def get_lse_contour(post_mean: torch.Tensor, mono_grid: Union[torch.Tensor, np.ndarray], level: float, mono_dim: int =-1, lb: Union[torch.Tensor, float] =-float('inf'), ub: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor:
post_mean = torch.tensor(post_mean, dtype=torch.float32)
mono_grid = torch.tensor(mono_grid, dtype=torch.float32)

# Move mono_dim to the last dimension if it isn't already
if mono_dim != -1:
post_mean = post_mean.transpose(mono_dim, -1)

# Apply interpolation across all rows at once
result = interpolate_monotonic(mono_grid, post_mean, level, lb, ub)

# Transpose back if necessary
if mono_dim != -1:
result = result.transpose(-1, mono_dim)

return result


def get_jnd_1d(post_mean: torch.Tensor, mono_grid: torch.Tensor, df: int =1, mono_dim: int =-1, lb: Union[torch.Tensor, float] =-float('inf'), ub: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor:

# Calculate interpolate_to in a vectorized way
interpolate_to = post_mean + df
return (
np.array(
[interpolate_monotonic(mono_grid, post_mean, ito) for ito in interpolate_to]
)
- mono_grid
)


def get_jnd_multid(post_mean, mono_grid, df=1, mono_dim=-1, lb=-np.inf, ub=np.inf):
return np.apply_along_axis(
lambda p: get_jnd_1d(p, mono_grid, df=df, mono_dim=mono_dim, lb=lb, ub=ub),
mono_dim,
post_mean,
)

# Apply interpolation to the entire tensor
interpolated_values = interpolate_monotonic(mono_grid, post_mean, interpolate_to, lb, ub)

return interpolated_values - mono_grid

def get_jnd_multid(post_mean: torch.Tensor, mono_grid: torch.Tensor, df: int =1, mono_dim: int =-1, lb: Union[torch.Tensor, float] =-float('inf'), ub: Union[torch.Tensor, float] =float('inf')) -> torch.Tensor:

# Move mono_dim to the last dimension if it isn't already
if mono_dim != -1:
post_mean = post_mean.transpose(mono_dim, -1)

# Apply get_jnd_1d in a vectorized way
result = get_jnd_1d(post_mean, mono_grid, df=df, mono_dim=-1, lb=lb, ub=ub)

# Transpose back if necessary
if mono_dim != -1:
result = result.transpose(-1, mono_dim)

return result


def _get_ax_parameters(config):
Expand Down
13,659 changes: 13,605 additions & 54 deletions pubs/owenetal/code/test_functions.ipynb

Large diffs are not rendered by default.

0 comments on commit a3d3a33

Please sign in to comment.