Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minimal patch to fix Windows compilation issues #876

Merged
merged 1 commit into from
Feb 1, 2024

Conversation

wkpark
Copy link
Contributor

@wkpark wkpark commented Nov 16, 2023

manually cherry-picked from PR #788 and PR #229 and cleaned up by me

Original work done by @Jamezo97 and @acpopescu

@wkpark
Copy link
Contributor Author

wkpark commented Nov 29, 2023

@wkpark
Copy link
Contributor Author

wkpark commented Dec 11, 2023

cmake + github workflows + cuda 12.1, 11.8 + python 3.10, 3.11 + windows, ubuntu build matrix available.
https://github.com/wkpark/bitsandbytes/actions/runs/7165411186/job/19507307550

@TimDettmers
Copy link
Collaborator

This is awesome, thanks for your contribution! We will discuss this internally and get back to you. @Titus-von-Koeller @younesbelkada please remind me that we talk about this.

@TimDettmers TimDettmers added the medium priority (will be worked on after all high priority issues) label Jan 2, 2024
@wkpark wkpark mentioned this pull request Jan 2, 2024

int thread_wave_size = 256;
// we chunk the thresds into waves of 256 since the max limit is
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
{
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
#ifdef _WIN32
std::thread *threads = (std::thread *) malloc(sizeof(std::thread) * valid_chunks);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to use std::thread everywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to use std::thread everywhere?

I just leaving this as the original patch creator intended.

@Titus-von-Koeller
Copy link
Collaborator

@wkpark

Do I understand correctly that this is ready for review?

@wkpark
Copy link
Contributor Author

wkpark commented Jan 29, 2024

@wkpark

Do I understand correctly that this is ready for review?

Yes. I think so.

minimal fixes have been cherry-picked and squashed. (several unrelated hunks were carefully removed/simplified manually)

see also PR #908,
https://github.com/wkpark/bitsandbytes/actions/runs/7700400727/job/20984178019

based on @Jamezo97 and @acpopescu work

manually cherry-picked from PR bitsandbytes-foundation#788 and PR bitsandbytes-foundation#229 and cleanup by wkpark

Signed-off-by: Won-Kyu Park <[email protected]>
Copy link
Collaborator

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wkpark !
I think we should merge this PR before #908 right? Could you re-enable windows build in the workflow for #908 after we merge this PR?

@younesbelkada younesbelkada merged commit 89876bb into bitsandbytes-foundation:main Feb 1, 2024
2 checks passed
@wkpark
Copy link
Contributor Author

wkpark commented Feb 1, 2024

Thanks @wkpark ! I think we should merge this PR before #908 right? Could you re-enable windows build in the workflow for #908 after we merge this PR?

right and thank you.
now I can add windows build matrix!

@younesbelkada
Copy link
Collaborator

Perfect thanks @wkpark !

@younesbelkada
Copy link
Collaborator

Thanks for your great work @wkpark I we can confirm with @Titus-von-Koeller that this PR did not break anything with respect to bnb + HF (PEFT, transformers, etc.) Looking forward to your next contributions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
medium priority (will be worked on after all high priority issues)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants