Skip to content

Commit ad36d47

Browse files
Update chunk.py (#119)
For the compiler gods
1 parent 4578842 commit ad36d47

File tree

1 file changed

+3
-3
lines changed
  • fla/ops/generalized_delta_rule/dplr

1 file changed

+3
-3
lines changed

fla/ops/generalized_delta_rule/dplr/chunk.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,8 @@ def chunk_dplr_delta_rule(
300300
a: torch.Tensor,
301301
b: torch.Tensor,
302302
gk: torch.Tensor,
303-
scale: float = None,
304-
initial_state: torch.Tensor = None,
303+
scale: Optional[float] = None,
304+
initial_state: Optional[torch.Tensor] = None,
305305
output_final_state: bool = False,
306306
offsets: Optional[torch.LongTensor] = None,
307307
head_first: bool = False
@@ -372,4 +372,4 @@ def chunk_dplr_delta_rule(
372372
offsets,
373373
head_first
374374
)
375-
return o, final_state
375+
return o, final_state

0 commit comments

Comments
 (0)