Skip to content

Commit d7f841b

Browse files
committed
feat: early fail if not correct region
1 parent 4c0e77f commit d7f841b

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

deps/ReactantExtra/API.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -2404,3 +2404,7 @@ extern "C" void dump_operation(Operation *op, const char *filename) {
24042404
extern "C" bool pjrt_device_is_addressable(PjRtDevice *device) {
24052405
return device->IsAddressable();
24062406
}
2407+
2408+
extern "C" mlir::Operation *mlirGetParentOfTypeFunctionOp(mlir::Operation *op) {
2409+
return op->getParentOfType<mlir::FunctionOpInterface>();
2410+
}

src/Ops.jl

+5
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ end
126126
result_inference=false,
127127
)
128128

129+
parent_func_op = MLIR.IR.get_parent_of_type_function_op(cstop)
130+
if parent_func_op == C_NULL
131+
error("Constant must be created inside a Function Op.")
132+
end
133+
129134
res = MLIR.IR.result(cstop)
130135
tres = TracedRArray{T,N}((), res, size(x))
131136
constants[value] = tres

src/mlir/IR/Operation.jl

+24-2
Original file line numberDiff line numberDiff line change
@@ -331,16 +331,38 @@ function create_operation_common(
331331
end
332332
end
333333

334+
function create_operation_common_with_checks(args...; operands=nothing, kwargs...)
335+
op = create_operation_common(args...; operands, kwargs...)
336+
# if !isnothing(operands)
337+
# parent_function_op = get_parent_of_type_function_op(op)
338+
# if parent_function_op != C_NULL
339+
# function_op_region = parent_region(parent_function_op)
340+
# # TODO: add the checks
341+
# end
342+
# end
343+
return op
344+
end
345+
334346
function create_operation(args...; kwargs...)
335-
res = create_operation_common(args...; kwargs...)
347+
res = create_operation_common_with_checks(args...; kwargs...)
336348
if _has_block()
337349
push!(block(), res)
338350
end
339351
return res
340352
end
341353

342354
function create_operation_at_front(args...; kwargs...)
343-
res = create_operation_common(args...; kwargs...)
355+
res = create_operation_common_with_checks(args...; kwargs...)
344356
Base.pushfirst!(block(), res)
345357
return res
346358
end
359+
360+
function get_parent_of_type_function_op(op::Operation)
361+
GC.@preserve op begin
362+
funcop = @ccall API.mlir_c.mlirGetParentOfTypeFunctionOp(
363+
op::API.MlirOperation
364+
)::API.MlirOperation
365+
end
366+
funcop.ptr == C_NULL && return C_NULL
367+
return Operation(funcop, false)
368+
end

0 commit comments

Comments
 (0)