Skip to content

Commit

Permalink
more generic ZScoreTransform, UnitRangeTranform to support CuArrays (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sdewaele authored Dec 19, 2020
1 parent 854a541 commit 83ccc37
Showing 1 changed file with 27 additions and 44 deletions.
71 changes: 27 additions & 44 deletions src/transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ reconstruct(t::AbstractDataTransform, y::AbstractVector{<:Real}) =
"""
Standardization (Z-score transformation)
"""
struct ZScoreTransform{T<:Real} <: AbstractDataTransform
struct ZScoreTransform{T<:Real, U<:AbstractVector{T}} <: AbstractDataTransform
len::Int
dims::Int
mean::Vector{T}
scale::Vector{T}
mean::U
scale::U

function ZScoreTransform(l::Int, dims::Int, m::Vector{T}, s::Vector{T}) where T
function ZScoreTransform(l::Int, dims::Int, m::U, s::U) where {T<:Real, U<:AbstractVector{T}}
lenm = length(m)
lens = length(s)
lenm == l || lenm == 0 || throw(DimensionMismatch("Inconsistent dimensions."))
lens == l || lens == 0 || throw(DimensionMismatch("Inconsistent dimensions."))
new{T}(l, dims, m, s)
new{T, U}(l, dims, m, s)
end
end

Expand Down Expand Up @@ -123,9 +123,8 @@ function fit(::Type{ZScoreTransform}, X::AbstractMatrix{<:Real};
else
throw(DomainError(dims, "fit only accept dims to be 1 or 2."))
end
T = eltype(X)
return ZScoreTransform(l, dims, (center ? vec(m) : zeros(T, 0)),
(scale ? vec(s) : zeros(T, 0)))
return ZScoreTransform(l, dims, (center ? vec(m) : similar(m, 0)),
(scale ? vec(s) : similar(s, 0)))
end

function fit(::Type{ZScoreTransform}, X::AbstractVector{<:Real};
Expand All @@ -134,10 +133,7 @@ function fit(::Type{ZScoreTransform}, X::AbstractVector{<:Real};
throw(DomainError(dims, "fit only accepts dims=1 over a vector. Try fit(t, x, dims=1)."))
end

T = eltype(X)
m, s = mean_and_std(X)
return ZScoreTransform(1, dims, (center ? [m] : zeros(T, 0)),
(scale ? [s] : zeros(T, 0)))
return fit(ZScoreTransform, reshape(X, :, 1); dims=dims, center=center, scale=scale)
end

function transform!(y::AbstractMatrix{<:Real}, t::ZScoreTransform, x::AbstractMatrix{<:Real})
Expand Down Expand Up @@ -207,19 +203,19 @@ end
"""
Unit range normalization
"""
struct UnitRangeTransform{T<:Real} <: AbstractDataTransform
struct UnitRangeTransform{T<:Real, U<:AbstractVector} <: AbstractDataTransform
len::Int
dims::Int
unit::Bool
min::Vector{T}
scale::Vector{T}
min::U
scale::U

function UnitRangeTransform(l::Int, dims::Int, unit::Bool, min::Vector{T}, max::Vector{T}) where {T}
function UnitRangeTransform(l::Int, dims::Int, unit::Bool, min::U, max::U) where {T, U<:AbstractVector{T}}
lenmin = length(min)
lenmax = length(max)
lenmin == l || lenmin == 0 || throw(DimensionMismatch("Inconsistent dimensions."))
lenmax == l || lenmax == 0 || throw(DimensionMismatch("Inconsistent dimensions."))
new{T}(l, dims, unit, min, max)
new{T, U}(l, dims, unit, min, max)
end
end

Expand Down Expand Up @@ -270,45 +266,32 @@ function fit(::Type{UnitRangeTransform}, X::AbstractMatrix{<:Real};
Base.depwarn("fit(t, x) is deprecated: use fit(t, x, dims=2) instead", :fit)
dims = 2
end
if dims == 1
l, tmin, tmax = _compute_extrema(X)
elseif dims == 2
l, tmin, tmax = _compute_extrema(X')
else
throw(DomainError(dims, "fit only accept dims to be 1 or 2."))
end

for i = 1:l
@inbounds tmax[i] = 1 / (tmax[i] - tmin[i])
end
dims (1, 2) || throw(DomainError(dims, "fit only accept dims to be 1 or 2."))
tmin, tmax = _compute_extrema(X, dims)
@. tmax = 1 / (tmax - tmin)
l = length(tmin)
return UnitRangeTransform(l, dims, unit, tmin, tmax)
end

function _compute_extrema(X::AbstractMatrix{<:Real})
n, l = size(X)
tmin = X[1, :]
tmax = X[1, :]
for j = 1:l
@inbounds for i = 2:n
if X[i, j] < tmin[j]
tmin[j] = X[i, j]
elseif X[i, j] > tmax[j]
tmax[j] = X[i, j]
end
end
function _compute_extrema(X::AbstractMatrix, dims::Integer)
dims == 2 && return _compute_extrema(X', 1)
l = size(X, 2)
tmin = similar(X, l)
tmax = similar(X, l)
for i in 1:l
@inbounds tmin[i], tmax[i] = extrema(@view(X[:, i]))
end
return l, tmin, tmax
return tmin, tmax
end

function fit(::Type{UnitRangeTransform}, X::AbstractVector{<:Real};
dims::Integer=1, unit::Bool=true)
if dims != 1
throw(DomainError(dims, "fit only accept dims=1 over a vector. Try fit(t, x, dims=1)."))
end

l, tmin, tmax = _compute_extrema(reshape(X, :, 1))
tmin, tmax = extrema(X)
tmax = 1 / (tmax - tmin)
return UnitRangeTransform(1, dims, unit, vec(tmin), vec(tmax))
return UnitRangeTransform(1, dims, unit, [tmin], [tmax])
end

function transform!(y::AbstractMatrix{<:Real}, t::UnitRangeTransform, x::AbstractMatrix{<:Real})
Expand Down

0 comments on commit 83ccc37

Please sign in to comment.