Skip to content

Commit

Permalink
add type asserts for MA.operate! return values
Browse files Browse the repository at this point in the history
This helps ensure the MutableArithmetics contract is respected.
Specifically, `MA.operate!` should always return its first argument.

Why this instead of performing a stronger check with `===`? Because:

* That would be more verbose
* The `===` would sooner have a run time cost than a type assertion
  • Loading branch information
nsajko committed Nov 26, 2023
1 parent fcb1b29 commit d7df123
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
18 changes: 9 additions & 9 deletions src/default_polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ function MA.operate!(
p::Polynomial,
q::Union{AbstractTermLike,Polynomial},
)
return _polynomial_merge!(op, p, q)
return _polynomial_merge!(op, p, q)::typeof(p)
end

function MA.operate!(
Expand All @@ -554,7 +554,7 @@ function MA.operate!(
q::Polynomial,
args::AbstractTermLike...,
)
return _polynomial_merge!(op, p, q, args...)
return _polynomial_merge!(op, p, q, args...)::typeof(p)
end

function MA.operate!(
Expand All @@ -563,7 +563,7 @@ function MA.operate!(
t::AbstractTermLike,
q::Polynomial,
)
return _polynomial_merge!(op, p, t, q)
return _polynomial_merge!(op, p, t, q)::typeof(p)
end

function MA.buffer_for(
Expand All @@ -581,7 +581,7 @@ function MA.buffered_operate!(
t::AbstractTermLike,
q::Polynomial,
)
return _polynomial_merge!(op, p, t, q, buffer)
return _polynomial_merge!(op, p, t, q, buffer)::typeof(p)
end

function MA.operate_to!(
Expand All @@ -596,13 +596,13 @@ function MA.operate_to!(
return output
end
function MA.operate!(::typeof(*), p::Polynomial, q::Polynomial)
if iszero(q)
return MA.operate!(zero, p)
return if iszero(q)
MA.operate!(zero, p)
elseif nterms(q) == 1
return MA.operate!(*, p, leading_term(q))
MA.operate!(*, p, leading_term(q))
else
return MA.operate_to!(p, *, MA.mutable_copy(p), q)
end
MA.operate_to!(p, *, MA.mutable_copy(p), q)
end::typeof(p)
end
function MA.operate!(::typeof(*), p::Polynomial, t::AbstractTermLike)
for i in eachindex(p.terms)
Expand Down
6 changes: 3 additions & 3 deletions src/division.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ function MA.operate!(
g::_APL,
algo,
)
return MA.buffered_operate!(nothing, op, f, g, algo)
return MA.buffered_operate!(nothing, op, f, g, algo)::typeof(f)
end

# TODO As suggested in [Knu14, Algorithm R, p. 426] (univariate case only), if
Expand Down Expand Up @@ -305,7 +305,7 @@ function MA.operate!(
f,
g,
algo,
)
)::typeof(f)
end

function MA.buffered_operate!(
Expand All @@ -321,7 +321,7 @@ function MA.buffered_operate!(
f,
g,
algo,
)
)::typeof(f)
end

function MA.buffer_for(
Expand Down
20 changes: 10 additions & 10 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ right_constant_function(::typeof(+)) = right_constant_plus
right_constant_function(::typeof(-)) = right_constant_minus
right_constant_function(::typeof(*)) = right_constant_mult
function MA.operate!(op::Union{typeof(+),typeof(-),typeof(*)}, p::_APL, α)
return MA.operate!(right_constant_function(op), p, α)
return MA.operate!(right_constant_function(op), p, α)::typeof(p)
end

MA.operate!(op::typeof(*), α, p::_APL) = MA.operate!(left_constant_mult, α, p)
MA.operate!(op::typeof(*), p::_APL, α) = MA.operate!(right_constant_mult, p, α)
MA.operate!(op::typeof(/), p::_APL, α) = map_coefficients!(Base.Fix2(op, α), p)
MA.operate!(op::typeof(*), α, p::_APL) = MA.operate!(left_constant_mult, α, p)::typeof(α)
MA.operate!(op::typeof(*), p::_APL, α) = MA.operate!(right_constant_mult, p, α)::typeof(p)
MA.operate!(op::typeof(/), p::_APL, α) = map_coefficients!(Base.Fix2(op, α), p)::typeof(p)
function MA.operate_to!(output::AbstractPolynomial, op::typeof(*), α, p::_APL)
return MA.operate_to!(output, left_constant_mult, α, p)
end
Expand Down Expand Up @@ -240,7 +240,7 @@ function MA.operate!(
p::AbstractPolynomial,
q::AbstractPolynomialLike,
)
return MA.operate!(op, p, polynomial(q))
return MA.operate!(op, p, polynomial(q))::typeof(p)
end

function mul_to_terms!(ts::Vector{<:AbstractTerm}, p1::_APL, p2::_APL)
Expand Down Expand Up @@ -327,7 +327,7 @@ function MA.operate!(::typeof(left_constant_mult), α, p::_APL)
return map_coefficients!(Base.Fix1(*, α), p)
end
function MA.operate!(::typeof(right_constant_mult), p::_APL, α)
return map_coefficients!(Base.Fix2(MA.mul!!, α), p)
return map_coefficients!(Base.Fix2(MA.mul!!, α), p)::typeof(p)
end

function MA.operate_to!(
Expand All @@ -343,7 +343,7 @@ function MA.operate!(
m1::AbstractMonomial,
m2::AbstractMonomialLike,
)
return map_exponents!(+, m1, m2)
return map_exponents!(+, m1, m2)::typeof(m1)
end
function Base.:*(m1::AbstractMonomialLike, m2::AbstractMonomialLike)
return map_exponents(+, m1, m2)
Expand All @@ -364,7 +364,7 @@ function Base.:*(t1::AbstractTermLike, t2::AbstractTermLike)
end

function MA.operate!(::typeof(*), p::_APL, t::AbstractMonomialLike)
return map_exponents!(+, p, t)
return map_exponents!(+, p, t)::typeof(p)
end
Base.:*(p::_APL, t::AbstractMonomialLike) = map_exponents(+, p, t)
Base.:*(t::AbstractTermLike, p::_APL) = polynomial!(map(te -> t * te, terms(p)))
Expand Down Expand Up @@ -515,7 +515,7 @@ function MA.operate!(
z,
args::Vararg{Any,N},
) where {N}
return MA.operate!(MA.add_sub_op(op), x, *(y, z, args...))
return MA.operate!(MA.add_sub_op(op), x, *(y, z, args...))::typeof(x)
end
function MA.buffer_for(
::MA.AddSubMul,
Expand Down Expand Up @@ -545,5 +545,5 @@ function MA.buffered_operate!(
args::Vararg{Any,N},
) where {N}
product = MA.operate_to!!(buffer, *, y, z, args...)
return MA.operate!(MA.add_sub_op(op), x, product)
return MA.operate!(MA.add_sub_op(op), x, product)::typeof(x)
end
2 changes: 1 addition & 1 deletion src/polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ function MA.operate!(
t::AbstractTermLike,
)
# `MA.add!` will copy the coefficient of `t` so `Polynomial` redefines this
return MA.add!!(p, t)
return MA.add!!(p, t)::typeof(p)
end

#$(SIGNATURES)
Expand Down

0 comments on commit d7df123

Please sign in to comment.