-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
consider value matrix shapes for Jax conversion #10
base: main
Are you sure you want to change the base?
consider value matrix shapes for Jax conversion #10
Conversation
For some programs, the value matrix could have a dimension larger than the largest key-query matrix resulting in a compliation error. By considering both ov and qk matrices when padding the Jax model we can resolve this
Thanks for this! This fix makes sense. Could you add a test case that catches the previous bug? Easiest would be to add a minimal version of your example to the test cases in test_cases.py. This will add it to all our integration tests. In particular, rasp_to_craft_integration_test.py should fail without your change and pass after making it. The proper thing to do would be to also add a unit test to assemble_test.py. I don't think we need to do this here, but feel free to give it a shot if you want. |
Is this still blocked on the tests? Would be great if it could be merged soon. |
I haven't added any test cases yet as there are still discrepancies with RASP and CRAFT that should be resolved first, then we can check for consistency. I will prepare some test cases in the mean time that demonstrate the issue under the current compiler but things may change depending on issue #14 |
I have added the test cases to test cases.py and checked that the validator doesn't flag issues, since these are categorical aggregates, I do think this only happens with categorical aggregation programs... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ov_test_case_1 still fails -- please fix
For some programs, the value matrix could have a dimension larger than the largest key-query matrix resulting in a compilation error. By considering both ov and qk matrices when padding the Jax model we can resolve this.
Examples of programs that trigger this error are:
I have verified that this change does not break any of the test cases proposed in issue #9