Skip to content

Commit

Permalink
respect scopping rules in for (#310)
Browse files Browse the repository at this point in the history
* respect scopping rules in for

* `@isdefined`

* Update ReactantCore.jl

* fix
  • Loading branch information
Pangoraw authored Jan 18, 2025
1 parent 01a5646 commit 3481d1d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
18 changes: 15 additions & 3 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,16 @@ function trace_for(mod, expr)
external_syms...,
)

cond_val(s) = :(@isdefined($s) ? $s : nothing)

while_defined = gensym(:while_defined)
locals = Expr[
[Expr(:(=), s, cond_val(s)) for s in external_syms]..., :(args = $(args_init))
]

var_syms = all_syms.args[(begin + 1):end]
reactant_code_block = quote
let args = $(args_init)
let $(locals...)
cond_fn =
$(all_syms) -> begin
local num_iters = div($limit - $start, $step, RoundDown)
Expand All @@ -170,19 +178,23 @@ function trace_for(mod, expr)
end
body_fn =
$(all_syms) -> begin
local isdefined_before = isnothing.(Any[$(var_syms...)])
local step_ = $step
local start_ = $start
local $induction = start_ + $counter * step_
$body
($counter + 1, $(all_syms.args[(begin + 1):end]...))
local results_ = Any[
s for (d, s) in zip(isdefined_before, Any[$(var_syms...)]) if !d
]
($counter + 1, results_...)
end

$(ReactantCore).traced_while(cond_fn, body_fn, args)
end
end

return quote
if any($(is_traced), $(Expr(:tuple, all_syms.args[(begin + 1):end]...)))
if any($(is_traced), $(Expr(:tuple, cond_val.(all_syms.args[(begin + 1):end])...)))
$(reactant_code_block)
else
$(expr)
Expand Down
13 changes: 9 additions & 4 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,21 @@ function ReactantCore.traced_while(

result_0 = in_tys

operands = MLIR.IR.Value[v.mlir_data for v in traced_args]
operands = MLIR.IR.Value[v.mlir_data for v in traced_args if v isa TracedType]

while_compiled = MLIR.Dialects.stablehlo.while_(
operands; result_0, cond=cond_reg, body=body_reg
)

return map(enumerate(traced_args)) do (i, res)
res.mlir_data = MLIR.IR.result(while_compiled, i)
return res
residx = 1
for res in traced_args
if res isa TracedType
res.mlir_data = MLIR.IR.result(while_compiled, residx)
residx += 1
end
end

return traced_args
end

function take_region(compiled_fn)
Expand Down
15 changes: 15 additions & 0 deletions test/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,18 @@ end

@test @jit(for_ref_outer(x_ra)) for_ref_outer(x)
end

function for_inner_scope(x)
@trace for i in 1:10
s = sum(x)
x = x / s
end
return x
end

@testset "for: inner scope" begin
x = randn(Float64, 10)
x_ra = Reactant.to_rarray(x)

@test @jit(for_inner_scope(x_ra)) for_inner_scope(x)
end

0 comments on commit 3481d1d

Please sign in to comment.