@@ -700,6 +700,38 @@ function raising!(f, is_raising::Bool)
700
700
end
701
701
end
702
702
703
+ function activate_backend! (backend:: String )
704
+ stack = get! (task_local_storage (), :reactant_backend ) do
705
+ String[]
706
+ end
707
+ push! (stack, backend)
708
+ return nothing
709
+ end
710
+
711
+ function deactivate_backend! (backend:: String )
712
+ key = :reactant_backend
713
+ backend === last (task_local_storage (key)) ||
714
+ error (" Deactivating wrong Reactant backend context" )
715
+ return pop! (task_local_storage (key))
716
+ end
717
+
718
+ function backend (; throw_error:: Bool = true )
719
+ key = :reactant_backend
720
+ if ! (haskey (task_local_storage (), key) && ! Base. isempty (task_local_storage (key)))
721
+ throw_error && error (" No Reactant backend context" )
722
+ end
723
+ return last (task_local_storage (key):: Vector{Bool} )
724
+ end
725
+
726
+ function backend! (f, backend:: String )
727
+ activate_backend! (backend)
728
+ try
729
+ return f ()
730
+ finally
731
+ deactivate_backend! (backend)
732
+ end
733
+ end
734
+
703
735
function compile_mlir! (
704
736
mod,
705
737
f,
@@ -747,12 +779,14 @@ function compile_mlir!(
747
779
end
748
780
is_raising = raise isa String || raise
749
781
activate_raising! (is_raising)
782
+ activate_backend! (backend)
750
783
751
784
mlir_fn_res = try
752
785
Reactant. TracedUtils. make_mlir_fn (
753
786
f, args, fn_kwargs, " main" , true ; input_shardings, runtime
754
787
)
755
788
finally
789
+ deactivate_backend! (backend)
756
790
deactivate_raising! (is_raising)
757
791
deactivate_sdycache! (sdycache)
758
792
deactivate_callcache! (callcache)
0 commit comments