Skip to content

Commit

Permalink
fix: partial application
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 12, 2024
1 parent 1a16c64 commit a6674df
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 49 deletions.
53 changes: 8 additions & 45 deletions src/partial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,49 +22,12 @@ function Base.show(
print(io, ")")
end

# ::Type{T} is already specified
function (f::PartialWeightInitializationFunction{T, F, <:AbstractRNG})(
dims::Integer...; kwargs...) where {T <: Number, F}
return f.f(f.rng, T, dims...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{T, F, Nothing})(
rng::AbstractRNG; kwargs...) where {T <: Number, F}
return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...))
end
function (f::PartialWeightInitializationFunction{T, F, Nothing})(
rng::AbstractRNG, dims::Integer...; kwargs...) where {T <: Number, F}
return f.f(rng, T, dims...; f.kwargs..., kwargs...)
end

# ::Type{T} is not needed
function (f::PartialWeightInitializationFunction{Missing, F, <:AbstractRNG})(
dims::Integer...; kwargs...) where {F}
return f.f(f.rng, dims...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{Missing, F, Nothing})(
rng::AbstractRNG; kwargs...) where {F}
return PartialWeightInitializationFunction{Missing}(
f.f, rng, (; f.kwargs..., kwargs...))
end
function (f::PartialWeightInitializationFunction{Missing, F, Nothing})(
rng::AbstractRNG, dims::Integer...; kwargs...) where {F}
return f.f(rng, dims...; f.kwargs..., kwargs...)
end

# ::Type{T} is not specified
function (f::PartialWeightInitializationFunction{Nothing, F, Union{<:AbstractRNG, Nothing}})(
::Type{T}; kwargs...) where {T <: Number, F}
return PartialWeightInitializationFunction{T}(f.f, f.rng, (; f.kwargs..., kwargs...))
end
function (f::PartialWeightInitializationFunction{Nothing, F, <:AbstractRNG})(
::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F}
return f.f(f.rng, T, dims...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})(
rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: Number, F}
return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...))
end
function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})(
rng::AbstractRNG, ::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F}
return f.f(rng, T, dims...; f.kwargs..., kwargs...)
function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})(
args...; kwargs...)
f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...)
return f.f(f.rng, args...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{T})(args...; kwargs...) where {T <: Number}
f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...)
return f.f(f.rng, T, args...; f.kwargs..., kwargs...)
end
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All"))

const EXTRA_PKGS = String[]

BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" && push!(EXTRA_PKGS, "CUDA")
BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" && push!(EXTRA_PKGS, "AMDGPU")
BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" && push!(EXTRA_PKGS, "Metal")
BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" && push!(EXTRA_PKGS, "oneAPI")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "CUDA")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, "Metal")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && push!(EXTRA_PKGS, "oneAPI")

if !isempty(EXTRA_PKGS)
@info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS
Expand Down

0 comments on commit a6674df

Please sign in to comment.