Skip to content

Commit

Permalink
Update pocs/transf.py
Browse files Browse the repository at this point in the history
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
  • Loading branch information
leonvanbokhorst and sourcery-ai[bot] authored Dec 13, 2024
1 parent 657d3e5 commit 173a06e
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions pocs/transf.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,8 @@ def test_addition(model, num1, num2, max_digits=3):
predicted = torch.argmax(output, dim=-1)

# Filter out padding tokens
result = ''.join(str(x) for x in predicted[0].tolist() if x != 0).lstrip('0')
if not result: # If result is empty after filtering
result = '0'
result = ''.join(str(x) for x in predicted[0].tolist() if x != 0).lstrip('0') or '0'


print(f"{num1} + {num2} = {result} (expected {num1 + num2})")
return result
Expand Down

0 comments on commit 173a06e

Please sign in to comment.