-
Notifications
You must be signed in to change notification settings - Fork 421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Distributions.jl GPU friendly #1067
Comments
Many distributions are already generic. It would be great if you could prepare PRs for the ones that aren't.
Can't you use broadcasting for this? |
Yeah I'll start looking into this.
Yes, we'll need to define broadcasting styles to do this packing. Also another thing I thought of yesterday — the sampling code will need to be modified to not use scalar indexing. This would be very slow on the GPU, and most people has scalar indexing disabled when loading CuArrays.jl. |
Curious if this is likely to be prioritized any time soon! |
I got pulled into other projects that prevents me from focusing on this issue right now. But I know @srinjoyganguly is working on a DistributionsGPU.jl package. |
Thanks for the prompt response! Including a link to srinjoyganguly's repo for anyone who ends up here: DistributionsGPU.jl |
Are GPU features from DistributionsGPU.jl going to eventually be ported into Distributions.jl? |
Hello @darsnack @johnurbanik I am sorry for the late response. I got busy with some assignments. I will start working on the issue soon. @azev77 I truly hope these features are added to Distributions.jl package. It will help a lot in computation speeds during sampling. Thanks so much. |
Just want to point out we don't have to solve all these GPU things at once, and the abstract types would help integrate with other libraries regardless of gpu implications |
I recently ran into trouble too, trying to integrate Distributions into GPU code. Maybe we could make Distributions.jl more GPU friendly step-by-step? Currently, even simple things like using Distributions, CUDA
Mu = cu(rand(10))
Sigma = cu(rand(10))
D = Normal.(Mu, Sigma) fail with a
This works, however: D = Normal{eltype(Mu)}.(Mu, Sigma) But then we run into the same error as above again with logpdf.(D, X)
This is not surprising, since If we override the Distributions.Normal(mu::T, sigma::T) where {T<:Real} = Normal{eltype(mu)}(mu, sigma)
D = Normal.(Mu, Sigma)
logpdf.(D, X) isa CuArray That's just one specific case, of course, but I have a feeling that with a bit of tweaking here and there Distributions could support GPU operations quite well. |
Is this a Distributions or a CUDA issue? What exactly is the problem? The |
I think it's because most distribution default ctors (like But |
Yeah, usually (e.g. in constructors without sigma or |
I think we'll find a few CUDA/GPU-incompatible things like that scattered throughout Distributions, but I hope many will be easy to fix in that regard (except calls to Rmath, obviously, but luckily there's a lot less of those than in the early days by now). I have some GPU-related statistics tasks coming up, I can do PR's along the way when I hit issues like the above. |
I opened #1487. |
Thanks! |
Is there currently a way to transfer an arbitrary
Maybe something like this:
|
Given that there are no constraints or specifications on fields, parameters, and constructors I don't think any such implementation will generally work. The cleanest approach would be to implement https://github.com/JuliaGPU/Adapt.jl but that has to be done for each distribution separately to ensure it is correct (possibly one could use ConstructionBase but I guess it is not even needed for this task). Hence I think it's much easier to just construct distributions with GPU-compatible parameters instead of trying to move an existing distribution to the GPU. |
perhaps webgpu, being an abstraction over GPU-specific stuff, could enable a more hardware-agnostic and forward-compatible approach? https://github.com/cshenton/WebGPU.jl and https://github.com/JuliaWGPU/WGPUNative.jl could be examples or usable for E2E testing to automatically check if Distributions.jl works on GPU. Then you could handle AMD/NVIDIA/Intel/Whatever and potentially the code would last longer than if we focus on some particular version of CUDA. another idea would be to enforce a linter rule whereby all the functions in this library must act on abstract number types instead of concrete number kinds like float64. otherwise, it's difficult to go the last mile to the GPU because you might need a different (usually smaller or somehow different) entry data type in your arrays to actually fit your problems on your GPU so it works in practice. does julia have a concept of traits? you could probably modularize this better: /// a field is a set on which addition, subtraction, multiplication, and division are defined and behave as the corresponding operations on rational and real numbers do
pub trait Field:
Clone
+ Default
+ std::fmt::Debug
+ std::fmt::Display
+ Sized
+ Eq
+ PartialEq
+ Add<Output = Option<Self>>
+ AddAssign<Self>
+ Sub<Output = Option<Self>>
+ SubAssign<Self>
+ Mul<Output = Option<Self>>
+ MulAssign<Self>
+ Div<Output = Option<Self>>
+ DivAssign<Self>
+ Neg<Output = Option<Self>>
{
const FIELD_NAME: &'static str;
const DATA_TYPE_STR: &'static str;
type Data: Add + Sub + Mul + Div + Default + Debug + Display + Clone;
type Shape;
// Basic Info
fn field_name(&self) -> &'static str {
Self::FIELD_NAME
}
fn data_type_str(&self) -> &'static str {
Self::DATA_TYPE_STR
}
fn shape(&self) -> &Self::Shape;
// Construction and value access
fn of<X>(value: X) -> Self
where
Self::Data: From<X>;
fn from_dtype(value: Self::Data) -> Self {
Self::of(value)
}
fn try_from_n_or_0<T>(value: T) -> Self
where
Self::Data: TryFrom<T>,
{
match value.try_into() {
Ok(converted) => Self::of(converted),
Err(_e) => {
println!("[from_n_or_0] failed to convert a value!");
Self::zero()
}
}
}
fn try_from_n_or_1<T>(value: T) -> Self
where
Self::Data: TryFrom<T>,
{
match value.try_into() {
Ok(converted) => Self::of(converted),
Err(_) => {
println!("[from_n_or_1] failed to convert a value!");
Self::one()
}
}
}
fn get_value(&self) -> Option<&Self::Data>;
fn set_value(&mut self, value: Self::Data);
// Additive & Multiplicative Identity
fn is_zero(&self) -> bool;
fn is_one(&self) -> bool;
fn zero() -> Self;
fn one() -> Self;
// Basic operations with references
fn add_ref(&self, other: &Self) -> Option<Self>;
fn sub_ref(&self, other: &Self) -> Option<Self>;
fn mul_ref(&self, other: &Self) -> Option<Self>;
fn safe_div_ref(&self, other: &Self) -> Option<Self>;
fn rem_ref(&self, other: &Self) -> Option<Self>;
// Field has a multiplicative inverse -- requires rational number field for unsigned integers
// fn mul_inv(&self) -> Option<Ratio<Self, Self>>;
} // thus, this trait should be modularized into smaller traits in a progression of capability anyway, i'll spare you the implementation details and that can probably be way better if informed mathematically and modularized better. If you think about it, when Socrates was talking about Forms, he was talking about Traits; when we think about Float or Number, that's a trait, not a concrete type, so if you really want to make it hella easy to work with GPUs in Distributions.jl, maybe the solution is to make a test with webgpu and some abstract number type which would work independently of the hardware and precision or whatever. Then our current implicit data type Float { processor: CPU, dtype: Float64 } can become Float { processor: WebGPU, dtype: Bfloat16 } as long as a "Bfloat16" implements the same required functionality as a Float64 then you could swap them out like Indiana Jones! |
An increasing number of ML applications require sampling and reusing Distributions.jl instead of rolling your own sampler would be a major step forward. I have already forked this package to fix issues where parametric types are hardcoded to
Float64
instead of generic (to support a ML project I am working on). At this point, being able to sample on the GPU instead of moving data back and forth is pretty important to me. I am willing to put some effort into making this happen, but Distributions.jl covers a lot of stuff that I am not familiar with.To that end, I'd like some help coming up with the list of changes that need to bring us towards this GPU friendly goal. Here's what I have so far:
Float64
) with generics (e.g.AbstractFloat
/Real
)The text was updated successfully, but these errors were encountered: