Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] FIx bug with f16 overfloat #482

Closed
wants to merge 2 commits into from
Closed

Conversation

vadiklyutiy
Copy link
Collaborator

Add preliminary conversion to f32 for quantization.

Describe on example f16 quantize to int16 (it is a bit unclear why we need it but we support it and test it). f16 can not hold exact value of int16.max_value and we got f16 number that more than int16.max_value. In result we got max_min value in quantized int16 instead of max_int.

In fact, converting a float to an int can have unpredictable behavior when the float exceeds the maximum value that the target int can hold.

@yaoyaoding
Copy link
Member

it is a bit unclear why we need it but we support it and test it

Maybe it's added accidently. There is no reason to quantize f16 to an integer type with the same number of bits.

converting a float to an int can have unpredictable behavior when the float exceeds the maximum value that the target int

This is true in general. But in the case of quantization, we can make sure that the quantized floating-point numbers are in the representable range of integer data type.

@yaoyaoding
Copy link
Member

We can remove the "f16 quantized to int16" test case. The PR also looks good.

@vadiklyutiy
Copy link
Collaborator Author

We can remove the "f16 quantized to int16" test case. The PR also looks good.

Yes, I think it what we should do

@vadiklyutiy
Copy link
Collaborator Author

@yaoyaoding
Maybe you know or see something similar. Why I touch it at all is error on Publish workflow

                   Command: /usr/local/cuda/bin/nvcc -I/__w/_tool/Python/3.9.21/x64/lib/python3.9/site-packages/hidet/include -L/__w/_tool/Python/3.9.21/x64/lib/python3.9/site-packages/hidet/lib -lcuda -O3 -Xcompiler -fPIC,-m64,-march=cascadelake,-O3,-funroll-loops,-ffast-math -std=c++11 -gencode arch=compute_89,code=sm_89 --ptxas-options=-v -lineinfo -ftz=true -prec-div=false -lhidet_runtime --cudart shared --diag-suppress 177 --diag-suppress 179 --diag-suppress 39 --shared  /__w/hidet/hidet/outs/cache/ops/cuda_space_0/symmetric_quantization/b8e96bb57ff53b72/source.cu -o /__w/hidet/hidet/outs/cache/ops/cuda_space_0/symmetric_quantization/b8e96bb57ff53b72/lib.so
E                   /__w/hidet/hidet/outs/cache/ops/cuda_space_0/symmetric_quantization/b8e96bb57ff53b72/source.cu(68): error: more than one conversion function from "__nv_bfloat16" to "int8_t" applies:
E                               function "__nv_bfloat16::operator float() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(177): here
E                               function "__nv_bfloat16::operator short() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(195): here
E                               function "__nv_bfloat16::operator unsigned short() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(198): here
E                               function "__nv_bfloat16::operator int() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(201): here
E                               function "__nv_bfloat16::operator unsigned int() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(204): here
E                               function "__nv_bfloat16::operator long long() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(207): here
E                               function "__nv_bfloat16::operator unsigned long long() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(210): here
E                               function "__nv_bfloat16::operator __nv_bool() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(214): here
E                   
E                   1 error detected in the compilation of "/__w/hidet/hidet/outs/cache/ops/cuda_space_0/symmetric_quantization/b8e96bb57ff53b72/source.cu".

By some reason it works with on Tests workflow but fail on Publish.

So, in code just

bfloat16 x;
int8 y = (int8)x

Seems use different versions of CUDA there. What is the right way to convert to int8?

@vadiklyutiy vadiklyutiy deleted the vadim/quant-fix2 branch December 26, 2024 19:16
vadiklyutiy pushed a commit that referenced this pull request Dec 26, 2024
### PR Comment:

In the new version of the `transformers` library (version 4.45.0) used
in CI, the `merges` field in the configuration has changed to a list of
lists. To ensure compatibility with this update, I have modified our
code base to accommodate this change.

Without this adjustment, the `test_tokenizer` in CI would fail to
execute successfully. This update ensures that the tests run as expected
with the new library version.
@yaoyaoding
Copy link
Member

@yaoyaoding Maybe you know or see something similar. Why I touch it at all is error on Publish workflow

                   Command: /usr/local/cuda/bin/nvcc -I/__w/_tool/Python/3.9.21/x64/lib/python3.9/site-packages/hidet/include -L/__w/_tool/Python/3.9.21/x64/lib/python3.9/site-packages/hidet/lib -lcuda -O3 -Xcompiler -fPIC,-m64,-march=cascadelake,-O3,-funroll-loops,-ffast-math -std=c++11 -gencode arch=compute_89,code=sm_89 --ptxas-options=-v -lineinfo -ftz=true -prec-div=false -lhidet_runtime --cudart shared --diag-suppress 177 --diag-suppress 179 --diag-suppress 39 --shared  /__w/hidet/hidet/outs/cache/ops/cuda_space_0/symmetric_quantization/b8e96bb57ff53b72/source.cu -o /__w/hidet/hidet/outs/cache/ops/cuda_space_0/symmetric_quantization/b8e96bb57ff53b72/lib.so
E                   /__w/hidet/hidet/outs/cache/ops/cuda_space_0/symmetric_quantization/b8e96bb57ff53b72/source.cu(68): error: more than one conversion function from "__nv_bfloat16" to "int8_t" applies:
E                               function "__nv_bfloat16::operator float() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(177): here
E                               function "__nv_bfloat16::operator short() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(195): here
E                               function "__nv_bfloat16::operator unsigned short() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(198): here
E                               function "__nv_bfloat16::operator int() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(201): here
E                               function "__nv_bfloat16::operator unsigned int() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(204): here
E                               function "__nv_bfloat16::operator long long() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(207): here
E                               function "__nv_bfloat16::operator unsigned long long() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(210): here
E                               function "__nv_bfloat16::operator __nv_bool() const"
E                   /usr/local/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(214): here
E                   
E                   1 error detected in the compilation of "/__w/hidet/hidet/outs/cache/ops/cuda_space_0/symmetric_quantization/b8e96bb57ff53b72/source.cu".

By some reason it works with on Tests workflow but fail on Publish.

So, in code just

bfloat16 x;
int8 y = (int8)x

Seems use different versions of CUDA there. What is the right way to convert to int8?

I encountered this problem multiple times. The reason is that C++ finds multiple "middle type" during converting bfloat16 to int8 when there is not a direct conversion from bf16 to int8. To address the issue we can explicitly specify a middle type like int8 y = (int8)(int32)(x);.

@vadiklyutiy
Copy link
Collaborator Author

Actually I fixed it with updating test image. With new one there is no such problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants