Skip to content

Commit 2dc75bb

Browse files
wsmosesgiordano
authored andcommitted
Add backend tls
1 parent 0f2b278 commit 2dc75bb

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ function ka_with_reactant(ndrange, workgroupsize, obj, args...)
352352

353353
# figure out the optimal workgroupsize automatically
354354
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
355-
if !Reactant.Compiler.PartitionKA[] || raising()
355+
if !Reactant.Compiler.PartitionKA[] || raising() || backend() in ("cpu", "tpu")
356356
threads = prod(ndrange)
357357
else
358358
config = CUDA.launch_configuration(kernel.fun; max_threads=prod(ndrange))

src/Compiler.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,38 @@ function raising!(f, is_raising::Bool)
700700
end
701701
end
702702

703+
function activate_backend!(backend::String)
704+
stack = get!(task_local_storage(), :reactant_backend) do
705+
String[]
706+
end
707+
push!(stack, backend)
708+
return nothing
709+
end
710+
711+
function deactivate_backend!(backend::String)
712+
key = :reactant_backend
713+
backend === last(task_local_storage(key)) ||
714+
error("Deactivating wrong Reactant backend context")
715+
return pop!(task_local_storage(key))
716+
end
717+
718+
function backend(; throw_error::Bool=true)
719+
key = :reactant_backend
720+
if !(haskey(task_local_storage(), key) && !Base.isempty(task_local_storage(key)))
721+
throw_error && error("No Reactant backend context")
722+
end
723+
return last(task_local_storage(key)::Vector{Bool})
724+
end
725+
726+
function backend!(f, backend::String)
727+
activate_backend!(backend)
728+
try
729+
return f()
730+
finally
731+
deactivate_backend!(backend)
732+
end
733+
end
734+
703735
function compile_mlir!(
704736
mod,
705737
f,
@@ -747,12 +779,14 @@ function compile_mlir!(
747779
end
748780
is_raising = raise isa String || raise
749781
activate_raising!(is_raising)
782+
activate_backend!(backend)
750783

751784
mlir_fn_res = try
752785
Reactant.TracedUtils.make_mlir_fn(
753786
f, args, fn_kwargs, "main", true; input_shardings, runtime
754787
)
755788
finally
789+
deactivate_backend!(backend)
756790
deactivate_raising!(is_raising)
757791
deactivate_sdycache!(sdycache)
758792
deactivate_callcache!(callcache)

0 commit comments

Comments
 (0)