Skip to content

Commit c705436

Browse files
Constant cse (EnzymeAD#865)
* Constant cse * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix test * traced if constants * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 8e1d45d commit c705436

File tree

4 files changed

+98
-9
lines changed

4 files changed

+98
-9
lines changed

src/Ops.jl

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,56 @@ struct Token
6767
mlir_data::MLIR.IR.Value
6868
end
6969

70+
function activate_constant_context!(blk::MLIR.IR.Block)
71+
stack = get!(task_local_storage(), :entry_block) do
72+
return Tuple{MLIR.IR.Block,Dict{MLIR.IR.Attribute,TracedRArray}}[]
73+
end
74+
Base.push!(stack, (blk, Dict{MLIR.IR.Attribute,TracedRArray}()))
75+
return nothing
76+
end
77+
78+
function constant_context(; throw_error::Core.Bool=true)
79+
return last(task_local_storage(:entry_block))
80+
end
81+
82+
function deactivate_constant_context!(blk::MLIR.IR.Block)
83+
constant_context()[1] == blk || error("Deactivating wrong block")
84+
return Base.pop!(task_local_storage(:entry_block))
85+
end
86+
7087
# constant ops
7188
@noinline function constant(
7289
x::DenseArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
7390
) where {T,N}
7491
value = MLIR.IR.DenseElementsAttribute(x)
75-
output = mlir_type(TracedRArray{T,N}, size(x))
76-
res = MLIR.IR.result(stablehlo.constant(; output, value, location))
77-
return TracedRArray{T,N}((), res, size(x))
92+
constants = constant_context()[2]
93+
if haskey(constants, value)
94+
return constants[value]
95+
else
96+
output = mlir_type(TracedRArray{T,N}, size(x))
97+
98+
op_ty_results = MLIR.IR.Type[output]
99+
operands = MLIR.IR.Value[]
100+
owned_regions = MLIR.IR.Region[]
101+
successors = MLIR.IR.Block[]
102+
attributes = MLIR.IR.NamedAttribute[MLIR.Dialects.namedattribute("value", value),]
103+
104+
cstop = MLIR.IR.create_operation(
105+
"stablehlo.constant",
106+
location;
107+
operands,
108+
owned_regions,
109+
successors,
110+
attributes,
111+
results=op_ty_results,
112+
result_inference=false,
113+
)
114+
115+
res = MLIR.IR.result(cstop)
116+
tres = TracedRArray{T,N}((), res, size(x))
117+
constants[value] = tres
118+
return tres
119+
end
78120
end
79121

80122
@noinline function constant(
@@ -1764,6 +1806,7 @@ end
17641806
true_fn_args = true_fn_names[1]
17651807

17661808
MLIR.IR.activate!(true_fn_body)
1809+
Ops.activate_constant_context!(true_fn_body)
17671810
tb_result = try
17681811
for (i, arg) in enumerate(tb_linear_args)
17691812
# find the right path to index the traced arg.
@@ -1787,6 +1830,7 @@ end
17871830
end
17881831
Reactant.call_with_reactant(true_fn, tb_traced_args...)
17891832
finally
1833+
Ops.deactivate_constant_context!(true_fn_body)
17901834
MLIR.IR.deactivate!(true_fn_body)
17911835
end
17921836

@@ -1827,6 +1871,7 @@ end
18271871

18281872
false_fn_args = false_fn_names[1]
18291873
MLIR.IR.activate!(false_fn_body)
1874+
Ops.activate_constant_context!(false_fn_body)
18301875
fb_result = try
18311876
for (i, arg) in enumerate(fb_linear_args)
18321877
# find the right path to index the traced arg.
@@ -1850,6 +1895,7 @@ end
18501895
end
18511896
Reactant.call_with_reactant(false_fn, fb_traced_args...)
18521897
finally
1898+
Ops.deactivate_constant_context!(false_fn_body)
18531899
MLIR.IR.deactivate!(false_fn_body)
18541900
end
18551901

@@ -1928,6 +1974,7 @@ end
19281974

19291975
# finalize the true branch by adding the missing values
19301976
MLIR.IR.activate!(true_fn_body)
1977+
Ops.activate_constant_context!(true_fn_body)
19311978
tb_corrected_linear_results = Reactant.TracedType[]
19321979
try
19331980
for (i, path) in enumerate(tb_paths)
@@ -1939,10 +1986,12 @@ end
19391986
end
19401987
finally
19411988
MLIR.IR.deactivate!(true_fn_body)
1989+
Ops.deactivate_constant_context!(true_fn_body)
19421990
end
19431991

19441992
# finalize the false branch by adding the missing values
19451993
MLIR.IR.activate!(false_fn_body)
1994+
Ops.activate_constant_context!(false_fn_body)
19461995
fb_corrected_linear_results = Reactant.TracedType[]
19471996
try
19481997
for (i, path) in enumerate(fb_paths)
@@ -1954,6 +2003,7 @@ end
19542003
end
19552004
finally
19562005
MLIR.IR.deactivate!(false_fn_body)
2006+
Ops.deactivate_constant_context!(false_fn_body)
19572007
end
19582008

19592009
# All MissingTracedValues must be replaced with zeroes
@@ -1968,19 +2018,23 @@ end
19682018
res = if tr isa MissingTracedValue
19692019
@assert !(fr isa MissingTracedValue)
19702020
MLIR.IR.activate!(true_fn_body)
2021+
Ops.activate_constant_context!(true_fn_body)
19712022
try
19722023
tb_corrected_linear_results[i] = zero(fr)
19732024
finally
19742025
MLIR.IR.deactivate!(true_fn_body)
2026+
Ops.deactivate_constant_context!(true_fn_body)
19752027
end
19762028
fr
19772029
elseif fr isa MissingTracedValue
19782030
@assert !(tr isa MissingTracedValue)
19792031
MLIR.IR.activate!(false_fn_body)
2032+
Ops.activate_constant_context!(false_fn_body)
19802033
try
19812034
fb_corrected_linear_results[i] = zero(tr)
19822035
finally
19832036
MLIR.IR.deactivate!(false_fn_body)
2037+
Ops.deactivate_constant_context!(false_fn_body)
19842038
end
19852039
tr
19862040
else
@@ -1993,6 +2047,7 @@ end
19932047
end
19942048

19952049
MLIR.IR.activate!(true_fn_body)
2050+
Ops.activate_constant_context!(true_fn_body)
19962051
try
19972052
vals = MLIR.IR.Value[
19982053
Reactant.TracedUtils.get_mlir_data(res) for
@@ -2001,9 +2056,11 @@ end
20012056
MLIR.Dialects.stablehlo.return_(vals)
20022057
finally
20032058
MLIR.IR.deactivate!(true_fn_body)
2059+
Ops.deactivate_constant_context!(true_fn_body)
20042060
end
20052061

20062062
MLIR.IR.activate!(false_fn_body)
2063+
Ops.activate_constant_context!(false_fn_body)
20072064
try
20082065
vals = MLIR.IR.Value[
20092066
Reactant.TracedUtils.get_mlir_data(res) for
@@ -2012,6 +2069,7 @@ end
20122069
MLIR.Dialects.stablehlo.return_(vals)
20132070
finally
20142071
MLIR.IR.deactivate!(false_fn_body)
2072+
Ops.deactivate_constant_context!(false_fn_body)
20152073
end
20162074

20172075
# With the corrected results, we can compile the true and false branches

src/TracedUtils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ function make_mlir_fn(
244244

245245
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args])
246246
push!(MLIR.IR.region(func, 1), fnbody)
247+
Ops.activate_constant_context!(fnbody)
247248

248249
@assert MLIR.IR._has_block()
249250

@@ -265,6 +266,7 @@ function make_mlir_fn(
265266
end
266267
finally
267268
MLIR.IR.deactivate!(fnbody)
269+
Ops.deactivate_constant_context!(fnbody)
268270
end
269271

270272
# check which arguments have been mutated

src/mlir/IR/Operation.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ This will return true if the dialect is loaded and the operation is registered w
274274
is_registered(opname; context::Context=context()) =
275275
API.mlirContextIsRegisteredOperation(context, opname)
276276

277-
function create_operation(
277+
function create_operation_common(
278278
name,
279279
loc;
280280
results=nothing,
@@ -320,10 +320,20 @@ function create_operation(
320320
if mlirIsNull(op)
321321
error("Create Operation '$name' failed")
322322
end
323-
res = Operation(op, true)
324-
if _has_block()
325-
push!(block(), res)
326-
end
327-
return res
323+
return Operation(op, true)
328324
end
329325
end
326+
327+
function create_operation(args...; kwargs...)
328+
res = create_operation_common(args...; kwargs...)
329+
if _has_block()
330+
push!(block(), res)
331+
end
332+
return res
333+
end
334+
335+
function create_operation_at_front(args...; kwargs...)
336+
res = create_operation_common(args...; kwargs...)
337+
Base.pushfirst!(block(), res)
338+
return res
339+
end

test/ops.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,3 +1101,22 @@ end
11011101
r = reduce(+, A; dims=dims, init=init)
11021102
@test r_hlo squeeze_dims(r)
11031103
end
1104+
1105+
@testset "const dedup" begin
1106+
x = Reactant.to_rarray([11, 12, 13, 14])
1107+
function const_dedup(x)
1108+
c1 = [1, 2, 3, 4]
1109+
y1 = (x .+ c1)
1110+
c2 = [1, 2, 3, 4]
1111+
y2 = (x .+ c2)
1112+
c1[1] = 6
1113+
return y1 .* y2 .* c1
1114+
end
1115+
1116+
mod = @code_hlo optimize = false const_dedup(x)
1117+
hlo_ir = repr(mod)
1118+
csts = collect(x for x in eachsplit(hlo_ir, "\n") if occursin("stablehlo.constant", x))
1119+
@test length(csts) == 2
1120+
@test occursin("1, 2, 3, 4", csts[1])
1121+
@test occursin("6, 2, 3, 4", csts[2])
1122+
end

0 commit comments

Comments
 (0)