Skip to content

Commit 539a485

Browse files
committed
feat: lapack integration working 🎉
1 parent f6bedb6 commit 539a485

File tree

4 files changed

+18
-13
lines changed

4 files changed

+18
-13
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "657c2f62e4cc3dd2b52ff0587f644b64858d7c0b"
12+
ENZYMEXLA_COMMIT = "abd48236a7466bd0db7d2c8ffa56d1d724463c91"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(

src/Compiler.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module Compiler
22

33
using Reactant_jll
44
using Libdl: dlsym
5+
using LinearAlgebra: BLAS
56

67
import ..Reactant:
78
Reactant,
@@ -1274,6 +1275,10 @@ function compile_mlir!(
12741275
"canonicalize"
12751276
end
12761277

1278+
blas_int_width = sizeof(BLAS.BlasInt) * 8
1279+
lower_factorization_pass = "lower-factorization{backend=$backend \
1280+
blas_int_width=$blas_int_width}"
1281+
12771282
if optimize === :all
12781283
run_pass_pipeline!(
12791284
mod,
@@ -1292,7 +1297,7 @@ function compile_mlir!(
12921297
"remove-unnecessary-enzyme-ops",
12931298
"enzyme-simplify-math",
12941299
opt_passes2,
1295-
"lower-factorization{backend=$backend}",
1300+
lower_factorization_pass,
12961301
jit,
12971302
]
12981303
else
@@ -1309,7 +1314,7 @@ function compile_mlir!(
13091314
opt_passes2,
13101315
kern,
13111316
raise_passes,
1312-
"lower-factorization{backend=$backend}",
1317+
lower_factorization_pass,
13131318
jit,
13141319
]
13151320
end,
@@ -1456,7 +1461,7 @@ function compile_mlir!(
14561461
"remove-unnecessary-enzyme-ops",
14571462
"enzyme-simplify-math",
14581463
opt_passes2,
1459-
"lower-factorization{backend=$backend}",
1464+
lower_factorization_pass,
14601465
jit,
14611466
]
14621467
else
@@ -1470,7 +1475,7 @@ function compile_mlir!(
14701475
opt_passes2,
14711476
kern,
14721477
raise_passes,
1473-
"lower-factorization{backend=$backend}",
1478+
lower_factorization_pass,
14741479
jit,
14751480
]
14761481
end,
@@ -1492,7 +1497,7 @@ function compile_mlir!(
14921497
opt_passes2,
14931498
enzyme_pass,
14941499
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
1495-
"lower-factorization{backend=$backend}",
1500+
lower_factorization_pass,
14961501
jit,
14971502
]
14981503
else
@@ -1505,7 +1510,7 @@ function compile_mlir!(
15051510
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
15061511
kern,
15071512
raise_passes,
1508-
"lower-factorization{backend=$backend}",
1513+
lower_factorization_pass,
15091514
jit,
15101515
]
15111516
end,

src/mlir/Dialects/EnzymeXLA.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ function kernel_call(
188188
)
189189
end
190190

191-
function lu_factorization(
191+
function linalg_lu(
192192
input::Value; output::IR.Type, pivots::IR.Type, info::IR.Type, location=Location()
193193
)
194194
op_ty_results = IR.Type[output, pivots, info]
@@ -198,7 +198,7 @@ function lu_factorization(
198198
attributes = NamedAttribute[]
199199

200200
return create_operation(
201-
"enzymexla.lu_factorization",
201+
"enzymexla.linalg.lu",
202202
location;
203203
operands,
204204
owned_regions,

src/stdlibs/LinearAlgebra.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,21 +498,21 @@ function LinearAlgebra.lu!(
498498
# TODO: stop ignoring `check` and `allowsingular`
499499
m, n = size(A)
500500

501-
lu_op = MLIR.Dialects.enzymexla.lu_factorization(
501+
lu_op = MLIR.Dialects.enzymexla.linalg_lu(
502502
materialize_traced_array(A).mlir_data;
503503
output=MLIR.IR.TensorType(
504504
collect(Int64, size(A)), MLIR.IR.Type(unwrapped_eltype(A))
505505
),
506-
pivots=MLIR.IR.TensorType(Int64[max(m, n)], MLIR.IR.Type(Int32)),
506+
pivots=MLIR.IR.TensorType(Int64[min(m, n)], MLIR.IR.Type(Int32)),
507507
info=MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Int32)),
508508
)
509509
set_mlir_data!(A, MLIR.IR.result(lu_op, 1))
510-
pivots = TracedRArray{Int32,1}((), MLIR.IR.result(lu_op, 2), (max(m, n),))
510+
pivots = TracedRArray{Int32,1}((), MLIR.IR.result(lu_op, 2), (min(m, n),))
511511
info = TracedRNumber{Int32}((), MLIR.IR.result(lu_op, 3))
512512

513513
# XXX: `info` needs to be a BLASInt for LU
514514
# return LinearAlgebra.LU{T, typeof(A), typeof(pivots)}(A, pivots, -1)
515-
return (A, pivots, info)
515+
return (UnitLowerTriangular(A), UpperTriangular(A), pivots, info)
516516
end
517517

518518
end

0 commit comments

Comments
 (0)