From 7f6e9d30804f0fa2c3b695a08315ec2246ca7315 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 17 Jul 2023 00:03:36 +0200 Subject: [PATCH] gpu docstring --- src/functor.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index d4d310a508..c3531a187c 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -225,8 +225,10 @@ end Copies `m` to the current GPU device (using current GPU backend), if one is available. If no GPU is available, it does nothing (but prints a warning the first time). -On arrays, this calls CUDA's `cu`, which also changes arrays -with Float64 elements to Float32 while copying them to the device (same for AMDGPU). +When the backed is set to "CUDA", when called on arrays it calls `CUDA.cu`, +which also changes arrays with Float64 elements to Float32 while copying them to the device. +Similar conversions happen for "AMDGPU" and "Metal" backends. + To act on arrays within a struct, the struct type must be marked with [`@functor`](@ref). Use [`cpu`](@ref) to copy back to ordinary `Array`s. @@ -235,6 +237,8 @@ See also [`f32`](@ref) and [`f16`](@ref) to change element type only. See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/) to help identify the current device. +See [`Flux.gpu_backend!`](@ref) for setting the backend. + # Example ```julia-repl julia> m = Dense(rand(2, 3)) # constructed with Float64 weight matrix