From d6fdbbb75206496b71912b68ffa5e7d75db1d0e7 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 23 Mar 2025 00:39:29 +0000 Subject: [PATCH 1/3] fix creation of traced values --- src/TracedUtils.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index a7f4d4997d..1bb9479d38 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -312,23 +312,29 @@ function make_mlir_fn( seen_results = OrderedIdDict() - traced_result = Reactant.make_tracer( - seen_results, - result, - (:result,), - concretein ? Reactant.NoStopTracedTrack : Reactant.TracedSetPath; - runtime, - ) - - # marks buffers to be donated - for i in 1:N - Reactant.make_tracer( + MLIR.IR.activate!(fnbody) + traced_result = try + traced_result = Reactant.make_tracer( seen_results, - traced_args[i], - concretein ? (:resargs, i) : (), - Reactant.NoStopTracedTrack; + result, + (:result,), + concretein ? Reactant.NoStopTracedTrack : Reactant.TracedSetPath; runtime, ) + + # marks buffers to be donated + for i in 1:N + Reactant.make_tracer( + seen_results, + traced_args[i], + concretein ? (:resargs, i) : (), + Reactant.NoStopTracedTrack; + runtime, + ) + end + traced_result + finally + MLIR.IR.deactivate!(fnbody) end linear_results = Reactant.TracedType[] From fcbe709010b6fbf938375cd8964a7f78593a0681 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 23 Mar 2025 01:45:29 +0000 Subject: [PATCH 2/3] make function before creating the constants in args --- src/TracedUtils.jl | 47 ++++++++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 1bb9479d38..28e590d32d 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -211,18 +211,36 @@ function make_mlir_fn( num_partitions, num_replicas = 1, 1 + func = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=name * "_tmp", + function_type=MLIR.IR.FunctionType(in_tys, []), + body=MLIR.IR.Region(), + ) + end + + fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) + push!(MLIR.IR.region(func, 1), fnbody) + Ops.activate_constant_context!(fnbody) + N = length(args) seen_args = OrderedIdDict() traced_args = Vector{Any}(undef, N) - for i in 1:N - @inbounds traced_args[i] = Reactant.make_tracer( - seen_args, - args[i], - (:args, i), - concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath; - toscalar, - runtime, - ) + + try + for i in 1:N + @inbounds traced_args[i] = Reactant.make_tracer( + seen_args, + args[i], + (:args, i), + concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath; + toscalar, + runtime, + ) + end + finally + MLIR.IR.deactivate!(fnbody) + Ops.deactivate_constant_context!(fnbody) end linear_args = Reactant.TracedType[] @@ -264,17 +282,6 @@ function make_mlir_fn( end end - func = MLIR.IR.block!(MLIR.IR.body(mod)) do - return MLIR.Dialects.func.func_(; - sym_name=name * "_tmp", - function_type=MLIR.IR.FunctionType(in_tys, []), - body=MLIR.IR.Region(), - ) - end - - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) - push!(MLIR.IR.region(func, 1), fnbody) - Ops.activate_constant_context!(fnbody) @assert MLIR.IR._has_block() From f29de0c78f9b152b04b5b78d6d903e398b88228d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 23 Mar 2025 03:03:17 +0000 Subject: [PATCH 3/3] Update TracedUtils.jl --- src/TracedUtils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 28e590d32d..c005190cc7 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -211,6 +211,9 @@ function make_mlir_fn( num_partitions, num_replicas = 1, 1 + ctx = MLIR.IR.context() + mod = MLIR.IR.mmodule() + func = MLIR.IR.block!(MLIR.IR.body(mod)) do return MLIR.Dialects.func.func_(; sym_name=name * "_tmp", @@ -265,9 +268,6 @@ function make_mlir_fn( sym_visibility = MLIR.IR.Attribute("private") end - ctx = MLIR.IR.context() - mod = MLIR.IR.mmodule() - # Insert meshes for the sharded arguments traced_args_to_shardings = OrderedIdDict() for (k, v) in seen_args