Skip to content

Commit a692b2a

Browse files
arghhhhandreasnoack
authored andcommitted
Removing requirement for promotions from 0,1 in lufact (#22146)
* Removing requirement for promotions from 0,1 * Using iszero() and added test * Tidy up spaces, tabs etc
1 parent 54158fe commit a692b2a

File tree

3 files changed

+34
-4
lines changed

3 files changed

+34
-4
lines changed

base/linalg/generic.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ true
957957
function istriu(A::AbstractMatrix)
958958
m, n = size(A)
959959
for j = 1:min(n,m-1), i = j+1:m
960-
if A[i,j] != 0
960+
if !iszero(A[i,j])
961961
return false
962962
end
963963
end
@@ -992,7 +992,7 @@ true
992992
function istril(A::AbstractMatrix)
993993
m, n = size(A)
994994
for j = 2:n, i = 1:min(j-1,m)
995-
if A[i,j] != 0
995+
if !iszero(A[i,j])
996996
return false
997997
end
998998
end

base/linalg/lu.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function generic_lufact!(A::StridedMatrix{T}, ::Type{Val{Pivot}} = Val{true}) wh
3939
# find index max
4040
kp = k
4141
if Pivot
42-
amax = real(zero(T))
42+
amax = abs(zero(T))
4343
for i = k:m
4444
absi = abs(A[i,k])
4545
if absi > amax
@@ -49,7 +49,7 @@ function generic_lufact!(A::StridedMatrix{T}, ::Type{Val{Pivot}} = Val{true}) wh
4949
end
5050
end
5151
ipiv[k] = kp
52-
if A[kp,k] != 0
52+
if !iszero(A[kp,k])
5353
if k != kp
5454
# Interchange
5555
for i = 1:n

test/linalg/generic.jl

+30
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,33 @@ end
310310
@test [[1, 2], [3, 4]] [[1.0-eps(), 2.0+eps()], [3.0+2eps(), 4.0-1e8eps()]]
311311
@test [[1, 2], [3, 4]] [[1.0-eps(), 2.0+eps()], [3.0+2eps(), 4.0-1e9eps()]]
312312
@test [[1,2, [3,4]], 5.0, [6im, [7.0, 8.0]]] [[1,2, [3,4]], 5.0, [6im, [7.0, 8.0]]]
313+
314+
# Issue 22042
315+
# Minimal modulo number type - but not subtyping Number
316+
struct ModInt{n}
317+
k
318+
ModInt{n}(k) where {n} = new(mod(k,n))
319+
end
320+
321+
Base.:+(a::ModInt{n}, b::ModInt{n}) where {n} = ModInt{n}(a.k + b.k)
322+
Base.:-(a::ModInt{n}, b::ModInt{n}) where {n} = ModInt{n}(a.k - b.k)
323+
Base.:*(a::ModInt{n}, b::ModInt{n}) where {n} = ModInt{n}(a.k * b.k)
324+
Base.:-(a::ModInt{n}) where {n} = ModInt{n}(-a.k)
325+
Base.inv(a::ModInt{n}) where {n} = ModInt{n}(invmod(a.k, n))
326+
Base.:/(a::ModInt{n}, b::ModInt{n}) where {n} = a*inv(b)
327+
328+
Base.zero(::Type{ModInt{n}}) where {n} = ModInt{n}(0)
329+
Base.zero(::ModInt{n}) where {n} = ModInt{n}(0)
330+
Base.one(::Type{ModInt{n}}) where {n} = ModInt{n}(1)
331+
Base.one(::ModInt{n}) where {n} = ModInt{n}(1)
332+
333+
# Needed for pivoting:
334+
Base.abs(a::ModInt{n}) where {n} = a
335+
Base.:<(a::ModInt{n}, b::ModInt{n}) where {n} = a.k < b.k
336+
Base.transpose(a::ModInt{n}) where {n} = a # see Issue 20978
337+
338+
A = [ ModInt{2}(1) ModInt{2}(0) ; ModInt{2}(1) ModInt{2}(1) ]
339+
b = [ ModInt{2}(1), ModInt{2}(0) ]
340+
341+
@test A*(A\b) == b
342+
@test_nowarn lufact( A, Val{true} )

0 commit comments

Comments
 (0)