From d7df123d5455f4ff76629641ac1a7c9f21dfecfd Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Sun, 26 Nov 2023 10:05:02 +0100 Subject: [PATCH] add type asserts for `MA.operate!` return values This helps ensure the MutableArithmetics contract is respected. Specifically, `MA.operate!` should always return its first argument. Why this instead of performing a stronger check with `===`? Because: * That would be more verbose * The `===` would sooner have a run time cost than a type assertion --- src/default_polynomial.jl | 18 +++++++++--------- src/division.jl | 6 +++--- src/operators.jl | 20 ++++++++++---------- src/polynomial.jl | 2 +- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/default_polynomial.jl b/src/default_polynomial.jl index a7b45bf6..1eb77769 100644 --- a/src/default_polynomial.jl +++ b/src/default_polynomial.jl @@ -545,7 +545,7 @@ function MA.operate!( p::Polynomial, q::Union{AbstractTermLike,Polynomial}, ) - return _polynomial_merge!(op, p, q) + return _polynomial_merge!(op, p, q)::typeof(p) end function MA.operate!( @@ -554,7 +554,7 @@ function MA.operate!( q::Polynomial, args::AbstractTermLike..., ) - return _polynomial_merge!(op, p, q, args...) + return _polynomial_merge!(op, p, q, args...)::typeof(p) end function MA.operate!( @@ -563,7 +563,7 @@ function MA.operate!( t::AbstractTermLike, q::Polynomial, ) - return _polynomial_merge!(op, p, t, q) + return _polynomial_merge!(op, p, t, q)::typeof(p) end function MA.buffer_for( @@ -581,7 +581,7 @@ function MA.buffered_operate!( t::AbstractTermLike, q::Polynomial, ) - return _polynomial_merge!(op, p, t, q, buffer) + return _polynomial_merge!(op, p, t, q, buffer)::typeof(p) end function MA.operate_to!( @@ -596,13 +596,13 @@ function MA.operate_to!( return output end function MA.operate!(::typeof(*), p::Polynomial, q::Polynomial) - if iszero(q) - return MA.operate!(zero, p) + return if iszero(q) + MA.operate!(zero, p) elseif nterms(q) == 1 - return MA.operate!(*, p, leading_term(q)) + MA.operate!(*, p, leading_term(q)) else - return MA.operate_to!(p, *, MA.mutable_copy(p), q) - end + MA.operate_to!(p, *, MA.mutable_copy(p), q) + end::typeof(p) end function MA.operate!(::typeof(*), p::Polynomial, t::AbstractTermLike) for i in eachindex(p.terms) diff --git a/src/division.jl b/src/division.jl index feedeb47..934e073b 100644 --- a/src/division.jl +++ b/src/division.jl @@ -228,7 +228,7 @@ function MA.operate!( g::_APL, algo, ) - return MA.buffered_operate!(nothing, op, f, g, algo) + return MA.buffered_operate!(nothing, op, f, g, algo)::typeof(f) end # TODO As suggested in [Knu14, Algorithm R, p. 426] (univariate case only), if @@ -305,7 +305,7 @@ function MA.operate!( f, g, algo, - ) + )::typeof(f) end function MA.buffered_operate!( @@ -321,7 +321,7 @@ function MA.buffered_operate!( f, g, algo, - ) + )::typeof(f) end function MA.buffer_for( diff --git a/src/operators.jl b/src/operators.jl index 22dbf3d9..eee83943 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -94,12 +94,12 @@ right_constant_function(::typeof(+)) = right_constant_plus right_constant_function(::typeof(-)) = right_constant_minus right_constant_function(::typeof(*)) = right_constant_mult function MA.operate!(op::Union{typeof(+),typeof(-),typeof(*)}, p::_APL, α) - return MA.operate!(right_constant_function(op), p, α) + return MA.operate!(right_constant_function(op), p, α)::typeof(p) end -MA.operate!(op::typeof(*), α, p::_APL) = MA.operate!(left_constant_mult, α, p) -MA.operate!(op::typeof(*), p::_APL, α) = MA.operate!(right_constant_mult, p, α) -MA.operate!(op::typeof(/), p::_APL, α) = map_coefficients!(Base.Fix2(op, α), p) +MA.operate!(op::typeof(*), α, p::_APL) = MA.operate!(left_constant_mult, α, p)::typeof(α) +MA.operate!(op::typeof(*), p::_APL, α) = MA.operate!(right_constant_mult, p, α)::typeof(p) +MA.operate!(op::typeof(/), p::_APL, α) = map_coefficients!(Base.Fix2(op, α), p)::typeof(p) function MA.operate_to!(output::AbstractPolynomial, op::typeof(*), α, p::_APL) return MA.operate_to!(output, left_constant_mult, α, p) end @@ -240,7 +240,7 @@ function MA.operate!( p::AbstractPolynomial, q::AbstractPolynomialLike, ) - return MA.operate!(op, p, polynomial(q)) + return MA.operate!(op, p, polynomial(q))::typeof(p) end function mul_to_terms!(ts::Vector{<:AbstractTerm}, p1::_APL, p2::_APL) @@ -327,7 +327,7 @@ function MA.operate!(::typeof(left_constant_mult), α, p::_APL) return map_coefficients!(Base.Fix1(*, α), p) end function MA.operate!(::typeof(right_constant_mult), p::_APL, α) - return map_coefficients!(Base.Fix2(MA.mul!!, α), p) + return map_coefficients!(Base.Fix2(MA.mul!!, α), p)::typeof(p) end function MA.operate_to!( @@ -343,7 +343,7 @@ function MA.operate!( m1::AbstractMonomial, m2::AbstractMonomialLike, ) - return map_exponents!(+, m1, m2) + return map_exponents!(+, m1, m2)::typeof(m1) end function Base.:*(m1::AbstractMonomialLike, m2::AbstractMonomialLike) return map_exponents(+, m1, m2) @@ -364,7 +364,7 @@ function Base.:*(t1::AbstractTermLike, t2::AbstractTermLike) end function MA.operate!(::typeof(*), p::_APL, t::AbstractMonomialLike) - return map_exponents!(+, p, t) + return map_exponents!(+, p, t)::typeof(p) end Base.:*(p::_APL, t::AbstractMonomialLike) = map_exponents(+, p, t) Base.:*(t::AbstractTermLike, p::_APL) = polynomial!(map(te -> t * te, terms(p))) @@ -515,7 +515,7 @@ function MA.operate!( z, args::Vararg{Any,N}, ) where {N} - return MA.operate!(MA.add_sub_op(op), x, *(y, z, args...)) + return MA.operate!(MA.add_sub_op(op), x, *(y, z, args...))::typeof(x) end function MA.buffer_for( ::MA.AddSubMul, @@ -545,5 +545,5 @@ function MA.buffered_operate!( args::Vararg{Any,N}, ) where {N} product = MA.operate_to!!(buffer, *, y, z, args...) - return MA.operate!(MA.add_sub_op(op), x, product) + return MA.operate!(MA.add_sub_op(op), x, product)::typeof(x) end diff --git a/src/polynomial.jl b/src/polynomial.jl index e3830d0f..44c9061c 100644 --- a/src/polynomial.jl +++ b/src/polynomial.jl @@ -443,7 +443,7 @@ function MA.operate!( t::AbstractTermLike, ) # `MA.add!` will copy the coefficient of `t` so `Polynomial` redefines this - return MA.add!!(p, t) + return MA.add!!(p, t)::typeof(p) end #$(SIGNATURES)