From 683ad8ac63a4b1d7b6e8e13887f59fc28cc4629f Mon Sep 17 00:00:00 2001 From: Michael Shvartsman Date: Fri, 19 Jul 2024 00:48:51 -0700 Subject: [PATCH] Fix mypy (#363) Summary: Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/363 I think mypy is broken in main right now. This should fix. Reviewed By: crasanders Differential Revision: D59802966 fbshipit-source-id: fa29f14b3e7f2de7c75b87b588632688ede4a96d --- aepsych/plotting.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/aepsych/plotting.py b/aepsych/plotting.py index 30a51fbb0..4a054d0c7 100644 --- a/aepsych/plotting.py +++ b/aepsych/plotting.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. import warnings -from typing import Callable, Iterable, List, Optional, Union +from typing import Any, Callable, Iterable, List, Optional, Union import matplotlib.pyplot as plt import numpy as np @@ -375,8 +375,14 @@ def plot_strat_3d( raise TypeError("slice_vals must be either an integer or a list of values") else: slices = np.array(slice_vals) - - _, axs = plt.subplots(1, len(slices), constrained_layout=True, figsize=(20, 3)) + + # make mypy happy, note that this can't be more specific + # because of https://github.com/numpy/numpy/issues/24738 + axs: np.ndarray[Any, Any] + _, axs = plt.subplots(1, len(slices), constrained_layout=True, figsize=(20, 3)) # type: ignore + + assert len(slices) > 1, "Must have at least 2 slices" + for _i, dim_val in enumerate(slices): img = plot_slice(