@@ -104,53 +104,58 @@ function update_global_state!(args...; kwargs...)
104
104
end
105
105
106
106
function __init__ ()
107
- # This must be the very first thing initialized (otherwise we can't throw errors)
108
- errptr = cglobal ((:ReactantThrowError , MLIR. API. mlir_c), Ptr{Ptr{Cvoid}})
109
- unsafe_store! (errptr, @cfunction (reactant_err, Cvoid, (Cstring,)))
110
-
111
- initLogs = Libdl. dlsym (Reactant_jll. libReactantExtra_handle, " InitializeLogs" )
112
- ccall (initLogs, Cvoid, ())
113
- # Add most log level
114
- # SetLogLevel(0)
115
-
116
- if haskey (ENV , " XLA_REACTANT_GPU_MEM_FRACTION" )
117
- XLA_REACTANT_GPU_MEM_FRACTION[] = parse (
118
- Float64, ENV [" XLA_REACTANT_GPU_MEM_FRACTION" ]
119
- )
120
- @debug " XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[]
121
- if XLA_REACTANT_GPU_MEM_FRACTION[] > 1 || XLA_REACTANT_GPU_MEM_FRACTION[] < 0
122
- error (" XLA_REACTANT_GPU_MEM_FRACTION must be between 0 and 1" )
107
+ if Reactant_jll. is_available ()
108
+ # This must be the very first thing initialized (otherwise we can't throw errors)
109
+ errptr = cglobal ((:ReactantThrowError , MLIR. API. mlir_c), Ptr{Ptr{Cvoid}})
110
+ unsafe_store! (errptr, @cfunction (reactant_err, Cvoid, (Cstring,)))
111
+
112
+ initLogs = Libdl. dlsym (Reactant_jll. libReactantExtra_handle, " InitializeLogs" )
113
+ ccall (initLogs, Cvoid, ())
114
+ # Add most log level
115
+ # SetLogLevel(0)
116
+
117
+ if haskey (ENV , " XLA_REACTANT_GPU_MEM_FRACTION" )
118
+ XLA_REACTANT_GPU_MEM_FRACTION[] = parse (
119
+ Float64, ENV [" XLA_REACTANT_GPU_MEM_FRACTION" ]
120
+ )
121
+ @debug " XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[]
122
+ if XLA_REACTANT_GPU_MEM_FRACTION[] > 1 || XLA_REACTANT_GPU_MEM_FRACTION[] < 0
123
+ error (" XLA_REACTANT_GPU_MEM_FRACTION must be between 0 and 1" )
124
+ end
123
125
end
124
- end
125
126
126
- if haskey (ENV , " XLA_REACTANT_GPU_PREALLOCATE" )
127
- XLA_REACTANT_GPU_PREALLOCATE[] = parse (Bool, ENV [" XLA_REACTANT_GPU_PREALLOCATE" ])
128
- @debug " XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[]
129
- end
127
+ if haskey (ENV , " XLA_REACTANT_GPU_PREALLOCATE" )
128
+ XLA_REACTANT_GPU_PREALLOCATE[] = parse (
129
+ Bool, ENV [" XLA_REACTANT_GPU_PREALLOCATE" ]
130
+ )
131
+ @debug " XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[]
132
+ end
130
133
131
- if haskey (ENV , " REACTANT_VISIBLE_GPU_DEVICES" )
132
- global_state. local_gpu_device_ids =
133
- parse .(Int, split (ENV [" REACTANT_VISIBLE_GPU_DEVICES" ], " ," ))
134
- @debug " REACTANT_VISIBLE_GPU_DEVICES: " global_state. local_gpu_device_ids
135
- end
134
+ if haskey (ENV , " REACTANT_VISIBLE_GPU_DEVICES" )
135
+ global_state. local_gpu_device_ids =
136
+ parse .(Int, split (ENV [" REACTANT_VISIBLE_GPU_DEVICES" ], " ," ))
137
+ @debug " REACTANT_VISIBLE_GPU_DEVICES: " global_state. local_gpu_device_ids
138
+ end
136
139
137
- @debug " REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME
140
+ @debug " REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME
138
141
139
- @ccall MLIR. API. mlir_c. RegisterEnzymeXLACPUHandler ():: Cvoid
140
- @ccall MLIR. API. mlir_c. RegisterEnzymeXLAGPUHandler ():: Cvoid
142
+ @ccall MLIR. API. mlir_c. RegisterEnzymeXLACPUHandler ():: Cvoid
143
+ @ccall MLIR. API. mlir_c. RegisterEnzymeXLAGPUHandler ():: Cvoid
141
144
142
- @static if ! Sys. isapple ()
143
- lljit = Enzyme. LLVM. JuliaOJIT ()
144
- jd_main = Enzyme. LLVM. JITDylib (lljit)
145
+ @static if ! Sys. isapple ()
146
+ lljit = Enzyme. LLVM. JuliaOJIT ()
147
+ jd_main = Enzyme. LLVM. JITDylib (lljit)
145
148
146
- for name in (" XLAExecute" , " XLAExecuteSharded" , " ifrt_loaded_executable_execute" )
147
- ptr = Libdl. dlsym (Reactant_jll. libReactantExtra_handle, name)
148
- Enzyme. LLVM. define (
149
- jd_main,
150
- Enzyme. Compiler. JIT. absolute_symbol_materialization (
151
- Enzyme. LLVM. mangle (lljit, name), ptr
152
- ),
153
- )
149
+ for name in
150
+ (" XLAExecute" , " XLAExecuteSharded" , " ifrt_loaded_executable_execute" )
151
+ ptr = Libdl. dlsym (Reactant_jll. libReactantExtra_handle, name)
152
+ Enzyme. LLVM. define (
153
+ jd_main,
154
+ Enzyme. Compiler. JIT. absolute_symbol_materialization (
155
+ Enzyme. LLVM. mangle (lljit, name), ptr
156
+ ),
157
+ )
158
+ end
154
159
end
155
160
end
156
161
0 commit comments