diff --git a/Project.toml b/Project.toml index 92e45311c8..09e7c05dce 100644 --- a/Project.toml +++ b/Project.toml @@ -59,7 +59,7 @@ Crayons = "4" DataFrames = "1" EnzymeCore = "0.8.2" ExprTools = "0.1" -GPUArrays = "11.1" +GPUArrays = "11.2" GPUCompiler = "0.24, 0.25, 0.26, 0.27, 1" KernelAbstractions = "0.9.2" LLVM = "9.1" diff --git a/src/array.jl b/src/array.jl index 30657625b2..3213a91652 100644 --- a/src/array.jl +++ b/src/array.jl @@ -71,9 +71,13 @@ mutable struct CuArray{T,N,M} <: AbstractGPUArray{T,N} else maxsize end - data = DataRef(pool_free, pool_alloc(M, bufsize)) - obj = new{T,N,M}(data, maxsize, 0, dims) - finalizer(unsafe_free!, obj) + + GPUArrays.cached_alloc((CuArray, CUDA.device(), T, bufsize, M)) do + data = DataRef(pool_free, pool_alloc(M, bufsize)) + obj = new{T,N,M}(data, maxsize, 0, dims) + finalizer(unsafe_free!, obj) + return obj + end::CuArray{T, N, M} end function CuArray{T,N}(data::DataRef{Managed{M}}, dims::Dims{N};