diff --git a/doctr/models/modules/transformer/pytorch.py b/doctr/models/modules/transformer/pytorch.py index 66dca7cb85..312eba9a26 100644 --- a/doctr/models/modules/transformer/pytorch.py +++ b/doctr/models/modules/transformer/pytorch.py @@ -50,7 +50,7 @@ def scaled_dot_product_attention( if mask is not None: # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined] - p_attn = torch.softmax(scores, dim=-1) + p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload] return torch.matmul(p_attn, value), p_attn