Skip to content

Commit 9f6cf21

Browse files
authored
Make Reactant loadable when the JLL isn't available (#1003)
* Make Reactant loadable when the JLL isn't available * Make formatter happy
1 parent f5c23a0 commit 9f6cf21

File tree

5 files changed

+79
-64
lines changed

5 files changed

+79
-64
lines changed

deps/build_local.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,6 @@ using Preferences
194194
set_preferences!(
195195
joinpath(dirname(@__DIR__), "LocalPreferences.toml"),
196196
"Reactant_jll",
197-
"libReactantExtra_path" => lib_path,
197+
"libReactantExtra_path" => lib_path;
198198
force=true,
199199
)

src/Precompile.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,29 +57,31 @@ function precompiling()
5757
return (@ccall jl_generating_output()::Cint) == 1
5858
end
5959

60-
@setup_workload begin
61-
initialize_dialect()
60+
if Reactant_jll.is_available()
61+
@setup_workload begin
62+
initialize_dialect()
6263

63-
if XLA.REACTANT_XLA_RUNTIME == "PJRT"
64-
client = XLA.PJRT.CPUClient(; checkcount=false)
65-
elseif XLA.REACTANT_XLA_RUNTIME == "IFRT"
66-
client = XLA.IFRT.CPUClient(; checkcount=false)
67-
else
68-
error("Unsupported runtime: $(XLA.REACTANT_XLA_RUNTIME)")
69-
end
64+
if XLA.REACTANT_XLA_RUNTIME == "PJRT"
65+
client = XLA.PJRT.CPUClient(; checkcount=false)
66+
elseif XLA.REACTANT_XLA_RUNTIME == "IFRT"
67+
client = XLA.IFRT.CPUClient(; checkcount=false)
68+
else
69+
error("Unsupported runtime: $(XLA.REACTANT_XLA_RUNTIME)")
70+
end
7071

71-
@compile_workload begin
72-
@static if precompilation_supported()
73-
x = ConcreteRNumber(2.0; client)
74-
Reactant.compile(sin, (x,); client, optimize=:all)
72+
@compile_workload begin
73+
@static if precompilation_supported()
74+
x = ConcreteRNumber(2.0; client)
75+
Reactant.compile(sin, (x,); client, optimize=:all)
7576

76-
y = ConcreteRArray([2.0]; client)
77-
Reactant.compile(Base.sum, (y,); client, optimize=:all)
77+
y = ConcreteRArray([2.0]; client)
78+
Reactant.compile(Base.sum, (y,); client, optimize=:all)
79+
end
7880
end
79-
end
8081

81-
XLA.free_client(client)
82-
client.client = C_NULL
83-
deinitialize_dialect()
84-
clear_oc_cache()
82+
XLA.free_client(client)
83+
client.client = C_NULL
84+
deinitialize_dialect()
85+
clear_oc_cache()
86+
end
8587
end

src/Reactant.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,12 @@ function initialize_ptrs()
252252
end
253253

254254
function __init__()
255-
initialize_ptrs()
256-
initialize_dialect()
255+
if Reactant_jll.is_available()
256+
initialize_ptrs()
257+
initialize_dialect()
258+
else
259+
@warn "Reactant_jll isn't availble for your platform $(Reactant_jll.host_platform)"
260+
end
257261
return nothing
258262
end
259263

src/mlir/MLIR.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ module API
77
using Preferences
88
using Reactant_jll
99

10-
const mlir_c = Reactant_jll.libReactantExtra
10+
const mlir_c = if Reactant_jll.is_available()
11+
Reactant_jll.libReactantExtra
12+
else
13+
""
14+
end
1115

1216
# MLIR C API
1317
let

src/xla/XLA.jl

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -104,53 +104,58 @@ function update_global_state!(args...; kwargs...)
104104
end
105105

106106
function __init__()
107-
# This must be the very first thing initialized (otherwise we can't throw errors)
108-
errptr = cglobal((:ReactantThrowError, MLIR.API.mlir_c), Ptr{Ptr{Cvoid}})
109-
unsafe_store!(errptr, @cfunction(reactant_err, Cvoid, (Cstring,)))
110-
111-
initLogs = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "InitializeLogs")
112-
ccall(initLogs, Cvoid, ())
113-
# Add most log level
114-
# SetLogLevel(0)
115-
116-
if haskey(ENV, "XLA_REACTANT_GPU_MEM_FRACTION")
117-
XLA_REACTANT_GPU_MEM_FRACTION[] = parse(
118-
Float64, ENV["XLA_REACTANT_GPU_MEM_FRACTION"]
119-
)
120-
@debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[]
121-
if XLA_REACTANT_GPU_MEM_FRACTION[] > 1 || XLA_REACTANT_GPU_MEM_FRACTION[] < 0
122-
error("XLA_REACTANT_GPU_MEM_FRACTION must be between 0 and 1")
107+
if Reactant_jll.is_available()
108+
# This must be the very first thing initialized (otherwise we can't throw errors)
109+
errptr = cglobal((:ReactantThrowError, MLIR.API.mlir_c), Ptr{Ptr{Cvoid}})
110+
unsafe_store!(errptr, @cfunction(reactant_err, Cvoid, (Cstring,)))
111+
112+
initLogs = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "InitializeLogs")
113+
ccall(initLogs, Cvoid, ())
114+
# Add most log level
115+
# SetLogLevel(0)
116+
117+
if haskey(ENV, "XLA_REACTANT_GPU_MEM_FRACTION")
118+
XLA_REACTANT_GPU_MEM_FRACTION[] = parse(
119+
Float64, ENV["XLA_REACTANT_GPU_MEM_FRACTION"]
120+
)
121+
@debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[]
122+
if XLA_REACTANT_GPU_MEM_FRACTION[] > 1 || XLA_REACTANT_GPU_MEM_FRACTION[] < 0
123+
error("XLA_REACTANT_GPU_MEM_FRACTION must be between 0 and 1")
124+
end
123125
end
124-
end
125126

126-
if haskey(ENV, "XLA_REACTANT_GPU_PREALLOCATE")
127-
XLA_REACTANT_GPU_PREALLOCATE[] = parse(Bool, ENV["XLA_REACTANT_GPU_PREALLOCATE"])
128-
@debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[]
129-
end
127+
if haskey(ENV, "XLA_REACTANT_GPU_PREALLOCATE")
128+
XLA_REACTANT_GPU_PREALLOCATE[] = parse(
129+
Bool, ENV["XLA_REACTANT_GPU_PREALLOCATE"]
130+
)
131+
@debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[]
132+
end
130133

131-
if haskey(ENV, "REACTANT_VISIBLE_GPU_DEVICES")
132-
global_state.local_gpu_device_ids =
133-
parse.(Int, split(ENV["REACTANT_VISIBLE_GPU_DEVICES"], ","))
134-
@debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_gpu_device_ids
135-
end
134+
if haskey(ENV, "REACTANT_VISIBLE_GPU_DEVICES")
135+
global_state.local_gpu_device_ids =
136+
parse.(Int, split(ENV["REACTANT_VISIBLE_GPU_DEVICES"], ","))
137+
@debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_gpu_device_ids
138+
end
136139

137-
@debug "REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME
140+
@debug "REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME
138141

139-
@ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid
140-
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid
142+
@ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid
143+
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid
141144

142-
@static if !Sys.isapple()
143-
lljit = Enzyme.LLVM.JuliaOJIT()
144-
jd_main = Enzyme.LLVM.JITDylib(lljit)
145+
@static if !Sys.isapple()
146+
lljit = Enzyme.LLVM.JuliaOJIT()
147+
jd_main = Enzyme.LLVM.JITDylib(lljit)
145148

146-
for name in ("XLAExecute", "XLAExecuteSharded", "ifrt_loaded_executable_execute")
147-
ptr = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, name)
148-
Enzyme.LLVM.define(
149-
jd_main,
150-
Enzyme.Compiler.JIT.absolute_symbol_materialization(
151-
Enzyme.LLVM.mangle(lljit, name), ptr
152-
),
153-
)
149+
for name in
150+
("XLAExecute", "XLAExecuteSharded", "ifrt_loaded_executable_execute")
151+
ptr = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, name)
152+
Enzyme.LLVM.define(
153+
jd_main,
154+
Enzyme.Compiler.JIT.absolute_symbol_materialization(
155+
Enzyme.LLVM.mangle(lljit, name), ptr
156+
),
157+
)
158+
end
154159
end
155160
end
156161

0 commit comments

Comments
 (0)