From 6ada11b120486f31ff120eb9e59396861456056c Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 19 Nov 2024 06:57:55 +0100 Subject: [PATCH] mypy fixes --- doctr/models/modules/transformer/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/models/modules/transformer/pytorch.py b/doctr/models/modules/transformer/pytorch.py index 66dca7cb85..312eba9a26 100644 --- a/doctr/models/modules/transformer/pytorch.py +++ b/doctr/models/modules/transformer/pytorch.py @@ -50,7 +50,7 @@ def scaled_dot_product_attention( if mask is not None: # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined] - p_attn = torch.softmax(scores, dim=-1) + p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload] return torch.matmul(p_attn, value), p_attn