@@ -67,14 +67,56 @@ struct Token
67
67
mlir_data:: MLIR.IR.Value
68
68
end
69
69
70
+ function activate_constant_context! (blk:: MLIR.IR.Block )
71
+ stack = get! (task_local_storage (), :entry_block ) do
72
+ return Tuple{MLIR. IR. Block,Dict{MLIR. IR. Attribute,TracedRArray}}[]
73
+ end
74
+ Base. push! (stack, (blk, Dict {MLIR.IR.Attribute,TracedRArray} ()))
75
+ return nothing
76
+ end
77
+
78
+ function constant_context (; throw_error:: Core.Bool = true )
79
+ return last (task_local_storage (:entry_block ))
80
+ end
81
+
82
+ function deactivate_constant_context! (blk:: MLIR.IR.Block )
83
+ constant_context ()[1 ] == blk || error (" Deactivating wrong block" )
84
+ return Base. pop! (task_local_storage (:entry_block ))
85
+ end
86
+
70
87
# constant ops
71
88
@noinline function constant (
72
89
x:: DenseArray{T,N} ; location= mlir_stacktrace (" constant" , @__FILE__ , @__LINE__ )
73
90
) where {T,N}
74
91
value = MLIR. IR. DenseElementsAttribute (x)
75
- output = mlir_type (TracedRArray{T,N}, size (x))
76
- res = MLIR. IR. result (stablehlo. constant (; output, value, location))
77
- return TracedRArray {T,N} ((), res, size (x))
92
+ constants = constant_context ()[2 ]
93
+ if haskey (constants, value)
94
+ return constants[value]
95
+ else
96
+ output = mlir_type (TracedRArray{T,N}, size (x))
97
+
98
+ op_ty_results = MLIR. IR. Type[output]
99
+ operands = MLIR. IR. Value[]
100
+ owned_regions = MLIR. IR. Region[]
101
+ successors = MLIR. IR. Block[]
102
+ attributes = MLIR. IR. NamedAttribute[MLIR. Dialects. namedattribute (" value" , value),]
103
+
104
+ cstop = MLIR. IR. create_operation (
105
+ " stablehlo.constant" ,
106
+ location;
107
+ operands,
108
+ owned_regions,
109
+ successors,
110
+ attributes,
111
+ results= op_ty_results,
112
+ result_inference= false ,
113
+ )
114
+
115
+ res = MLIR. IR. result (cstop)
116
+ tres = TracedRArray {T,N} ((), res, size (x))
117
+ constants[value] = tres
118
+ return tres
119
+ end
78
120
end
79
121
80
122
@noinline function constant (
@@ -1764,6 +1806,7 @@ end
1764
1806
true_fn_args = true_fn_names[1 ]
1765
1807
1766
1808
MLIR. IR. activate! (true_fn_body)
1809
+ Ops. activate_constant_context! (true_fn_body)
1767
1810
tb_result = try
1768
1811
for (i, arg) in enumerate (tb_linear_args)
1769
1812
# find the right path to index the traced arg.
@@ -1787,6 +1830,7 @@ end
1787
1830
end
1788
1831
Reactant. call_with_reactant (true_fn, tb_traced_args... )
1789
1832
finally
1833
+ Ops. deactivate_constant_context! (true_fn_body)
1790
1834
MLIR. IR. deactivate! (true_fn_body)
1791
1835
end
1792
1836
@@ -1827,6 +1871,7 @@ end
1827
1871
1828
1872
false_fn_args = false_fn_names[1 ]
1829
1873
MLIR. IR. activate! (false_fn_body)
1874
+ Ops. activate_constant_context! (false_fn_body)
1830
1875
fb_result = try
1831
1876
for (i, arg) in enumerate (fb_linear_args)
1832
1877
# find the right path to index the traced arg.
@@ -1850,6 +1895,7 @@ end
1850
1895
end
1851
1896
Reactant. call_with_reactant (false_fn, fb_traced_args... )
1852
1897
finally
1898
+ Ops. deactivate_constant_context! (false_fn_body)
1853
1899
MLIR. IR. deactivate! (false_fn_body)
1854
1900
end
1855
1901
@@ -1928,6 +1974,7 @@ end
1928
1974
1929
1975
# finalize the true branch by adding the missing values
1930
1976
MLIR. IR. activate! (true_fn_body)
1977
+ Ops. activate_constant_context! (true_fn_body)
1931
1978
tb_corrected_linear_results = Reactant. TracedType[]
1932
1979
try
1933
1980
for (i, path) in enumerate (tb_paths)
@@ -1939,10 +1986,12 @@ end
1939
1986
end
1940
1987
finally
1941
1988
MLIR. IR. deactivate! (true_fn_body)
1989
+ Ops. deactivate_constant_context! (true_fn_body)
1942
1990
end
1943
1991
1944
1992
# finalize the false branch by adding the missing values
1945
1993
MLIR. IR. activate! (false_fn_body)
1994
+ Ops. activate_constant_context! (false_fn_body)
1946
1995
fb_corrected_linear_results = Reactant. TracedType[]
1947
1996
try
1948
1997
for (i, path) in enumerate (fb_paths)
@@ -1954,6 +2003,7 @@ end
1954
2003
end
1955
2004
finally
1956
2005
MLIR. IR. deactivate! (false_fn_body)
2006
+ Ops. deactivate_constant_context! (false_fn_body)
1957
2007
end
1958
2008
1959
2009
# All MissingTracedValues must be replaced with zeroes
@@ -1968,19 +2018,23 @@ end
1968
2018
res = if tr isa MissingTracedValue
1969
2019
@assert ! (fr isa MissingTracedValue)
1970
2020
MLIR. IR. activate! (true_fn_body)
2021
+ Ops. activate_constant_context! (true_fn_body)
1971
2022
try
1972
2023
tb_corrected_linear_results[i] = zero (fr)
1973
2024
finally
1974
2025
MLIR. IR. deactivate! (true_fn_body)
2026
+ Ops. deactivate_constant_context! (true_fn_body)
1975
2027
end
1976
2028
fr
1977
2029
elseif fr isa MissingTracedValue
1978
2030
@assert ! (tr isa MissingTracedValue)
1979
2031
MLIR. IR. activate! (false_fn_body)
2032
+ Ops. activate_constant_context! (false_fn_body)
1980
2033
try
1981
2034
fb_corrected_linear_results[i] = zero (tr)
1982
2035
finally
1983
2036
MLIR. IR. deactivate! (false_fn_body)
2037
+ Ops. deactivate_constant_context! (false_fn_body)
1984
2038
end
1985
2039
tr
1986
2040
else
@@ -1993,6 +2047,7 @@ end
1993
2047
end
1994
2048
1995
2049
MLIR. IR. activate! (true_fn_body)
2050
+ Ops. activate_constant_context! (true_fn_body)
1996
2051
try
1997
2052
vals = MLIR. IR. Value[
1998
2053
Reactant. TracedUtils. get_mlir_data (res) for
@@ -2001,9 +2056,11 @@ end
2001
2056
MLIR. Dialects. stablehlo. return_ (vals)
2002
2057
finally
2003
2058
MLIR. IR. deactivate! (true_fn_body)
2059
+ Ops. deactivate_constant_context! (true_fn_body)
2004
2060
end
2005
2061
2006
2062
MLIR. IR. activate! (false_fn_body)
2063
+ Ops. activate_constant_context! (false_fn_body)
2007
2064
try
2008
2065
vals = MLIR. IR. Value[
2009
2066
Reactant. TracedUtils. get_mlir_data (res) for
@@ -2012,6 +2069,7 @@ end
2012
2069
MLIR. Dialects. stablehlo. return_ (vals)
2013
2070
finally
2014
2071
MLIR. IR. deactivate! (false_fn_body)
2072
+ Ops. deactivate_constant_context! (false_fn_body)
2015
2073
end
2016
2074
2017
2075
# With the corrected results, we can compile the true and false branches
0 commit comments