From 7b4ad9b53c1c8fbf1abe1fb9934582c97dde8888 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 19 Nov 2024 06:52:30 +0100 Subject: [PATCH] mypy fixes --- doctr/models/modules/transformer/pytorch.py | 4 ++-- doctr/transforms/functional/pytorch.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/doctr/models/modules/transformer/pytorch.py b/doctr/models/modules/transformer/pytorch.py index 926066204a..66dca7cb85 100644 --- a/doctr/models/modules/transformer/pytorch.py +++ b/doctr/models/modules/transformer/pytorch.py @@ -46,10 +46,10 @@ def scaled_dot_product_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Scaled Dot-Product Attention""" - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) # type: ignore[assignment] + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) if mask is not None: # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition - scores = scores.masked_fill(mask == 0, float("-inf")) + scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined] p_attn = torch.softmax(scores, dim=-1) return torch.matmul(p_attn, value), p_attn diff --git a/doctr/transforms/functional/pytorch.py b/doctr/transforms/functional/pytorch.py index 2b12f8f6df..3c65d76b7d 100644 --- a/doctr/transforms/functional/pytorch.py +++ b/doctr/transforms/functional/pytorch.py @@ -30,12 +30,12 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor: out = F.rgb_to_grayscale(img, num_output_channels=3) # Random RGB shift shift_shape = [img.shape[0], 3, 1, 1] if img.ndim == 4 else [3, 1, 1] - rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape) # type: ignore[assignment] + rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape) # Inverse the color if out.dtype == torch.uint8: - out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) + out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) # type: ignore[attr-defined] else: - out = out * rgb_shift.to(dtype=out.dtype) + out = out * rgb_shift.to(dtype=out.dtype) # type: ignore[attr-defined] # Inverse the color out = 255 - out if out.dtype == torch.uint8 else 1 - out return out