Skip to content

Commit ef36486

Browse files
committed
feat: more constant checks
1 parent d7f841b commit ef36486

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/Ops.jl

+12
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,12 @@ for (T, mlir_func) in (
206206

207207
splatattr = MLIR.API.$mlir_func(tt, number)
208208
cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)
209+
210+
parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op)
211+
if parent_func_op == C_NULL
212+
error("Constant must be created inside a Function Op.")
213+
end
214+
209215
cst = MLIR.IR.result(cst_op)
210216
ta = TracedRArray{$T,length(shape)}((), cst, shape)
211217
return ta
@@ -226,6 +232,12 @@ end
226232
tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
227233
splatattr = MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element))
228234
cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)
235+
236+
parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op)
237+
if parent_func_op == C_NULL
238+
error("Constant must be created inside a Function Op.")
239+
end
240+
229241
cst = MLIR.IR.result(cst_op)
230242
ta = TracedRArray{T,length(shape)}((), cst, shape)
231243
return ta

0 commit comments

Comments
 (0)