Skip to content

Commit

Permalink
non-equality comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
arulandu committed Dec 12, 2024
1 parent 4715d38 commit 74581b9
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 35 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ In [4]: from pyqasm import dumps
In [5]: dumps(module).splitlines()
Out[5]: ['OPENQASM 3.0;', 'qubit[2] q;', 'h q;']
```
- Added support for unrolling multi-bit equality branching.
- Added support for unrolling multi-bit branching with `==`, `>=`, `<=`, `>`, and `<`.

### Improved / Modified
- Refactored the initialization of `QasmModule` to remove default include statements. Only user supplied include statements are now added to the generated QASM code ([#86](https://github.com/qBraid/pyqasm/pull/86))
Expand Down
24 changes: 18 additions & 6 deletions src/pyqasm/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,15 +226,18 @@ def transform_gate_params(
gate_op.argument = Qasm3Transformer.transform_expression(gate_op.argument, param_map)

@staticmethod
def get_branch_params(condition: Expression) -> tuple[Optional[int], str, Optional[bool]]:
def get_branch_params(
condition: Expression,
) -> tuple[Optional[int], str, Optional[Union[BinaryOperator, UnaryOperator]], Optional[bool]]:
"""
Get the branch parameters from the branching condition
Args:
condition (Any): The condition to analyze
Returns:
tuple[Optional[int], str, Any]: register_idx, register_name, value of RHS
tuple[Optional[int], str, Expression, Any]:
register_idx, register_name, op, value of RHS
"""
if isinstance(condition, Identifier):
raise_qasm3_error(
Expand All @@ -251,12 +254,14 @@ def get_branch_params(condition: Expression) -> tuple[Optional[int], str, Option
return (
condition.expression.index[0].value,
condition.expression.collection.name,
condition.op,
False,
)
if isinstance(condition, BinaryExpression):
if condition.op != BinaryOperator["=="]:
if condition.op not in [BinaryOperator[o] for o in ["==", ">=", "<=", ">", "<"]]:
raise_qasm3_error(
message="Only '==' supported in branching condition with classical register",
message="Only {==, >=, <=, >, <} supported in branching condition "
"with classical register",
span=condition.span,
)

Expand All @@ -265,6 +270,7 @@ def get_branch_params(condition: Expression) -> tuple[Optional[int], str, Option
return (
None,
condition.lhs.name,
condition.op,
# do not evaluate to bool
Qasm3ExprEvaluator.evaluate_expression(condition.rhs, reqd_type=Qasm3IntType)[
0
Expand All @@ -273,6 +279,7 @@ def get_branch_params(condition: Expression) -> tuple[Optional[int], str, Option
return (
condition.lhs.index[0].value,
condition.lhs.collection.name,
condition.op,
# evaluate to bool
Qasm3ExprEvaluator.evaluate_expression(condition.rhs)[0] != 0,
)
Expand All @@ -288,9 +295,14 @@ def get_branch_params(condition: Expression) -> tuple[Optional[int], str, Option
message="RangeDefinition not supported in branching condition",
span=condition.span,
)
return (condition.index[0].value, condition.collection.name, True) # eg. if(c[0])
return (
condition.index[0].value,
condition.collection.name,
BinaryOperator["=="],
True,
) # eg. if(c[0])
# default case
return None, "", None
return None, "", None, None

@classmethod
def transform_function_qubits(
Expand Down
49 changes: 34 additions & 15 deletions src/pyqasm/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ def _visit_branching_statement(
# here, the lhs CAN only be a classical register as QCs won't have
# ability to evaluate expressions in the condition

reg_idx, reg_name, rhs_value = Qasm3Transformer.get_branch_params(condition)
reg_idx, reg_name, op, rhs_value = Qasm3Transformer.get_branch_params(condition)

if reg_name not in self._global_creg_size_map:
raise_qasm3_error(
Expand Down Expand Up @@ -1369,26 +1369,45 @@ def _visit_branching_statement(
result.append(new_if_block)
else:
# unroll multi-bit branch
rhs_value_str = bin(int(rhs_value))[2:][::-1]
assert isinstance(rhs_value, int) and op in [
qasm3_ast.BinaryOperator[o] for o in ["==", ">=", "<=", ">", "<"]
]

if op == qasm3_ast.BinaryOperator[">"]:
op = qasm3_ast.BinaryOperator[">="]
rhs_value += 1
elif op == qasm3_ast.BinaryOperator["<"]:
op = qasm3_ast.BinaryOperator["<="]
rhs_value -= 1

size = self._global_creg_size_map[reg_name]
rhs_value_str = bin(int(rhs_value))[2:].zfill(size)
else_block = self.visit_basic_block(statement.else_block)

def ravel(i):
r = rhs_value_str[i] == "1"

return qasm3_ast.BranchingStatement(
condition=qasm3_ast.BinaryExpression(
op=qasm3_ast.BinaryOperator["=="],
lhs=qasm3_ast.IndexExpression(
collection=qasm3_ast.Identifier(name=reg_name),
index=[qasm3_ast.IntegerLiteral(i)],
if (op == qasm3_ast.BinaryOperator[">="] and not r) or (
op == qasm3_ast.BinaryOperator["<="] and r
):
# ith-bit doesn't affect condition -> skip
return if_block if i == len(rhs_value_str) - 1 else ravel(i + 1)

return [
qasm3_ast.BranchingStatement(
condition=qasm3_ast.BinaryExpression(
op=qasm3_ast.BinaryOperator["=="],
lhs=qasm3_ast.IndexExpression(
collection=qasm3_ast.Identifier(name=reg_name),
index=[qasm3_ast.IntegerLiteral(size - i - 1)],
),
rhs=qasm3_ast.BooleanLiteral(r),
),
rhs=qasm3_ast.BooleanLiteral(r),
),
if_block=if_block if i == len(rhs_value_str) - 1 else [ravel(i + 1)],
else_block=else_block,
)
if_block=if_block if i == len(rhs_value_str) - 1 else ravel(i + 1),
else_block=else_block,
)
]

result.extend(self.visit_basic_block([ravel(0)])) # type: ignore[arg-type]
result.extend(self.visit_basic_block(ravel(0))) # type: ignore[arg-type]
else:
# here we can unroll the block depending on the condition
positive_branching = Qasm3ExprEvaluator.evaluate_expression(condition)[0] != 0
Expand Down
88 changes: 75 additions & 13 deletions tests/qasm3/test_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def test_simple_if():
if(c[1] == 1){
cx q[1], q[2];
}
if(c == 5){
x q[3];
}
"""
expected_qasm = """OPENQASM 3.0;
include "stdgates.inc";
Expand All @@ -57,20 +54,12 @@ def test_simple_if():
if (c[1] == true) {
cx q[1], q[2];
}
if (c[0] == true) {
if (c[1] == false) {
if (c[2] == true) {
x q[3];
}
}
}
"""

result = loads(qasm)
result.unroll()
assert result.num_clbits == 4
assert result.num_qubits == 4
print(dumps(result))
check_unrolled_qasm(dumps(result), expected_qasm)


Expand Down Expand Up @@ -140,6 +129,79 @@ def test_complex_if():
check_unrolled_qasm(dumps(result), expected_qasm)


def test_multi_bit_if():
qasm = """OPENQASM 3.0;
include "stdgates.inc";
qubit[1] q;
bit[4] c;
if(c == 3){
h q[0];
}
if(c >= 3){
h q[0];
} else {
x q[0];
}
if(c <= 3){
h q[0];
} else {
x q[0];
}
if(c < 4){
h q[0];
} else {
x q[0];
}
"""
expected_qasm = """OPENQASM 3.0;
include "stdgates.inc";
qubit[1] q;
bit[4] c;
if (c[3] == false) {
if (c[2] == false) {
if (c[1] == true) {
if (c[0] == true) {
h q[0];
}
}
}
}
if (c[1] == true) {
if (c[0] == true) {
h q[0];
} else {
x q[0];
}
} else {
x q[0];
}
if (c[3] == false) {
if (c[2] == false) {
h q[0];
} else {
x q[0];
}
} else {
x q[0];
}
if (c[3] == false) {
if (c[2] == false) {
h q[0];
} else {
x q[0];
}
} else {
x q[0];
}
"""

result = loads(qasm)
result.unroll()
assert result.num_clbits == 4
assert result.num_qubits == 1
check_unrolled_qasm(dumps(result), expected_qasm)


def test_incorrect_if():

with pytest.raises(ValidationError, match=r"Missing if block"):
Expand Down Expand Up @@ -191,7 +253,7 @@ def test_incorrect_if():
}
"""
).validate()
with pytest.raises(ValidationError, match=r"Only '==' supported .*"):
with pytest.raises(ValidationError, match=r"Only {==, >=, <=, >, <} supported in branching condition with classical register"):
loads(
"""
OPENQASM 3.0;
Expand All @@ -202,7 +264,7 @@ def test_incorrect_if():
h q;
measure q->c;
if(c[0] >= 1){
if(c[0] >> 1){
cx q;
}
"""
Expand Down

0 comments on commit 74581b9

Please sign in to comment.