Skip to content

Commit deb38b1

Browse files
kshyattamontoison
authored andcommitted
Wrap the Givens rotation methods
1 parent f3b3f8b commit deb38b1

File tree

2 files changed

+127
-1
lines changed

2 files changed

+127
-1
lines changed

lib/cublas/wrappers.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,62 @@ for (fname, fname_64, elty, sty) in ((:cublasSrot_v2, :cublasSrot_v2_64, :Float3
295295
end
296296
end
297297

298+
## rotg
299+
for (fname, elty) in ((:cublasSrotg_v2, :Float32),
300+
(:cublasDrotg_v2, :Float64),
301+
(:cublasCrotg_v2, :ComplexF32),
302+
(:cublasZrotg_v2, :ComplexF64),
303+
)
304+
@eval begin
305+
function rotg!(a::$elty, b::$elty)
306+
c = Ref{real($elty)}(zero(real($elty)))
307+
s = Ref{$elty}(zero($elty))
308+
ref_a = Ref(a)
309+
ref_b = Ref(b)
310+
$fname(handle(), ref_a, ref_b, c, s)
311+
ref_a[], ref_b[], c[], s[]
312+
end
313+
end
314+
end
315+
316+
## rotm
317+
for (fname, fname_64, elty) in ((:cublasSrotm_v2, :cublasSrotm_v2_64, :Float32),
318+
(:cublasDrotm_v2, :cublasDrotm_v2_64, :Float64),
319+
)
320+
@eval begin
321+
function rotm!(n::Integer,
322+
x::StridedCuVecOrDenseMat{$elty},
323+
y::StridedCuVecOrDenseMat{$elty},
324+
param::AbstractVector{$elty})
325+
if CUBLAS.version() >= v"12.0"
326+
$fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1), param)
327+
else
328+
$fname(handle(), n, x, stride(x, 1), y, stride(y, 1), param)
329+
end
330+
x, y
331+
end
332+
end
333+
end
334+
335+
## rotmg
336+
for (fname, elty) in ((:cublasSrotmg_v2, :Float32),
337+
(:cublasDrotmg_v2, :Float64))
338+
@eval begin
339+
function rotmg!(d1::$elty,
340+
d2::$elty,
341+
x::$elty,
342+
y::$elty,
343+
param::AbstractVector{$elty})
344+
ref_d1 = Ref(d1)
345+
ref_d2 = Ref(d2)
346+
ref_x = Ref(x)
347+
ref_y = Ref(y)
348+
$fname(handle(), ref_d1, ref_d2, ref_x, ref_y, param)
349+
ref_d1[], ref_d2[], ref_x[], ref_y[], param
350+
end
351+
end
352+
end
353+
298354
## swap
299355
for (fname, fname_64, elty) in ((:cublasSswap_v2, :cublasSswap_v2_64, :Float32),
300356
(:cublasDswap_v2, :cublasDswap_v2_64, :Float64),

test/libraries/cublas/level1.jl

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using CUDA.CUBLAS
2-
32
using LinearAlgebra
43

54
using BFloat16s
@@ -49,6 +48,77 @@ k = 13
4948
@test testf(reflect!, rand(T, m), rand(T, m), rand(real(T)), rand(real(T)))
5049
@test testf(reflect!, rand(T, m), rand(T, m), rand(real(T)), rand(T))
5150
end
51+
52+
@testset "rotg!" begin
53+
a = rand(T)
54+
b = rand(T)
55+
a_copy = copy(a)
56+
b_copy = copy(b)
57+
a, b, c, s = CUBLAS.rotg!(a, b)
58+
rot = [c s; -conj(s) c] * [a_copy; b_copy]
59+
@test rot [a; 0]
60+
if T <: Real
61+
@test a^2 a_copy^2 + b_copy^2
62+
end
63+
@test c^2 + abs2(s) one(T)
64+
end
65+
66+
if T <: Real
67+
H = rand(T, 2, 2)
68+
@testset "flag $flag" for (flag, flag_H) in ((T(-2), [one(T) zero(T); zero(T) one(T)]),
69+
(-one(T), H),
70+
(zero(T), [one(T) H[1,2]; H[2, 1] one(T)]),
71+
(one(T), [H[1,1] one(T); -one(T) H[2, 2]]),
72+
)
73+
@testset "rotm!" begin
74+
rot_n = 2
75+
x = rand(T, rot_n)
76+
y = rand(T, rot_n)
77+
dx = CuArray(x)
78+
dy = CuArray(y)
79+
dx, dy = CUBLAS.rotm!(rot_n, dx, dy, vcat(flag, H...))
80+
h_x = collect(dx)
81+
h_y = collect(dy)
82+
@test h_x [x[1] * flag_H[1,1] + y[1] * flag_H[1,2]; x[2] * flag_H[1, 1] + y[2] * flag_H[1, 2]]
83+
@test h_y [x[1] * flag_H[2,1] + y[1] * flag_H[2,2]; x[2] * flag_H[2, 1] + y[2] * flag_H[2, 2]]
84+
end
85+
end
86+
@testset "rotmg!" begin
87+
param = zeros(T, 5)
88+
x1 = rand(T)
89+
y1 = rand(T)
90+
d1 = zero(T)
91+
d2 = zero(T)
92+
x1_copy = copy(x1)
93+
y1_copy = copy(y1)
94+
d1, d2, x1, y1, param = CUBLAS.rotmg!(d1, d2, x1, y1, param)
95+
flag = param[1]
96+
H = zeros(T, 2, 2)
97+
if flag == -2
98+
H[1, 1] = one(T)
99+
H[1, 2] = zero(T)
100+
H[2, 1] = zero(T)
101+
H[2, 2] = one(T)
102+
elseif flag == -1
103+
H[1, 1] = param[2]
104+
H[1, 2] = param[3]
105+
H[2, 1] = param[4]
106+
H[2, 2] = param[5]
107+
elseif iszero(flag)
108+
H[1, 1] = one(T)
109+
H[1, 2] = param[3]
110+
H[2, 1] = param[4]
111+
H[2, 2] = one(T)
112+
elseif flag == 1
113+
H[1, 1] = param[2]
114+
H[1, 2] = one(T)
115+
H[2, 1] = -one(T)
116+
H[2, 2] = param[5]
117+
end
118+
out = H * [(d1) * x1_copy; (d2) * y1_copy]
119+
@test out[2] zero(T)
120+
end
121+
end
52122

53123
@testset "swap!" begin
54124
# swap is an extension

0 commit comments

Comments
 (0)