From a6674df5dfb6f6384f33a486fac9d90c0f424391 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 20:44:40 -0700 Subject: [PATCH] fix: partial application --- src/partial.jl | 53 ++++++++---------------------------------------- test/runtests.jl | 8 ++++---- 2 files changed, 12 insertions(+), 49 deletions(-) diff --git a/src/partial.jl b/src/partial.jl index 7e2c499..a4d34b0 100644 --- a/src/partial.jl +++ b/src/partial.jl @@ -22,49 +22,12 @@ function Base.show( print(io, ")") end -# ::Type{T} is already specified -function (f::PartialWeightInitializationFunction{T, F, <:AbstractRNG})( - dims::Integer...; kwargs...) where {T <: Number, F} - return f.f(f.rng, T, dims...; f.kwargs..., kwargs...) -end -function (f::PartialWeightInitializationFunction{T, F, Nothing})( - rng::AbstractRNG; kwargs...) where {T <: Number, F} - return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...)) -end -function (f::PartialWeightInitializationFunction{T, F, Nothing})( - rng::AbstractRNG, dims::Integer...; kwargs...) where {T <: Number, F} - return f.f(rng, T, dims...; f.kwargs..., kwargs...) -end - -# ::Type{T} is not needed -function (f::PartialWeightInitializationFunction{Missing, F, <:AbstractRNG})( - dims::Integer...; kwargs...) where {F} - return f.f(f.rng, dims...; f.kwargs..., kwargs...) -end -function (f::PartialWeightInitializationFunction{Missing, F, Nothing})( - rng::AbstractRNG; kwargs...) where {F} - return PartialWeightInitializationFunction{Missing}( - f.f, rng, (; f.kwargs..., kwargs...)) -end -function (f::PartialWeightInitializationFunction{Missing, F, Nothing})( - rng::AbstractRNG, dims::Integer...; kwargs...) where {F} - return f.f(rng, dims...; f.kwargs..., kwargs...) -end - -# ::Type{T} is not specified -function (f::PartialWeightInitializationFunction{Nothing, F, Union{<:AbstractRNG, Nothing}})( - ::Type{T}; kwargs...) where {T <: Number, F} - return PartialWeightInitializationFunction{T}(f.f, f.rng, (; f.kwargs..., kwargs...)) -end -function (f::PartialWeightInitializationFunction{Nothing, F, <:AbstractRNG})( - ::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F} - return f.f(f.rng, T, dims...; f.kwargs..., kwargs...) -end -function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})( - rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: Number, F} - return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...)) -end -function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})( - rng::AbstractRNG, ::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F} - return f.f(rng, T, dims...; f.kwargs..., kwargs...) +function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})( + args...; kwargs...) + f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...) + return f.f(f.rng, args...; f.kwargs..., kwargs...) +end +function (f::PartialWeightInitializationFunction{T})(args...; kwargs...) where {T <: Number} + f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...) + return f.f(f.rng, T, args...; f.kwargs..., kwargs...) end diff --git a/test/runtests.jl b/test/runtests.jl index 994df2b..08c5712 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,10 +4,10 @@ const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) const EXTRA_PKGS = String[] -BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" && push!(EXTRA_PKGS, "CUDA") -BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" && push!(EXTRA_PKGS, "AMDGPU") -BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" && push!(EXTRA_PKGS, "Metal") -BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" && push!(EXTRA_PKGS, "oneAPI") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "CUDA") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, "Metal") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && push!(EXTRA_PKGS, "oneAPI") if !isempty(EXTRA_PKGS) @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS