diff --git a/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py b/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py index 78b54edf22..737d065feb 100644 --- a/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py +++ b/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py @@ -26,7 +26,7 @@ import logging import math from packaging.version import parse -from typing import Any, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING import numpy as np from tqdm.auto import trange @@ -513,13 +513,15 @@ def generate( # type: ignore if hasattr(self.estimator, "nb_classes"): + y_array: np.ndarray + if y is None: # pragma: no cover logger.info("Setting labels to estimator classification predictions.") - y_array: np.ndarray = to_categorical( + y_array = to_categorical( np.argmax(self.estimator.predict(x=x), axis=1), nb_classes=self.estimator.nb_classes ) else: - y_array: np.ndarray = y + y_array = cast(np.ndarray, y) y = check_and_transform_label_format(labels=y_array, nb_classes=self.estimator.nb_classes)