Skip to content

Commit ddf79a8

Browse files
authored
Add rot! to BLAS in stdlib/LinearAlgebra, add generic rotate! and reflect! (#35124)
1 parent bf8aae8 commit ddf79a8

File tree

6 files changed

+126
-0
lines changed

6 files changed

+126
-0
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ Standard library changes
115115
* `normalize` now supports multidimensional arrays ([#34239])
116116
* `lq` factorizations can now be used to compute the minimum-norm solution to under-determined systems ([#34350]).
117117
* The BLAS submodule now supports the level-2 BLAS subroutine `spmv!` ([#34320]).
118+
* The BLAS submodule now supports the level-1 BLAS subroutine `rot!` ([#35124]).
119+
* New generic `rotate!(x, y, c, s)` and `reflect!(x, y, c, s)` functions ([#35124]).
118120

119121
#### Markdown
120122

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ export
124124
opnorm,
125125
rank,
126126
rdiv!,
127+
reflect!,
128+
rotate!,
127129
schur,
128130
schur!,
129131
svd,

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ export
1717
blascopy!,
1818
dotc,
1919
dotu,
20+
rot!,
2021
scal!,
2122
scal,
2223
nrm2,
@@ -198,6 +199,37 @@ for (fname, elty) in ((:dcopy_,:Float64),
198199
end
199200
end
200201

202+
203+
## rot
204+
205+
"""
206+
rot!(n, X, incx, Y, incy, c, s)
207+
208+
Overwrite `X` with `c*X + s*Y` and `Y` with `-conj(s)*X + c*Y` for the first `n` elements of array `X` with stride `incx` and
209+
first `n` elements of array `Y` with stride `incy`. Returns `X` and `Y`.
210+
211+
!!! compat "Julia 1.5"
212+
`rot!` requires at least Julia 1.5.
213+
"""
214+
function rot! end
215+
216+
for (fname, elty, cty, sty, lib) in ((:drot_, :Float64, :Float64, :Float64, libblas),
217+
(:srot_, :Float32, :Float32, :Float32, libblas),
218+
(:zdrot_, :ComplexF64, :Float64, :Float64, libblas),
219+
(:csrot_, :ComplexF32, :Float32, :Float32, libblas),
220+
(:zrot_, :ComplexF64, :Float64, :ComplexF64, liblapack),
221+
(:crot_, :ComplexF32, :Float32, :ComplexF32, liblapack))
222+
@eval begin
223+
# SUBROUTINE DROT(N,DX,INCX,DY,INCY,C,S)
224+
function rot!(n::Integer, DX::Union{Ptr{$elty},AbstractArray{$elty}}, incx::Integer, DY::Union{Ptr{$elty},AbstractArray{$elty}}, incy::Integer, C::$cty, S::$sty)
225+
ccall((@blasfunc($fname), $lib), Cvoid,
226+
(Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ref{$cty}, Ref{$sty}),
227+
n, DX, incx, DY, incy, C, S)
228+
DX, DY
229+
end
230+
end
231+
end
232+
201233
## scal
202234

203235
"""

stdlib/LinearAlgebra/src/generic.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,6 +1416,51 @@ function axpby!(α, x::AbstractArray, β, y::AbstractArray)
14161416
y
14171417
end
14181418

1419+
"""
1420+
rotate!(x, y, c, s)
1421+
1422+
Overwrite `x` with `c*x + s*y` and `y` with `-conj(s)*x + c*y`.
1423+
Returns `x` and `y`.
1424+
1425+
!!! compat "Julia 1.5"
1426+
`rotate!` requires at least Julia 1.5.
1427+
"""
1428+
function rotate!(x::AbstractVector, y::AbstractVector, c, s)
1429+
require_one_based_indexing(x, y)
1430+
n = length(x)
1431+
if n != length(y)
1432+
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
1433+
end
1434+
@inbounds for i = 1:n
1435+
xi, yi = x[i], y[i]
1436+
x[i] = c *xi + s*yi
1437+
y[i] = -conj(s)*xi + c*yi
1438+
end
1439+
return x, y
1440+
end
1441+
1442+
"""
1443+
reflect!(x, y, c, s)
1444+
1445+
Overwrite `x` with `c*x + s*y` and `y` with `conj(s)*x - c*y`.
1446+
Returns `x` and `y`.
1447+
1448+
!!! compat "Julia 1.5"
1449+
`reflect!` requires at least Julia 1.5.
1450+
"""
1451+
function reflect!(x::AbstractVector, y::AbstractVector, c, s)
1452+
require_one_based_indexing(x, y)
1453+
n = length(x)
1454+
if n != length(y)
1455+
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
1456+
end
1457+
@inbounds for i = 1:n
1458+
xi, yi = x[i], y[i]
1459+
x[i] = c *xi + s*yi
1460+
y[i] = conj(s)*xi - c*yi
1461+
end
1462+
return x, y
1463+
end
14191464

14201465
# Elementary reflection similar to LAPACK. The reflector is not Hermitian but
14211466
# ensures that tridiagonalization of Hermitian matrices become real. See lawn72

stdlib/LinearAlgebra/test/blas.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,32 @@ Random.seed!(100)
7878
@test BLAS.iamax(z) == argmax(map(x -> abs(real(x)) + abs(imag(x)), z))
7979
end
8080
end
81+
@testset "rot!" begin
82+
if elty <: Real
83+
x = convert(Vector{elty}, randn(n))
84+
y = convert(Vector{elty}, randn(n))
85+
c = rand(elty)
86+
s = rand(elty)
87+
x2 = copy(x)
88+
y2 = copy(y)
89+
BLAS.rot!(n, x, 1, y, 1, c, s)
90+
@test x c*x2 + s*y2
91+
@test y -s*x2 + c*y2
92+
else
93+
x = convert(Vector{elty}, complex.(randn(n),rand(n)))
94+
y = convert(Vector{elty}, complex.(randn(n),rand(n)))
95+
cty = (elty == ComplexF32) ? Float32 : Float64
96+
c = rand(cty)
97+
for sty in [cty, elty]
98+
s = rand(sty)
99+
x2 = copy(x)
100+
y2 = copy(y)
101+
BLAS.rot!(n, x, 1, y, 1, c, s)
102+
@test x c*x2 + s*y2
103+
@test y -conj(s)*x2 + c*y2
104+
end
105+
end
106+
end
81107
@testset "axp(b)y" begin
82108
if elty <: Real
83109
x1 = convert(Vector{elty}, randn(n))

stdlib/LinearAlgebra/test/generic.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,25 @@ end
217217
@test norm(x, 3) cbrt(5^3 +sqrt(5)^3)
218218
end
219219

220+
@testset "rotate! and reflect!" begin
221+
x = rand(ComplexF64, 10)
222+
y = rand(ComplexF64, 10)
223+
c = rand(Float64)
224+
s = rand(ComplexF64)
225+
226+
x2 = copy(x)
227+
y2 = copy(y)
228+
rotate!(x, y, c, s)
229+
@test x c*x2 + s*y2
230+
@test y -conj(s)*x2 + c*y2
231+
232+
x3 = copy(x)
233+
y3 = copy(y)
234+
reflect!(x, y, c, s)
235+
@test x c*x3 + s*y3
236+
@test y conj(s)*x3 - c*y3
237+
end
238+
220239
@testset "LinearAlgebra.axp(b)y! for element type without commutative multiplication" begin
221240
α = [1 2; 3 4]
222241
β = [5 6; 7 8]

0 commit comments

Comments
 (0)