From d7e9f2b8dbbd58ae125b414e8891a0bbd69a3cde Mon Sep 17 00:00:00 2001 From: Beat Buesser Date: Mon, 20 Jan 2025 15:06:46 +0100 Subject: [PATCH] Fix typing Signed-off-by: Beat Buesser --- .../adversarial_patch/adversarial_patch_pytorch.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py b/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py index 78b54edf22..737d065feb 100644 --- a/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py +++ b/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py @@ -26,7 +26,7 @@ import logging import math from packaging.version import parse -from typing import Any, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING import numpy as np from tqdm.auto import trange @@ -513,13 +513,15 @@ def generate( # type: ignore if hasattr(self.estimator, "nb_classes"): + y_array: np.ndarray + if y is None: # pragma: no cover logger.info("Setting labels to estimator classification predictions.") - y_array: np.ndarray = to_categorical( + y_array = to_categorical( np.argmax(self.estimator.predict(x=x), axis=1), nb_classes=self.estimator.nb_classes ) else: - y_array: np.ndarray = y + y_array = cast(np.ndarray, y) y = check_and_transform_label_format(labels=y_array, nb_classes=self.estimator.nb_classes)