Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize CDF Calculation and Convert NumPy Arrays to Tensors in Benchmark #399

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions aepsych/benchmark/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,25 @@ def metadata(self) -> Dict[str, Any]:
Benchmark's output dataframe, with its associated value stored in each row."""
return {"name": self.name}

def p(self, x: np.ndarray) -> np.ndarray:
"""Evaluate response probability from test function.
def p(self, x: torch.Tensor) -> torch.Tensor:
"""
Evaluate response probability from test function.

Args:
x (np.ndarray): Points at which to evaluate.
x (torch.Tensor): Points at which to evaluate.

Returns:
np.ndarray: Response probability at queries points.
torch.Tensor: Response probability at queried points.
"""
return norm.cdf(self.f(x))
normal_dist = torch.distributions.Normal(0, 1) # Standard normal distribution
return normal_dist.cdf(self.f(x)) # Use PyTorch's CDF equivalent


def sample_y(self, x: np.ndarray) -> np.ndarray:
def sample_y(self, x: torch.Tensor) -> np.ndarray:
"""Sample a response from test function.

Args:
x (np.ndarray): Points at which to sample.
x (torch.Tensor): Points at which to sample.

Returns:
np.ndarray: A single (bernoulli) sample at points.
Expand Down Expand Up @@ -211,7 +214,7 @@ class LSEProblem(Problem):
def __init__(self, thresholds: Union[float, List]):
super().__init__()
thresholds = [thresholds] if isinstance(thresholds, float) else thresholds
self.thresholds = np.array(thresholds)
self.thresholds = torch.tensor(thresholds)

@property
def metadata(self) -> Dict[str, Any]:
Expand All @@ -225,27 +228,32 @@ def metadata(self) -> Dict[str, Any]:
)
return md

def f_threshold(self, model=None):

def f_threshold(self, model=None) -> torch.Tensor:
try:
inverse_torch = model.likelihood.objective.inverse

def inverse_link(x):
return inverse_torch(torch.tensor(x)).numpy()
return inverse_torch(torch.tensor(x))

except AttributeError:
inverse_link = norm.ppf
return inverse_link(self.thresholds).astype(np.float32)
def inverse_link(x):
normal_dist = torch.distributions.Normal(0, 1)
return normal_dist.icdf(torch.tensor(x)) # Same as norm.ppf but using Torch

return inverse_link(self.thresholds).float() # Return as float32 tensor




@cached_property
def true_below_threshold(self) -> np.ndarray:
def true_below_threshold(self) -> torch.Tensor:
"""
Evaluate whether the true function is below threshold over the eval grid
(used for proper scoring and threshold missclassification metric).
"""
return (
self.p(self.eval_grid).reshape(1, -1) <= self.thresholds.reshape(-1, 1)
).astype(float)
).to(torch.float32)

def evaluate(self, strat: Union[Strategy, SequentialStrategy]) -> Dict[str, float]:
"""Evaluate the model with respect to this problem.
Expand Down Expand Up @@ -284,16 +292,16 @@ def evaluate(self, strat: Union[Strategy, SequentialStrategy]) -> Dict[str, floa
and p_l.shape[0] == len(self.thresholds)
)

# Predict p(below threshold) at test points
brier_p_below_thresh = np.mean(2 * np.square(true_p_l - p_l), axis=1)
# Now, perform the Brier score calculation and classification error in PyTorch
brier_p_below_thresh = torch.mean(2 * torch.square(true_p_l - p_l), dim=1)
# Classification error
misclass_on_thresh = np.mean(
p_l * (1 - true_p_l) + (1 - p_l) * true_p_l, axis=1
misclass_on_thresh = torch.mean(
p_l * (1 - true_p_l) + (1 - p_l) * true_p_l, dim=1
)

for i_threshold, threshold in enumerate(self.thresholds):
metrics[f"brier_p_below_{threshold}"] = brier_p_below_thresh[i_threshold]
metrics[f"misclass_on_thresh_{threshold}"] = misclass_on_thresh[i_threshold]
metrics[f"brier_p_below_{threshold}"] = brier_p_below_thresh.detach().cpu().numpy()[i_threshold]
metrics[f"misclass_on_thresh_{threshold}"] = misclass_on_thresh.detach().cpu().numpy()[i_threshold]
return metrics


Expand Down
9 changes: 6 additions & 3 deletions aepsych/models/base.py
JasonKChow marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def update(
) -> None:
pass

def p_below_threshold(self, x, f_thresh) -> np.ndarray:
def p_below_threshold(self, x, f_thresh) -> torch.Tensor:
pass


Expand Down Expand Up @@ -378,9 +378,12 @@ def _fit_mll(
)
return res

def p_below_threshold(self, x, f_thresh) -> np.ndarray:
def p_below_threshold(self, x, f_thresh) -> torch.Tensor: # Return a tensor instead of NumPy array
f, var = self.predict(x)
f_thresh = f_thresh.reshape(-1, 1)
f = f.reshape(1, -1)
var = var.reshape(1, -1)
return norm.cdf((f_thresh - f.detach().numpy()) / var.sqrt().detach().numpy())

# Perform all operations in PyTorch (no .detach().numpy())
z = (f_thresh - f) / var.sqrt()
return torch.distributions.Normal(0, 1).cdf(z) # Use PyTorch's CDF equivalent
41 changes: 22 additions & 19 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,34 +70,35 @@ def setUp(self):
np.random.seed(1)

self.n_thresholds = 5
self.thresholds = np.linspace(0.55, 0.95, self.n_thresholds)
self.thresholds = torch.linspace(0.55, 0.95, self.n_thresholds)
self.test_problem = example_problems.DiscrimLowDim(thresholds=self.thresholds)
self.model = GPClassificationModel(
lb=self.test_problem.lb, ub=self.test_problem.ub
)

def unvectorized_p_below_threshold(self, x, f_thresh) -> np.ndarray:
def unvectorized_p_below_threshold(self, x, f_thresh) -> torch.Tensor:
"""this is the original p_below_threshold method in the AEPsychMixin that calculates model prediction
of the probability of the stimulus being below a threshold
for one single threshold"""
f, var = self.model.predict(x)
return norm.cdf((f_thresh - f.detach().numpy()) / var.sqrt().detach().numpy())

# Perform all operations in PyTorch (no .detach().numpy())
z = (f_thresh - f) / var.sqrt()
return torch.distributions.Normal(0, 1).cdf(z) # Use PyTorch's CDF equivalent

def unvectorized_true_below_threshold(self, threshold):
"""the original true_below_threshold method in the LSEProblem class"""
return (self.test_problem.p(self.test_problem.eval_grid) <= threshold).astype(
float
)
return (self.test_problem.p(self.test_problem.eval_grid) <= threshold).to(torch.float32)

def test_vectorized_score_calculation(self):
f_thresholds = self.test_problem.f_threshold(self.model)
p_l = self.model.p_below_threshold(self.test_problem.eval_grid, f_thresholds)
true_p_l = self.test_problem.true_below_threshold
# Predict p(below threshold) at test points
brier_p_below_thresh = np.mean(2 * np.square(true_p_l - p_l), axis=1)
# Now, perform the Brier score calculation and classification error in PyTorch
brier_p_below_thresh = torch.mean(2 * torch.square(true_p_l - p_l), dim=1)
# Classification error
misclass_on_thresh = np.mean(
p_l * (1 - true_p_l) + (1 - p_l) * true_p_l, axis=1
misclass_on_thresh = torch.mean(
p_l * (1 - true_p_l) + (1 - p_l) * true_p_l, dim=1
)
assert (
p_l.ndim == 2
Expand All @@ -106,31 +107,33 @@ def test_vectorized_score_calculation(self):
)

for i_threshold, single_threshold in enumerate(self.thresholds):
single_f_threshold = norm.ppf(single_threshold)
assert np.isclose(single_f_threshold, f_thresholds[i_threshold])
normal_dist = torch.distributions.Normal(0, 1)
single_f_threshold = normal_dist.icdf(single_threshold).float() # equivalent to norm.ppf

assert torch.isclose(single_f_threshold, f_thresholds[i_threshold])

unvectorized_p_l = self.unvectorized_p_below_threshold(
self.test_problem.eval_grid, single_f_threshold
)
assert np.all(np.isclose(unvectorized_p_l, p_l[i_threshold]))
assert torch.all(torch.isclose(unvectorized_p_l, p_l[i_threshold]))

unvectorized_true_p_l = self.unvectorized_true_below_threshold(
single_threshold
)
assert np.all(np.isclose(unvectorized_true_p_l, true_p_l[i_threshold]))
assert torch.all(torch.isclose(unvectorized_true_p_l, true_p_l[i_threshold]))

unvectorized_brier_score = np.mean(
2 * np.square(unvectorized_true_p_l - unvectorized_p_l)
unvectorized_brier_score = torch.mean(
2 * torch.square(unvectorized_true_p_l - unvectorized_p_l)
)
assert np.isclose(
assert torch.isclose(
unvectorized_brier_score, brier_p_below_thresh[i_threshold]
)

unvectorized_misclass_err = np.mean(
unvectorized_misclass_err = torch.mean(
unvectorized_p_l * (1 - unvectorized_true_p_l)
+ (1 - unvectorized_p_l) * unvectorized_true_p_l
)
assert np.isclose(
assert torch.isclose(
unvectorized_misclass_err, misclass_on_thresh[i_threshold]
)

Expand Down
Loading