Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Mypy Linter Errors in plotting.py #405

Closed
26 changes: 16 additions & 10 deletions aepsych/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,12 @@ def _plot_strat_1d(
x, y = strat.x, strat.y
assert x is not None and y is not None, "No data to plot!"

grid = strat.model.dim_grid(gridsize=gridsize)
samps = norm.cdf(strat.model.sample(grid, num_samples=10000).detach())
phimean = samps.mean(0)
if strat.model is not None:
grid = strat.model.dim_grid(gridsize=gridsize)
samps = norm.cdf(strat.model.sample(grid, num_samples=10000).detach())
phimean = samps.mean(0)
else:
raise RuntimeError("Cannot plot without a model!")

JasonKChow marked this conversation as resolved.
Show resolved Hide resolved
ax.plot(np.squeeze(grid), phimean)
if cred_level is not None:
Expand Down Expand Up @@ -215,14 +218,14 @@ def _plot_strat_1d(
ax.scatter(
x[y == 0, 0],
np.zeros_like(x[y == 0, 0]),
marker=3,
marker="3",
color="r",
label=no_label,
)
ax.scatter(
x[y == 1, 0],
np.zeros_like(x[y == 1, 0]),
marker=3,
marker="3",
color="b",
label=yes_label,
)
Expand Down Expand Up @@ -253,11 +256,14 @@ def _plot_strat_2d(
assert x is not None and y is not None, "No data to plot!"

# make sure the model is fit well if we've been limiting fit time
strat.model.fit(train_x=x, train_y=y, max_fit_time=None)
if strat.model is not None:
strat.model.fit(train_x=x, train_y=y, max_fit_time=None)

grid = strat.model.dim_grid(gridsize=gridsize)
fmean, _ = strat.model.predict(grid)
phimean = norm.cdf(fmean.reshape(gridsize, gridsize).detach().numpy()).T
grid = strat.model.dim_grid(gridsize=gridsize)
fmean, _ = strat.model.predict(grid)
phimean = norm.cdf(fmean.reshape(gridsize, gridsize).detach().numpy()).T
else:
raise RuntimeError("Cannot plot without a model!")

extent = np.r_[strat.lb[0], strat.ub[0], strat.lb[1], strat.ub[1]]
colormap = ax.imshow(
Expand All @@ -277,7 +283,7 @@ def _plot_strat_2d(

# hacky relabel to be in logspace
if logx:
locs = np.arange(strat.lb[0], strat.ub[0])
locs: np.ndarray = np.arange(strat.lb[0], strat.ub[0])
ax.set_xticks(ticks=locs)
ax.set_xticklabels(2.0**locs)

Expand Down
Loading