Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul mutating arithmetics #1784

Merged
merged 18 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AbstractAlgebra"
uuid = "c3fe647b-3220-5bb0-a1ea-a7954cac585d"
version = "0.42.5"
version = "0.42.6"

[deps]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand Down
7 changes: 3 additions & 4 deletions docs/src/ring.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,22 +167,21 @@ right, respectively, of `a`.

## Unsafe ring operators

To speed up polynomial arithmetic, various unsafe operators are provided, which
To speed up polynomial arithmetic, various unsafe operators are provided, which may
mutate the output rather than create a new object.

```julia
zero!(a::NCRingElement)
mul!(a::T, b::T, c::T) where T <: NCRingElement
add!(a::T, b::T, c::T) where T <: NCRingElement
addeq!(a::T, b::T) where T <: NCRingElement
addmul!(a::T, b::T, c::T, t::T) where T <: NCRingElement
```

In each case the mutated object is the leftmost parameter.

The `addeq!(a, b)` operation does the same thing as `add!(a, a, b)`. The
The `add!(a, b)` operation does the same thing as `add!(a, a, b)`. The
optional `addmul!(a, b, c, t)` operation does the same thing as
`mul!(t, b, c); addeq!(a, t)` where `t` is a temporary which can be mutated so
`mul!(t, b, c); add!(a, t)` where `t` is a temporary which can be mutated so
that an addition allocation is not needed.

## Random generation
Expand Down
13 changes: 1 addition & 12 deletions docs/src/ring_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -562,12 +562,6 @@ add!(c::MyElem, a::MyElem, b::MyElem)

Set $c$ to the value $a + b$ in place. Return the mutated value. Aliasing is permitted.

```julia
addeq!(a::MyElem, b::MyElem)
```

Set $a$ to $a + b$ in place. Return the mutated value. Aliasing is permitted.

### Random generation

The random functions are only used for test code to generate test data. They therefore
Expand Down Expand Up @@ -813,7 +807,7 @@ using Random: Random, SamplerTrivial, GLOBAL_RNG
using RandomExtensions: RandomExtensions, Make2, AbstractRNG

import AbstractAlgebra: parent_type, elem_type, base_ring, base_ring_type, parent, is_domain_type,
is_exact_type, canonical_unit, isequal, divexact, zero!, mul!, add!, addeq!,
is_exact_type, canonical_unit, isequal, divexact, zero!, mul!, add!,
get_cached!, is_unit, characteristic, Ring, RingElem, expressify

import Base: show, +, -, *, ^, ==, inv, isone, iszero, one, zero, rand,
Expand Down Expand Up @@ -980,11 +974,6 @@ function add!(f::ConstPoly{T}, g::ConstPoly{T}, h::ConstPoly{T}) where T <: Ring
return f
end

function addeq!(f::ConstPoly{T}, g::ConstPoly{T}) where T <: RingElement
f.c += g.c
return f
end

# Random generation

RandomExtensions.maketype(R::ConstPolyRing, _) = elem_type(R)
Expand Down
12 changes: 6 additions & 6 deletions src/AbsSeries.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
if lenz > i
for j = 2:min(lenb, lenz - i + 1)
t = mul!(t, coeff(a, i - 1), coeff(b, j - 1))
d[i + j - 1] = addeq!(d[i + j - 1], t)
d[i + j - 1] = add!(d[i + j - 1], t)

Check warning on line 302 in src/AbsSeries.jl

View check run for this annotation

Codecov / codecov/patch

src/AbsSeries.jl#L302

Added line #L302 was not covered by tests
end
end
end
Expand Down Expand Up @@ -740,7 +740,7 @@
x = set_precision!(x, la[n])
y = set_precision!(y, la[n])
y = mul!(y, minus_a, x)
y = addeq!(y, two)
y = add!(y, two)
x = mul!(x, x, y)
n -= 1
end
Expand Down Expand Up @@ -886,13 +886,13 @@
for i = 1:div(n - 1, 2)
j = n - i
p = mul!(p, coeff(asqrt, aval2 + i), coeff(asqrt, aval2 + j))
c = addeq!(c, p)
c = add!(c, p)
end
c *= 2
if (n % 2) == 0
i = div(n, 2)
p = mul!(p, coeff(asqrt, aval2 + i), coeff(asqrt, aval2 + i))
c = addeq!(c, p)
c = add!(c, p)
end
c = coeff(a, n + aval) - c
if check
Expand Down Expand Up @@ -1039,8 +1039,8 @@
x = set_precision!(x, la[n])
one1 = set_precision!(one1, la[n])
t = -log(x)
t = addeq!(t, one1)
t = addeq!(t, a)
t = add!(t, one1)
t = add!(t, a)
x = mul!(x, x, t)
n -= 1
end
Expand Down
3 changes: 3 additions & 0 deletions src/Deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
@alias FreeAssAlgebra FreeAssociativeAlgebra
@alias FreeAssAlgElem FreeAssociativeAlgebraElem

# renamed in 0.42.6
@alias addeq! add!

###############################################################################
#
# Deprecated bindings
Expand Down
10 changes: 5 additions & 5 deletions src/Fraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -831,13 +831,13 @@ function mul!(c::FracElem{T}, a::FracElem{T}, b::FracElem{T}) where {T <: RingEl
return c
end

function addeq!(a::FracElem{T}, b::FracElem{T}) where {T <: RingElem}
function add!(a::FracElem{T}, b::FracElem{T}) where {T <: RingElem}
d1 = denominator(a, false)
d2 = denominator(b, false)
n1 = numerator(a, false)
n2 = numerator(b, false)
if d1 == d2
a.num = addeq!(a.num, b.num)
a.num = add!(a.num, b.num)
if !isone(d1)
gd = gcd(a.num, d1)
if !isone(gd)
Expand All @@ -848,20 +848,20 @@ function addeq!(a::FracElem{T}, b::FracElem{T}) where {T <: RingElem}
elseif isone(d1)
if n1 !== n2
a.num = mul!(a.num, a.num, d2)
a.num = addeq!(a.num, n2)
a.num = add!(a.num, n2)
else
a.num = n1*d2 + n2
end
a.den = deepcopy(d2)
elseif isone(d2)
a.num = addeq!(a.num, n2*d1)
a.num = add!(a.num, n2*d1)
a.den = deepcopy(d1)
else
gd = gcd(d1, d2)
if isone(gd)
if n1 !== n2
a.num = mul!(a.num, a.num, d2)
a.num = addeq!(a.num, n2*d1)
a.num = add!(a.num, n2*d1)
else
a.num = n1*d2 + n2*d1
end
Expand Down
2 changes: 1 addition & 1 deletion src/FreeAssociativeAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ function evaluate(a::FreeAssociativeAlgebraElem{T}, vals::Vector{U}) where {T <:
r = zero(S)
o = one(S)
for (c, v) in zip(coefficients(a), exponent_words(a))
r = addeq!(r, c*prod((vals[i] for i in v), init = o))
r = add!(r, c*prod((vals[i] for i in v), init = o))
end
return r
end
Expand Down
20 changes: 4 additions & 16 deletions src/Groups.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,29 +236,17 @@ end
# Mutable API where modifications are recommended for performance reasons
################################################################################

# further mutable functions in fundamental_interface.jl:
# mul!(out::T, g::T, h::T) where {T<:GroupElem} = g * h
# inv!(out::T, g::T) where {T<:GroupElem} = inv(g)

"""
one!(g::GroupElem)

Return `one(g)`, possibly modifying `g`.
"""
one!(g::GroupElem) = one(parent(g))

"""
inv!(out::T, g::T) where {GEl <: GroupElem}

Return `inv(g)`, possibly modifying `out`. Aliasing of `g` with `out` is
allowed.
"""
inv!(out::T, g::T) where {T<:GroupElem} = inv(g)

"""
mul!(out::T, g::T, h::T) where {GEl <: GroupElem}

Return `g*h`, possibly modifying `out`. Aliasing of `g` or `h` with `out` is
allowed.
"""
mul!(out::T, g::T, h::T) where {T<:GroupElem} = g * h

"""
div_right!(out::T, g::T, h::T) where {GEl <: GroupElem}

Expand Down
6 changes: 3 additions & 3 deletions src/MPoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -905,13 +905,13 @@ function evaluate(a::MPolyRingElem{T}, vals::Vector{U}) where {T <: RingElement,
j = i = i + 1
while iseven(j) && length(r) > 1
top = pop!(r)
r[end] = addeq!(r[end], top)
r[end] = add!(r[end], top)
j >>= 1
end
end
while length(r) > 1
top = pop!(r)
r[end] = addeq!(r[end], top)
r[end] = add!(r[end], top)
end
return r[1]
end
Expand Down Expand Up @@ -1037,7 +1037,7 @@ function __evaluate(a, vars, vals, powers)
end
M = Generic.MPolyBuildCtx(S)
push_term!(M, c, v)
addeq!(r, t*finish(M))
add!(r, t*finish(M))
end
return r
end
Expand Down
2 changes: 1 addition & 1 deletion src/MatRing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ function *(x::MatRingElem{T}, y::MatRingElem{T}) where {T <: NCRingElement}
A[i, j] = base_ring(x)()
for k = 1:ncols(x)
C = mul!(C, x[i, k], y[k, j])
A[i, j] = addeq!(A[i, j], C)
A[i, j] = add!(A[i, j], C)
end
end
end
Expand Down
20 changes: 10 additions & 10 deletions src/Matrix-Strassen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@ argument "cutoff" to indicate when the base case should be used.

The speedup depends on the ring and the entry sizes.

#Examples:
# Examples

```jldoctest; setup = :(using AbstractAlgebra)
julia> m = matrix(ZZ, rand(-10:10, 1000, 1000));

julia> n = similar(m);

julia> mul!(n, m, m);
julia> n = mul!(n, m, m);

julia> Strassen.mul!(n, m, m);

julia> Strassen.mul!(n, m, m; cutoff = 100);
julia> n = Strassen.mul!(n, m, m);

julia> n = Strassen.mul!(n, m, m; cutoff = 100);
```
"""
module Strassen
Expand All @@ -49,8 +48,7 @@ function mul!(C::MatElem{T}, A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff)
@assert a == sC[1] && b == sB[1] && c == sC[2]

if (a <= cutoff || b <= cutoff || c <= cutoff)
AbstractAlgebra.mul!(C, A, B)
return
return AbstractAlgebra.mul!(C, A, B)
end

anr = div(a, 2)
Expand Down Expand Up @@ -142,7 +140,7 @@ function mul!(C::MatElem{T}, A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff)
#nmod_mat_window_init(Cc, C, 0, 2*bnc, a, c);
Cc = view(C, 1:a, 2*bnc+1:c)
#nmod_mat_mul(Cc, A, Bc);
AbstractAlgebra.mul!(Cc, A, Bc)
Cc = AbstractAlgebra.mul!(Cc, A, Bc)
end

if a > 2*anr #last row of A by B -> last row of C
Expand All @@ -151,7 +149,7 @@ function mul!(C::MatElem{T}, A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff)
#nmod_mat_window_init(Cr, C, 2*anr, 0, a, c);
Cr = view(C, 2*anr+1:a, 1:c)
#nmod_mat_mul(Cr, Ar, B);
AbstractAlgebra.mul!(Cr, Ar, B)
Cr = AbstractAlgebra.mul!(Cr, Ar, B)
end

if b > 2*anc # last col of A by last row of B -> C
Expand All @@ -162,8 +160,10 @@ function mul!(C::MatElem{T}, A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff)
#nmod_mat_window_init(Cb, C, 0, 0, 2*anr, 2*bnc);
Cb = view(C, 1:2*anr, 1:2*bnc)
#nmod_mat_addmul(Cb, Cb, Ac, Br);
AbstractAlgebra.mul!(Cb, Ac, Br, true)
Cb = AbstractAlgebra.mul!(Cb, Ac, Br)
end

return C
end

#solve_tril fast, recursive
Expand Down
Loading
Loading