Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFix] Fix UINT/INT8 dequantize implementation and optimize the schedule template for float32 accum #46

Merged
merged 6 commits into from
Jun 2, 2024

Conversation

LeiWang1999
Copy link
Contributor

This pull request includes changes to several Python files in the bitblas library, with the primary goal of improving support for different data types and making the code more robust. This includes changes to the hint.py, tensorcore.py, lop3.py, general_matmul.py, and matmul_dequantize_impl.py files. The changes can be grouped into three main categories: updates to the hint.py and tensorcore.py files to handle different data types, improvements to the lop3.py file to better handle different bit sizes, and changes to the general_matmul.py and matmul_dequantize_impl.py files to add assertions and handle different bit sizes.

Handling different data types:

Improvements to handle different bit sizes:

  • python/bitblas/gpu/intrin/lop3.py: Reformatted the get_fast_decode_intrin function calls in the LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN, LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN, and LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN registrations for better readability. Also added new LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN and LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN registrations. [1] [2] [3]

Adding assertions and handling different bit sizes:

  • python/bitblas/ops/general_matmul.py: Added the is_not_fast_decoding_supported function in the __initialize_fast_decoding method and updated the condition in the transform_weight method to check if bit is less than 8. [1] [2]
  • python/bitblas/ops/impl/matmul_dequantize_impl.py: Added assertions to check if bit is in [1, 2, 4, 8] in the matmul_nt_dequantize_b, matmul_nt_dequantize_b_propagate_b, and matmul_nt_dequantize_b_propagate_a_propagate_b functions. Also updated the decode_func function in these methods to handle the case where bit is 8. [1] [2] [3] [4] [5] [6] [7] [8]

Other changes:

@LeiWang1999 LeiWang1999 merged commit c4400b3 into microsoft:main Jun 2, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant