Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid excessive inlining by moving YaoBlocks.mat to func.func defs #344

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 161 additions & 34 deletions ext/ReactantYaoBlocksExt.jl
Original file line number Diff line number Diff line change
@@ -1,42 +1,169 @@
module ReactantYaoBlocksExt

using Reactant
using Reactant: TracedRArray, TracedRNumber
using Reactant.MLIR: IR
using Reactant.MLIR.Dialects: func
using YaoBlocks

function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:XGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
c = cos(R.theta / 2)
s = -im * sin(R.theta / 2)
M[1, 1] = c
M[2, 2] = c
M[1, 2] = s
M[2, 1] = s
return M
end

function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:YGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
c = cos(R.theta / 2)
s = sin(R.theta / 2)
M[1, 1] = c
M[2, 2] = c
M[1, 2] = -s
M[2, 1] = s
return M
end

function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:ZGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
x = exp(im * R.theta / 2)
M[1, 1] = conj(x)
M[2, 2] = x
return M
function module_top()
if !haskey(task_local_storage(), :mlir_module)
error("No MLIR module is active")
end
return first(task_local_storage(:mlir_module))
end

function symname(name, ::Type{ParamType}, ::Type{OutElType}) where {ParamType,OutElType}
return name * "_" * string(ParamType) * "_" * string(OutElType)
end

function codegen!(
::Val{:rx}, ::Type{ParamType}, ::Type{OutElType}
) where {ParamType,OutElType}
in_tys = [IR.TensorType((), IR.Type(ParamType))]
out_tys = [IR.TensorType((2, 2), IR.Type(OutElType))]

mod = module_top()
IR.block!(IR.body(mod)) do
f = func.func_(;
sym_name=symname("rx", ParamType, OutElType),
function_type=IR.FunctionType(in_tys, out_tys),
body=IR.Region(),
)

fbody = IR.Block(in_tys, [IR.Location()])
push!(IR.region(f, 1), fbody)

IR.block!(fbody) do
θ = TracedRNumber{ParamType}((), IR.argument(fbody, 1))
M = Reactant.broadcast_to_size(zero(OutElType), (2, 2))
c = cos(θ / 2)
s = -im * sin(θ / 2)
M[1, 1] = c
M[2, 2] = c
M[1, 2] = s
M[2, 1] = s
func.return_([M.mlir_data])
end

return f
end
end

function codegen!(
::Val{:ry}, ::Type{ParamType}, ::Type{OutElType}
) where {ParamType,OutElType}
in_tys = [IR.TensorType((), IR.Type(ParamType))]
out_tys = [IR.TensorType((2, 2), IR.Type(OutElType))]

mod = module_top()
IR.block!(IR.body(mod)) do
f = func.func_(;
sym_name=symname("ry", ParamType, OutElType),
function_type=IR.FunctionType(in_tys, out_tys),
body=IR.Region(),
)

fbody = IR.Block(in_tys, [IR.Location()])
push!(IR.region(f, 1), fbody)

IR.block!(fbody) do
θ = TracedRNumber{ParamType}((), IR.argument(fbody, 1))
M = Reactant.broadcast_to_size(zero(OutElType), (2, 2))
c = cos(θ / 2)
s = sin(θ / 2)
M[1, 1] = c
M[2, 2] = c
M[1, 2] = -s
M[2, 1] = s
func.return_([M.mlir_data])
end

return f
end
end

function codegen!(
::Val{:rz}, ::Type{ParamType}, ::Type{OutElType}
) where {ParamType,OutElType}
in_tys = [IR.TensorType((), IR.Type(ParamType))]
out_tys = [IR.TensorType((2, 2), IR.Type(OutElType))]

mod = module_top()
IR.block!(IR.body(mod)) do
f = func.func_(;
sym_name=symname("rz", ParamType, OutElType),
function_type=IR.FunctionType(in_tys, out_tys),
body=IR.Region(),
)

fbody = IR.Block(in_tys, [IR.Location()])
push!(IR.region(f, 1), fbody)

IR.block!(fbody) do
θ = TracedRNumber{ParamType}((), IR.argument(fbody, 1))
M = Reactant.broadcast_to_size(zero(OutElType), (2, 2))
x = exp(im * θ / 2)
M[1, 1] = conj(x)
M[2, 2] = x
func.return_([M.mlir_data])
end

return f
end
end

function hasfunc(name, ::Type{ParamType}, ::Type{OutElType}) where {ParamType,OutElType}
it = IR.OperationIterator(IR.body(module_top()))
return any(it) do op
IR.name(op) == "func.func" || return false

String(IR.attr(op, 2).named_attribute.name) == "sym_name" ||
error("expected sym_name attribute")

_symname = String(IR.Attribute(IR.attr(op, 2).named_attribute.attribute))
_symname == symname(name, ParamType, OutElType) || return false
return true
end
Comment on lines +117 to +127
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to use a SymbolTable for this: https://github.com/EnzymeAD/Reactant.jl/blob/main/src/mlir/IR/SymbolTable.jl

end

function YaoBlocks.mat(::Type{T}, R::RotationGate{D,TracedRNumber{S},<:XGate}) where {D,T,S}
hasfunc("rx", S, T) || codegen!(Val(:rx), S, T)

res = IR.result(
func.call(
[R.theta.mlir_data];
result_0=[IR.TensorType((2, 2), IR.Type(T))],
callee=IR.FlatSymbolRefAttribute(symname("rx", S, T)),
),
)
return TracedRArray{T,2}((), res, (2, 2))
end

function YaoBlocks.mat(::Type{T}, R::RotationGate{D,TracedRNumber{S},<:YGate}) where {D,T,S}
hasfunc("ry", S, T) || codegen!(Val(:ry), S, T)

res = IR.result(
func.call(
[R.theta.mlir_data];
result_0=[IR.TensorType((2, 2), IR.Type(T))],
callee=IR.FlatSymbolRefAttribute(symname("ry", S, T)),
),
)
return TracedRArray{T,2}((), res, (2, 2))
end

function YaoBlocks.mat(::Type{T}, R::RotationGate{D,TracedRNumber{S},<:ZGate}) where {D,T,S}
hasfunc("rz", S, T) || codegen!(Val(:rz), S, T)

res = IR.result(
func.call(
[R.theta.mlir_data];
result_0=[IR.TensorType((2, 2), IR.Type(T))],
callee=IR.FlatSymbolRefAttribute(symname("rz", S, T)),
),
)
return TracedRArray{T,2}((), res, (2, 2))
end

end
2 changes: 1 addition & 1 deletion src/mlir/IR/Operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ rmattr!(operation::Operation, name) =
API.mlirOperationRemoveAttributeByName(operation, name)

function lose_ownership!(operation::Operation)
@assert operation.owned
# @assert operation.owned
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is causing this ? This seems suspicious

@atomic operation.owned = false
return operation
end
Expand Down
Loading