Skip to content

Commit

Permalink
More arithmetic (#35)
Browse files Browse the repository at this point in the history
* More arithmetic

* Fix format

* Fixes

* Fix format

* up
  • Loading branch information
blegat authored Jul 2, 2024
1 parent 4ba5138 commit c041a60
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 30 deletions.
14 changes: 5 additions & 9 deletions src/MultivariateBases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ end
SA.basis(a::Algebra) = a.basis

#Base.:(==)(::Algebra{BT1,B1,M}, ::Algebra{BT2,B2,M}) where {BT1,B1,BT2,B2,M} = true
#Base.:(==)(::Algebra, ::Algebra) = false
function Base.:(==)(a::Algebra, b::Algebra)
# `===` is a shortcut for speedup
return a.basis === b.basis || a.basis == b.basis
end

function Base.show(io::IO, ::Algebra{BT,B}) where {BT,B}
ioc = IOContext(io, :limit => true, :compact => true)
Expand Down Expand Up @@ -72,13 +75,6 @@ function MA.promote_operation(
return Algebra{BT,B,M}
end

const _APL = MP.AbstractPolynomialLike
# We don't define it for all `AlgebraElement` as this would be type piracy
const _AE = SA.AlgebraElement{<:Algebra}

Base.:(+)(p::_APL, q::_AE) = +(p, MP.polynomial(q))
Base.:(+)(p::_AE, q::_APL) = +(MP.polynomial(p), q)
Base.:(-)(p::_APL, q::_AE) = -(p, MP.polynomial(q))
Base.:(-)(p::_AE, q::_APL) = -(MP.polynomial(p), q)
include("arithmetic.jl")

end # module
27 changes: 27 additions & 0 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
const _APL = MP.AbstractPolynomialLike
# We don't define it for all `AlgebraElement` as this would be type piracy
const _AE = SA.AlgebraElement{<:Algebra}

Base.:(+)(p::_APL, q::_AE) = +(p, MP.polynomial(q))
Base.:(+)(p::_AE, q::_APL) = +(MP.polynomial(p), q)
Base.:(-)(p::_APL, q::_AE) = -(p, MP.polynomial(q))
Base.:(-)(p::_AE, q::_APL) = -(MP.polynomial(p), q)

Base.:(+)(p, q::_AE) = +(constant_algebra_element(typeof(SA.basis(q)), p), q)
function Base.:(+)(p::_AE, q)
return +(MP.polynomial(p), constant_algebra_element(typeof(SA.basis(p)), q))
end
function Base.:(-)(p, q::_AE)
return -(constant_algebra_element(typeof(SA.basis(q)), p), MP.polynomial(q))
end
function Base.:(-)(p::_AE, q)
return -(MP.polynomial(p), constant_algebra_element(typeof(SA.basis(p)), q))
end

function Base.:(+)(p::_AE, q::_AE)
return MA.operate_to!(SA._preallocate_output(+, p, q), +, p, q)
end

function Base.:(-)(p::_AE, q::_AE)
return MA.operate_to!(SA._preallocate_output(-, p, q), -, p, q)
end
2 changes: 1 addition & 1 deletion src/chebyshev.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct ChebyshevFirstKind <: AbstractChebyshev end
const Chebyshev = ChebyshevFirstKind

# https://en.wikipedia.org/wiki/Chebyshev_polynomials#Properties
# T_n * T_m = T_{n + m} / 2 + T_{|n - m|} / 2
# `T_n * T_m = T_{n + m} / 2 + T_{|n - m|} / 2`
function (::Mul{Chebyshev})(a::MP.AbstractMonomial, b::MP.AbstractMonomial)
terms = [MP.term(1 // 1, MP.constant_monomial(a * b))]
vars_a = MP.variables(a)
Expand Down
20 changes: 10 additions & 10 deletions src/monomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,23 +198,23 @@ function algebra_element(f::Function, basis::SubBasis)
return algebra_element(map(f, eachindex(basis)), basis)
end

function constant_algebra_element(
::Type{FullBasis{B,M}},
::Type{T},
) where {B,M,T}
_one_if_type(α) = α
_one_if_type(::Type{T}) where {T} = one(T)

function constant_algebra_element(::Type{FullBasis{B,M}}, α) where {B,M}
return algebra_element(
sparse_coefficients(
MP.polynomial(MP.term(one(T), MP.constant_monomial(M))),
MP.polynomial(MP.term(_one_if_type), MP.constant_monomial(M))),
),
FullBasis{B,M}(),
)
end

function constant_algebra_element(
::Type{<:SubBasis{B,M}},
::Type{T},
) where {B,M,T}
return algebra_element([one(T)], SubBasis{B}([MP.constant_monomial(M)]))
function constant_algebra_element(::Type{<:SubBasis{B,M}}, α) where {B,M}
return algebra_element(
[_one_if_type(α)],
SubBasis{B}([MP.constant_monomial(M)]),
)
end

function _show(io::IO, mime::MIME, basis::SubBasis{B}) where {B}
Expand Down
12 changes: 7 additions & 5 deletions src/orthogonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,15 @@ function SA.coeffs(
end

function SA.coeffs(
p::Polynomial{B},
p::Polynomial{B,M},
::FullBasis{Monomial},
) where {B<:AbstractMultipleOrthogonal}
) where {B<:AbstractMultipleOrthogonal,M}
return sparse_coefficients(
prod(
univariate_orthogonal_basis(B, var, deg)[deg+1] for
(var, deg) in MP.powers(p.monomial)
),
MP.powers(p.monomial);
init = MP.constant_monomial(M),
) do (var, deg)
return univariate_orthogonal_basis(B, var, deg)[deg+1]
end,
)
end
4 changes: 3 additions & 1 deletion src/scaled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ function Base.promote_rule(
end

function scaling(m::MP.AbstractMonomial)
return (factorial(MP.degree(m)) / prod(factorial, MP.exponents(m)))
return (
factorial(MP.degree(m)) / prod(factorial, MP.exponents(m); init = 1),
)
end
unscale_coef(t::MP.AbstractTerm) = MP.coefficient(t) / scaling(MP.monomial(t))
function SA.coeffs(
Expand Down
12 changes: 8 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,14 @@ function api_test(B::Type{<:MB.AbstractMonomialIndexed}, degree)
_wrap(MB.SA.trim_LaTeX(mime, sprint(show, mime, p.monomial))) *
" \$\$"
const_mono = constant_monomial(prod(x))
@test const_mono + MB.algebra_element(MB.Polynomial{B}(const_mono)) == 2
@test MB.algebra_element(MB.Polynomial{B}(const_mono)) + const_mono == 2
@test iszero(const_mono - MB.algebra_element(MB.Polynomial{B}(const_mono)))
@test iszero(MB.algebra_element(MB.Polynomial{B}(const_mono)) - const_mono)
const_poly = MB.Polynomial{B}(const_mono)
const_alg_el = MB.algebra_element(const_poly)
for other in (const_mono, 1, const_alg_el)
@test other + const_alg_el 2 * other
@test const_alg_el + other 2 * other
@test iszero(other - const_alg_el)
@test iszero(const_alg_el - other)
end
@test typeof(MB.sparse_coefficients(sum(x))) ==
MA.promote_operation(MB.sparse_coefficients, typeof(sum(x)))
@test typeof(MB.algebra_element(sum(x))) ==
Expand Down

0 comments on commit c041a60

Please sign in to comment.