Skip to content

Commit d4841b2

Browse files
committed
set mypy version to 1.9.0
1 parent d679904 commit d4841b2

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

.github/workflows/mypy-type-checking.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
- name: Install dependencies
2626
run: |
2727
python -m pip install --upgrade pip
28-
pip install mypy
28+
pip install mypy==1.9.0
2929
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
3030
- name: Type checking with mypy
3131
run: |

encoder.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import math
3+
from typing import Optional
34

45
import torch
56
from torch import nn
@@ -38,7 +39,7 @@ def _reset_parameters(self):
3839
xavier_uniform_(p)
3940

4041
def forward(
41-
self, input_ids: torch.Tensor, src_padding_mask: torch.BoolTensor = None
42+
self, input_ids: torch.Tensor, src_padding_mask: Optional[torch.BoolTensor] = None
4243
):
4344
"""
4445
Performs one encoder forward pass given input token ids and an optional attention mask.
@@ -72,7 +73,7 @@ def __init__(self, hidden_dim: int, ff_dim: int, num_heads: int, dropout_p: floa
7273
self.layer_norm1 = nn.LayerNorm(hidden_dim)
7374
self.layer_norm2 = nn.LayerNorm(hidden_dim)
7475

75-
def forward(self, x: torch.FloatTensor, src_padding_mask: torch.BoolTensor = None):
76+
def forward(self, x: torch.FloatTensor, src_padding_mask: Optional[torch.BoolTensor] = None):
7677
"""
7778
Performs one encoder *block* forward pass given the previous block's output and an optional attention mask.
7879

0 commit comments

Comments
 (0)