Skip to content

Commit

Permalink
explicit scale not support with torch==2.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
mshr-h committed Nov 29, 2023
1 parent 215cc72 commit d46a6fc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3904,7 +3904,12 @@ def scaled_dot_product_attention(self, inputs, input_types):
attn_mask = inputs[3]
dropout_p = inputs[4]
is_causal = inputs[5]
scale = inputs[6]

# Explicit scale can be used from torch>=2.1.0
if len(inputs) == 7:
scale = inputs[6]
else:
scale = None

assert (
input_types[0] == input_types[1] == input_types[2]
Expand Down
14 changes: 2 additions & 12 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5493,9 +5493,9 @@ def test_scaled_dot_product_attention():
"""test_scaled_dot_product_attention"""
torch.set_grad_enabled(False)

def test_fn(attn_mask=None, is_causal=False, scale=None):
def test_fn(attn_mask=None, is_causal=False):
return lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, is_causal=is_causal, scale=scale
query, key, value, attn_mask=attn_mask, is_causal=is_causal
)

L, S, E, Ev = 5, 7, 11, 13
Expand Down Expand Up @@ -5535,16 +5535,6 @@ def test_fn(attn_mask=None, is_causal=False, scale=None):
verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_4d])
verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_3d])

scale = 0.5
verify_model(test_fn(scale=scale), [query_4d, key_4d, value_4d])
verify_model(test_fn(scale=scale), [query_4d, key_4d, value_3d])
verify_model(test_fn(scale=scale), [query_4d, key_3d, value_4d])
verify_model(test_fn(scale=scale), [query_4d, key_3d, value_3d])
verify_model(test_fn(scale=scale), [query_3d, key_4d, value_4d])
verify_model(test_fn(scale=scale), [query_3d, key_4d, value_3d])
verify_model(test_fn(scale=scale), [query_3d, key_3d, value_4d])
verify_model(test_fn(scale=scale), [query_3d, key_3d, value_3d])

# Test with float64
query_4d = torch.randn(2, 3, L, E, dtype=torch.float64)
query_3d = torch.randn(3, L, E, dtype=torch.float64)
Expand Down

0 comments on commit d46a6fc

Please sign in to comment.