Skip to content

Commit

Permalink
multi-bit equality branching
Browse files Browse the repository at this point in the history
  • Loading branch information
arulandu committed Dec 12, 2024
1 parent e1e53ec commit 3a6415b
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 31 deletions.
3 changes: 2 additions & 1 deletion src/pyqasm/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BinaryOperator,
BooleanLiteral,
DiscreteSet,
Expression,
FloatLiteral,
Identifier,
IndexedIdentifier,
Expand Down Expand Up @@ -225,7 +226,7 @@ def transform_gate_params(
gate_op.argument = Qasm3Transformer.transform_expression(gate_op.argument, param_map)

@staticmethod
def get_branch_params(condition: Any) -> tuple[Optional[int], str, Optional[bool]]:
def get_branch_params(condition: Expression) -> tuple[Optional[int], str, Optional[bool]]:
"""
Get the branch parameters from the branching condition
Expand Down
74 changes: 47 additions & 27 deletions src/pyqasm/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,44 +1331,64 @@ def _visit_branching_statement(
# here, the lhs CAN only be a classical register as QCs won't have
# ability to evaluate expressions in the condition

reg_id, reg_name, rhs_value = Qasm3Transformer.get_branch_params(condition)
reg_idx, reg_name, rhs_value = Qasm3Transformer.get_branch_params(condition)

if reg_name not in self._global_creg_size_map:
raise_qasm3_error(
f"Missing register declaration for {reg_name} in {condition}",
span=statement.span,
)
if reg_id is not None:

assert isinstance(rhs_value, (bool, int))

if_block = self.visit_basic_block(statement.if_block)
else_block = self.visit_basic_block(statement.else_block)

if reg_idx is not None:
# single bit branch
Qasm3Validator.validate_register_index(
reg_id, self._global_creg_size_map[reg_name], qubit=False
reg_idx, self._global_creg_size_map[reg_name], qubit=False
)

new_lhs = (
qasm3_ast.IndexExpression(
collection=qasm3_ast.Identifier(name=reg_name),
index=[qasm3_ast.IntegerLiteral(reg_id)],
new_if_block = qasm3_ast.BranchingStatement(
condition=qasm3_ast.BinaryExpression(
op=qasm3_ast.BinaryOperator["=="],
lhs=qasm3_ast.IndexExpression(
collection=qasm3_ast.Identifier(name=reg_name),
index=[qasm3_ast.IntegerLiteral(reg_idx)],
),
rhs=(
qasm3_ast.BooleanLiteral(rhs_value)
if isinstance(rhs_value, bool)
else qasm3_ast.IntegerLiteral(rhs_value)
),
),
if_block=if_block,
else_block=else_block,
)
if reg_id is not None
else qasm3_ast.Identifier(name=reg_name)
)
assert isinstance(rhs_value, (bool, int))
new_rhs = (
qasm3_ast.BooleanLiteral(rhs_value)
if isinstance(rhs_value, bool)
else qasm3_ast.IntegerLiteral(rhs_value)
)

new_if_block = qasm3_ast.BranchingStatement(
condition=qasm3_ast.BinaryExpression(
op=qasm3_ast.BinaryOperator["=="],
lhs=new_lhs,
rhs=new_rhs,
),
if_block=self.visit_basic_block(statement.if_block),
else_block=self.visit_basic_block(statement.else_block),
)
result.append(new_if_block)
result.append(new_if_block)
else:
# unroll multi-bit branch
rhs_value_str = bin(int(rhs_value))[2:][::-1]
else_block = self.visit_basic_block(statement.else_block)

def ravel(i):
r = rhs_value_str[i] == "1"

return qasm3_ast.BranchingStatement(
condition=qasm3_ast.BinaryExpression(
op=qasm3_ast.BinaryOperator["=="],
lhs=qasm3_ast.IndexExpression(
collection=qasm3_ast.Identifier(name=reg_name),
index=[qasm3_ast.IntegerLiteral(i)],
),
rhs=qasm3_ast.BooleanLiteral(r),
),
if_block=if_block if i == len(rhs_value_str) - 1 else [ravel(i + 1)],
else_block=else_block,
)

result.extend(self.visit_basic_block([ravel(0)])) # type: ignore[arg-type]
else:
# here we can unroll the block depending on the condition
positive_branching = Qasm3ExprEvaluator.evaluate_expression(condition)[0] != 0
Expand Down
10 changes: 7 additions & 3 deletions tests/qasm3/test_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,20 @@ def test_simple_if():
if (c[1] == true) {
cx q[1], q[2];
}
if (c == 5) {
x q[3];
if (c[0] == true) {
if (c[1] == false) {
if (c[2] == true) {
x q[3];
}
}
}
"""

result = loads(qasm)
result.unroll()
assert result.num_clbits == 4
assert result.num_qubits == 4

print(dumps(result))
check_unrolled_qasm(dumps(result), expected_qasm)


Expand Down

0 comments on commit 3a6415b

Please sign in to comment.