Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

consider value matrix shapes for Jax conversion #10

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

William-Baker
Copy link
Contributor

For some programs, the value matrix could have a dimension larger than the largest key-query matrix resulting in a compilation error. By considering both ov and qk matrices when padding the Jax model we can resolve this.

Examples of programs that trigger this error are:

def rasp_prog():
        se1 = rasp.Select(rasp.indices, rasp.indices, lambda x, y: x == y)
        so2 = rasp.Aggregate(se1, rasp.indices)
        se2 = rasp.Select(rasp.tokens, rasp.tokens, lambda x, y: x < y)
        so3 = rasp.SelectorWidth(se2)
        se3 = rasp.Select(so3, so2, lambda x, y: x!=y)
        so6 = rasp.SequenceMap(lambda x,y: x-y, so3, so3)
        so7 = rasp.Aggregate(se3, so6)
        return so7
def rasp_prog():
        so2 = rasp.Map(lambda x: x - 2, rasp.indices)
        so4 = rasp.SequenceMap(lambda x,y: x or y, so2, rasp.indices)
        so1 = rasp.Map(lambda x: x > 'b', rasp.tokens)
        so3 = rasp.Map(lambda x: x < False, so1)
        se1 = rasp.Select(so3, so3, rasp.Comparison.LEQ)
        so6 = rasp.Aggregate(se1, so4)
        return so6

I have verified that this change does not break any of the test cases proposed in issue #9

For some programs, the value matrix could have a dimension larger than the largest key-query matrix resulting in a compliation error. By considering both ov and qk matrices when padding the Jax model we can resolve this
@david-lindner
Copy link
Collaborator

david-lindner commented Oct 18, 2023

Thanks for this! This fix makes sense.

Could you add a test case that catches the previous bug? Easiest would be to add a minimal version of your example to the test cases in test_cases.py. This will add it to all our integration tests. In particular, rasp_to_craft_integration_test.py should fail without your change and pass after making it.

The proper thing to do would be to also add a unit test to assemble_test.py. I don't think we need to do this here, but feel free to give it a shot if you want.

@langosco
Copy link
Contributor

Is this still blocked on the tests? Would be great if it could be merged soon.

@William-Baker
Copy link
Contributor Author

I haven't added any test cases yet as there are still discrepancies with RASP and CRAFT that should be resolved first, then we can check for consistency. I will prepare some test cases in the mean time that demonstrate the issue under the current compiler but things may change depending on issue #14

@William-Baker
Copy link
Contributor Author

I have added the test cases to test cases.py and checked that the validator doesn't flag issues, since these are categorical aggregates, I do think this only happens with categorical aggregation programs...

@david-lindner david-lindner self-assigned this Jan 25, 2024
Copy link
Collaborator

@david-lindner david-lindner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ov_test_case_1 still fails -- please fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants