Skip to content

Commit

Permalink
Faster knotvector addition (#279)
Browse files Browse the repository at this point in the history
* improve KnotVector addition speed

* add comments

* faster *(::Integer, ::KnotVector)

* remove unnecessary copy
  • Loading branch information
hyrodium authored Sep 10, 2022
1 parent 1ea42cc commit 67a0774
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
5 changes: 5 additions & 0 deletions src/_EmptyKnotVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ Base.:+(::EmptyKnotVector{T1}, ::EmptyKnotVector{T2}) where {T1<:Real, T2<:Real}
# + swap
Base.:+(k1::EmptyKnotVector, k2::AbstractKnotVector) = k2 + k1

function Base.:*(m::Integer, k::EmptyKnotVector) where T
m < 0 && throw(DomainError(m, "The number to be multiplied must be non-negative."))
return k
end

function Base.show(io::IO, ::T) where T<:EmptyKnotVector
print(io, "$(T)()")
end
Expand Down
44 changes: 40 additions & 4 deletions src/_KnotVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,33 @@ julia> k1 + k2
KnotVector([1, 2, 3, 4, 5, 5, 8])
```
"""
Base.:+(k1::KnotVector{T}, k2::KnotVector{T}) where T = unsafe_knotvector(T,sort!(vcat(k1.vector,k2.vector)))
function Base.:+(k1::KnotVector{T}, k2::KnotVector{T}) where T
v1, v2 = k1.vector, k2.vector
n1, n2 = length(v1), length(v2)
iszero(n1) && return k2
iszero(n2) && return k1
n = n1 + n2
v = Vector{T}(undef, n)
i1 = i2 = 1
for i in 1:n
if isless(v1[i1], v2[i2])
v[i] = v1[i1]
i1 += 1
if i1 > n1
v[i+1:n] = view(v2, i2:n2)
break
end
else
v[i] = v2[i2]
i2 += 1
if i2 > n2
v[i+1:n] = view(v1, i1:n1)
break
end
end
end
return BasicBSpline.unsafe_knotvector(T,v)
end
Base.:+(k1::AbstractKnotVector, k2::AbstractKnotVector) = +(promote(k1,k2)...)

@doc raw"""
Expand All @@ -106,15 +132,25 @@ julia> 2 * k
KnotVector([1, 1, 2, 2, 2, 2, 5, 5])
```
"""
function Base.:*(m::Integer, k::AbstractKnotVector)
function Base.:*(m::Integer, k::KnotVector{T}) where T
if m == 0
return zero(k)
elseif m > 0
return sum(k for _ in 1:m)
elseif m == 1
return k
elseif m > 1
n = length(k)
v = Vector{T}(undef, m*n)
for i in 1:m
v[i:m:m*(n-1)+i] .= k.vector
end
return BasicBSpline.unsafe_knotvector(T,v)
else
throw(DomainError(m, "The number to be multiplied must be non-negative."))
end
end
function Base.:*(m::Integer, k::AbstractKnotVector) where T
return m*KnotVector(k)
end
Base.:*(k::AbstractKnotVector, m::Integer) = m*k

Base.in(r::Real, k::AbstractKnotVector) = in(r, _vec(k))
Expand Down

0 comments on commit 67a0774

Please sign in to comment.