Skip to content

Commit

Permalink
fix comments in plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
yalsaffar committed Oct 22, 2024
1 parent ffbef0a commit 344336b
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions aepsych/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,20 @@ def plot_strat_3d(
save_path: Optional[str] = None,
show: bool = True,
) -> None:
"""Creates a plot of a 2d slice of a 3D strategy, showing the estimated model or probability response and contours
Args:
strat (Strategy): Strategy object to be plotted. Must have a dimensionality of 3.
parnames (str list): list of the parameter names
outcome_label (str): The label of the outcome variable
slice_dim (int): dimension to slice on
dim_vals (list of floats or int): values to take slices; OR number of values to take even slices from
contour_levels (iterable of floats or bool, optional): List contour values to plot. Default: None. If true, all integer levels.
probability_space (bool): Whether to plot probability. Default: False
gridsize (int): The number of points to sample each dimension at. Default: 30.
extent_multiplier (list, optional): multipliers for each of the dimensions when plotting. Default:None
save_path (str, optional): File name to save the plot to. Default: None.
show (bool): Whether the plot should be shown in an interactive window. Default: True.
"""
assert strat.model is not None, "Cannot plot without a model!"

contour_levels_list: List[float] = []
Expand All @@ -353,7 +367,7 @@ def plot_strat_3d(

if parnames is None:
parnames = ["x1", "x2", "x3"]

# Get global min/max for all slices
if probability_space:
vmax = 1
vmin = 0
Expand All @@ -370,15 +384,18 @@ def plot_strat_3d(

if not isinstance(contour_levels_list, Sized):
raise TypeError("contour_levels_list must be Sized (e.g., a list or an array).")


# slice_vals is either a list of values or an integer number of values to slice on
if isinstance(slice_vals, int):
slices = np.linspace(strat.lb[slice_dim], strat.ub[slice_dim], slice_vals)
slices = np.around(slices, 4)
elif not isinstance(slice_vals, list):
raise TypeError("slice_vals must be either an integer or a list of values")
else:
slices = np.array(slice_vals)


# make mypy happy, note that this can't be more specific
# because of https://github.com/numpy/numpy/issues/24738
axs: np.ndarray
_, axs = plt.subplots(1, len(slices), constrained_layout=True, figsize=(20, 3)) # type: ignore

Expand All @@ -405,8 +422,7 @@ def plot_strat_3d(
cbar.ax.set_ylabel(f"Probability of {outcome_label}")
else:
cbar.ax.set_ylabel(outcome_label)

for clevel in contour_levels_list:
for clevel in contour_levels_list: # type: ignore
cbar.ax.axhline(y=clevel, c="w")

if save_path is not None:
Expand Down

0 comments on commit 344336b

Please sign in to comment.