From 67a0774144dde4ac0098c8fc47663d2e0ea526f1 Mon Sep 17 00:00:00 2001 From: Yuto Horikawa Date: Sat, 10 Sep 2022 16:42:41 +0900 Subject: [PATCH] Faster knotvector addition (#279) * improve KnotVector addition speed * add comments * faster *(::Integer, ::KnotVector) * remove unnecessary copy --- src/_EmptyKnotVector.jl | 5 +++++ src/_KnotVector.jl | 44 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/_EmptyKnotVector.jl b/src/_EmptyKnotVector.jl index ab78b7296..364ee38f5 100644 --- a/src/_EmptyKnotVector.jl +++ b/src/_EmptyKnotVector.jl @@ -45,6 +45,11 @@ Base.:+(::EmptyKnotVector{T1}, ::EmptyKnotVector{T2}) where {T1<:Real, T2<:Real} # + swap Base.:+(k1::EmptyKnotVector, k2::AbstractKnotVector) = k2 + k1 +function Base.:*(m::Integer, k::EmptyKnotVector) where T + m < 0 && throw(DomainError(m, "The number to be multiplied must be non-negative.")) + return k +end + function Base.show(io::IO, ::T) where T<:EmptyKnotVector print(io, "$(T)()") end diff --git a/src/_KnotVector.jl b/src/_KnotVector.jl index 79cf7242c..1380286e9 100644 --- a/src/_KnotVector.jl +++ b/src/_KnotVector.jl @@ -84,7 +84,33 @@ julia> k1 + k2 KnotVector([1, 2, 3, 4, 5, 5, 8]) ``` """ -Base.:+(k1::KnotVector{T}, k2::KnotVector{T}) where T = unsafe_knotvector(T,sort!(vcat(k1.vector,k2.vector))) +function Base.:+(k1::KnotVector{T}, k2::KnotVector{T}) where T + v1, v2 = k1.vector, k2.vector + n1, n2 = length(v1), length(v2) + iszero(n1) && return k2 + iszero(n2) && return k1 + n = n1 + n2 + v = Vector{T}(undef, n) + i1 = i2 = 1 + for i in 1:n + if isless(v1[i1], v2[i2]) + v[i] = v1[i1] + i1 += 1 + if i1 > n1 + v[i+1:n] = view(v2, i2:n2) + break + end + else + v[i] = v2[i2] + i2 += 1 + if i2 > n2 + v[i+1:n] = view(v1, i1:n1) + break + end + end + end + return BasicBSpline.unsafe_knotvector(T,v) +end Base.:+(k1::AbstractKnotVector, k2::AbstractKnotVector) = +(promote(k1,k2)...) @doc raw""" @@ -106,15 +132,25 @@ julia> 2 * k KnotVector([1, 1, 2, 2, 2, 2, 5, 5]) ``` """ -function Base.:*(m::Integer, k::AbstractKnotVector) +function Base.:*(m::Integer, k::KnotVector{T}) where T if m == 0 return zero(k) - elseif m > 0 - return sum(k for _ in 1:m) + elseif m == 1 + return k + elseif m > 1 + n = length(k) + v = Vector{T}(undef, m*n) + for i in 1:m + v[i:m:m*(n-1)+i] .= k.vector + end + return BasicBSpline.unsafe_knotvector(T,v) else throw(DomainError(m, "The number to be multiplied must be non-negative.")) end end +function Base.:*(m::Integer, k::AbstractKnotVector) where T + return m*KnotVector(k) +end Base.:*(k::AbstractKnotVector, m::Integer) = m*k Base.in(r::Real, k::AbstractKnotVector) = in(r, _vec(k))