Skip to content

Commit ed3d847

Browse files
committed
Support refinement iterations on ONNX (Fixes baudm#12 baudm#66)
1 parent 3dbf9c7 commit ed3d847

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ print('Decoded label = {}'.format(label[0]))
9191

9292
## Frequently Asked Questions
9393
- How do I train on a new language? See Issues [#5](https://github.com/baudm/parseq/issues/5) and [#9](https://github.com/baudm/parseq/issues/9).
94-
- Can you export to TorchScript or ONNX? Yes, with a [caveat (Issue #12)](https://github.com/baudm/parseq/issues/12#issuecomment-1267842315).
94+
- Can you export to TorchScript or ONNX? Yes, see Issue [#12](https://github.com/baudm/parseq/issues/12#issuecomment-1267842315).
9595
- How do I test on my own dataset? See Issue [#27](https://github.com/baudm/parseq/issues/27).
9696
- How do I finetune and/or create a custom dataset? See Issue [#7](https://github.com/baudm/parseq/issues/7).
9797
- What is `val_NED`? See Issue [#10](https://github.com/baudm/parseq/issues/10).

strhub/models/parseq/system.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor:
144144
for i in range(self.refine_iters):
145145
# Prior context is the previous output.
146146
tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
147-
tgt_padding_mask = ((tgt_in == self.eos_id).cumsum(-1) > 0) # mask tokens beyond the first EOS token.
147+
tgt_padding_mask = ((tgt_in == self.eos_id).int().cumsum(-1) > 0) # mask tokens beyond the first EOS token.
148148
tgt_out = self.decode(tgt_in, memory, tgt_mask, tgt_padding_mask,
149149
tgt_query=pos_queries, tgt_query_mask=query_mask[:, :tgt_in.shape[1]])
150150
logits = self.head(tgt_out)

0 commit comments

Comments
 (0)