Skip to content

Commit

Permalink
Initial switch implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGupta2012 committed Jul 29, 2024
1 parent 0aba78f commit ee35395
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 25 deletions.
149 changes: 127 additions & 22 deletions qbraid_qir/qasm3/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
Span,
Statement,
SubroutineDefinition,
SwitchStatement,
UnaryExpression,
WhileLoop,
)
Expand Down Expand Up @@ -958,7 +959,7 @@ def _visit_classical_declaration(self, statement: ClassicalDeclaration) -> None:

self._update_scope(variable)

def _analyse_classical_indices(self, indices: list[list], var_name: str) -> None:
def _analyse_classical_indices(self, indices: list[IntegerLiteral], var_name: str) -> None:
"""Validate the indices for a classical variable.
Args:
Expand All @@ -975,18 +976,17 @@ def _analyse_classical_indices(self, indices: list[list], var_name: str) -> None
var_dimensions = self._get_scope()[var_name].dims

if not var_dimensions:
self._print_err_location(indices[0][0].span)
self._print_err_location(indices[0].span)
raise Qasm3ConversionError(f"Indexing error. Variable {var_name} is not an array")

if len(indices) != len(var_dimensions):
self._print_err_location(indices[0][0].span)
self._print_err_location(indices[0].span)
raise Qasm3ConversionError(
f"Invalid number of indices for variable {var_name}. "
f"Expected {len(var_dimensions)} but got {len(indices)}"
)

for i, index in enumerate(indices):
index = index[0]
if isinstance(index, RangeDefinition):
self._print_err_location(index.span)
raise Qasm3ConversionError(
Expand Down Expand Up @@ -1075,7 +1075,12 @@ def _visit_classical_assignment(self, statement: ClassicalAssignment) -> None:

# handle assignment for arrays
if isinstance(lvalue, IndexedIdentifier):
indices = lvalue.indices
# stupid indices structure in openqasm :/
if len(lvalue.indices[0]) > 1:
indices = lvalue.indices[0]
else:
indices = [idx[0] for idx in lvalue.indices]

validated_indices = self._analyse_classical_indices(indices, var_name)
self._update_array_element(var.value, validated_indices, var_value)
else:
Expand Down Expand Up @@ -1107,6 +1112,34 @@ def _evaluate_array_initialization(

return init_values

def _analyse_index_expression(self, index_expr: IndexExpression) -> tuple[str, list[list]]:
"""Analyse an index expression to get the variable name and indices.
Args:
index_expr (IndexExpression): The index expression to analyse.
Returns:
tuple[str, list[list]]: The variable name and indices.
"""
indices = []
var_name = None
comma_separated = False

if isinstance(index_expr.collection, IndexExpression):
while isinstance(index_expr, IndexExpression):
indices.append(index_expr.index[0])
index_expr = index_expr.collection
else:
comma_separated = True
indices = index_expr.index

var_name = index_expr.collection.name if comma_separated else index_expr.name
if not comma_separated:
indices = indices[::-1]

return var_name, indices

# pylint: disable-next=too-many-return-statements
def _evaluate_expression(self, expression, const_expr: bool = False):
"""Evaluate an expression. Scalar types are assigned by value.
Expand Down Expand Up @@ -1155,20 +1188,6 @@ def _get_var_value(var_name, indices=None):
)
return var_value

def _analyse_index_expression(index_expr):
indices = []
var_name = None

# Recursive structure for IndexExpression
while isinstance(index_expr, IndexExpression):
indices.append(index_expr.index)
index_expr = index_expr.collection

indices = indices[::-1] # reverse indices as outermost was present first
var_name = index_expr.name

return var_name, indices

def process_variable(var_name, indices=None):
_check_var_in_scope(var_name)
_check_var_constant(var_name)
Expand All @@ -1183,7 +1202,7 @@ def process_variable(var_name, indices=None):
return process_variable(var_name)

if isinstance(expression, IndexExpression):
var_name, indices = _analyse_index_expression(expression)
var_name, indices = self._analyse_index_expression(expression)
return process_variable(var_name, indices)

if isinstance(expression, (BooleanLiteral, IntegerLiteral, FloatLiteral)):
Expand Down Expand Up @@ -1462,6 +1481,90 @@ def _visit_alias_statement(self, statement: AliasStatement) -> None:

_log.debug("Added labels for aliasing '%s'", target)

def _visit_switch_statement(self, statement: SwitchStatement) -> None:
"""Visit a switch statement element.
Args:
statement (SwitchStatement): The switch statement to visit.
Returns:
None
"""

# 1. analyse the target - it should ONLY be int, not casted
switch_target = statement.target

# either identifier or indexed expression
# if isinstance(switch_target, Identifier):
# switch_target_name = switch_target.name
# else:
# switch_target_name, indices = self._analyse_index_expression(switch_target)

# TODO: self._validate_variable_type(switch_target_name, Qasm3IntType)
switch_target_val = self._evaluate_expression(switch_target)

if len(statement.cases) == 0:
self._print_err_location(statement.span)
raise Qasm3ConversionError("Switch statement must have at least one case")

Check warning on line 1508 in qbraid_qir/qasm3/visitor.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/visitor.py#L1507-L1508

Added lines #L1507 - L1508 were not covered by tests

# 2. handle the cases of the switch stmt
# each element in the list of the values
# should be of const int type and no duplicates should be present
case_fulfilled = False
for case in statement.cases:
case_list = case[0]
seen_values = set()
for case_val in case_list:
# 3. evaluate and verify that it is a const_expression
# using vars only within the scope AND each component is either a
# literal OR type int
case_val = self._evaluate_expression(
case_val, const_expr=True
) # TODO: , reqd_type = Qasm3IntType)

if case_val in seen_values:
self._print_err_location(case.span)
raise Qasm3ConversionError(

Check warning on line 1527 in qbraid_qir/qasm3/visitor.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/visitor.py#L1526-L1527

Added lines #L1526 - L1527 were not covered by tests
f"Duplicate case value {case_val} in switch statement"
)

seen_values.add(case_val)

if case_val == switch_target_val:
case_fulfilled = True
break

if case_fulfilled:
# 4. each case has its own scope
self._push_scope({})
self._curr_scope += 1
self._label_scope_level[self._curr_scope] = set()

case_stmts = case[1].statements
for stmt in case_stmts:
self.visit_statement(stmt)

# 5. remove the labels and pop the scope
del self._label_scope_level[self._curr_scope]
self._curr_scope -= 1
self._pop_scope()
break

if not case_fulfilled and statement.default:
# 6. visit the default case
default_stmts = statement.default.statements

self._push_scope({})
self._curr_scope += 1
self._label_scope_level[self._curr_scope] = set()

for stmt in default_stmts:
self.visit_statement(stmt)

del self._label_scope_level[self._curr_scope]
self._curr_scope -= 1
self._pop_scope()

# pylint: disable-next=too-many-branches
def visit_statement(self, statement: Statement) -> None:
"""Visit a statement element.
Expand Down Expand Up @@ -1495,10 +1598,12 @@ def visit_statement(self, statement: Statement) -> None:
self._visit_branching_statement(statement)
elif isinstance(statement, ForInLoop):
self._visit_forin_loop(statement)
elif isinstance(statement, SubroutineDefinition):
raise NotImplementedError("OpenQASM 3 subroutines not yet supported")
elif isinstance(statement, AliasStatement):
self._visit_alias_statement(statement)
elif isinstance(statement, SwitchStatement):
self._visit_switch_statement(statement)
elif isinstance(statement, SubroutineDefinition):

Check warning on line 1605 in qbraid_qir/qasm3/visitor.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/qasm3/visitor.py#L1605

Added line #L1605 was not covered by tests
raise NotImplementedError("OpenQASM 3 subroutines not yet supported")
elif isinstance(statement, IODeclaration):
raise NotImplementedError("OpenQASM 3 IO declarations not yet supported")
else:
Expand Down
15 changes: 12 additions & 3 deletions tests/qasm3_qir/converter/declarations/test_classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,26 @@ def test_array_assignments():
bool f = true;
arr_int[0][1] = a*a;
arr_int[0,1] = a*a;
arr_uint[0][1] = b*b;
arr_uint[0,1] = b*b;
arr_float32[0][1] = c*c;
arr_float32[0,1] = c*c;
arr_float64[0][1] = d*d;
arr_float64[0,1] = d*d;
arr_bool[0][1] = f;
arr_bool[0,1] = f;
qubit q;
rx(arr_int[0][1]) q;
rx(arr_int[0,1]) q;
rx(arr_uint[0][1]) q;
rx(arr_float32[0][1]) q;
rx(arr_float32[0,1]) q;
rx(arr_float64[0][1]) q;
rx(arr_bool[0][1]) q;
rx(arr_bool[0,1]) q;
"""

a = 2
Expand Down
107 changes: 107 additions & 0 deletions tests/qasm3_qir/converter/test_switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (C) 2024 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.

"""
Module containing unit tests for converting OpenQASM 3 programs
with alias statements to QIR.
"""


from qbraid_qir.qasm3 import qasm3_to_qir
from tests.qir_utils import check_attributes, check_single_qubit_gate_op


def test_switch():
"""Test converting OpenQASM 3 program with openqasm3.ast.SwitchStatement."""

qasm3_switch_program = """
OPENQASM 3.0;
include "stdgates.inc";
const int i = 5;
qubit q;
switch(i) {
case 1,3,5,7 {
x q;
}
case 2,4,6,8 {
y q;
}
default {
z q;
}
}
"""

result = qasm3_to_qir(qasm3_switch_program, name="test")
generated_qir = str(result).splitlines()

check_attributes(generated_qir, 1)
check_single_qubit_gate_op(generated_qir, 1, [0], "x")


def test_switch_default():
"""Test converting OpenQASM 3 program with openqasm3.ast.SwitchStatement and default case."""

qasm3_switch_program = """
OPENQASM 3.0;
include "stdgates.inc";
const int i = 10;
qubit q;
switch(i) {
case 1,3,5,7 {
x q;
}
case 2,4,6,8 {
y q;
}
default {
z q;
}
}
"""

result = qasm3_to_qir(qasm3_switch_program, name="test")
generated_qir = str(result).splitlines()

check_attributes(generated_qir, 1)
check_single_qubit_gate_op(generated_qir, 1, [0], "z")


def test_switch_identifier_case():
"""Test converting OpenQASM 3 program with openqasm3.ast.SwitchStatement and identifier case."""

qasm3_switch_program = """
OPENQASM 3.0;
include "stdgates.inc";
const int i = 4;
const int j = 4;
qubit q;
switch(i) {
case 6, j {
x q;
}
default {
z q;
}
}
"""

result = qasm3_to_qir(qasm3_switch_program, name="test")
generated_qir = str(result).splitlines()

check_attributes(generated_qir, 1)
check_single_qubit_gate_op(generated_qir, 1, [0], "x")

0 comments on commit ee35395

Please sign in to comment.