From ee353954dbd4ffa52c7e23d87bfcbff15f0518fe Mon Sep 17 00:00:00 2001 From: TheGupta2012 Date: Mon, 29 Jul 2024 20:24:03 +0530 Subject: [PATCH] Initial switch implementation --- qbraid_qir/qasm3/visitor.py | 149 +++++++++++++++--- .../converter/declarations/test_classical.py | 15 +- tests/qasm3_qir/converter/test_switch.py | 107 +++++++++++++ 3 files changed, 246 insertions(+), 25 deletions(-) create mode 100644 tests/qasm3_qir/converter/test_switch.py diff --git a/qbraid_qir/qasm3/visitor.py b/qbraid_qir/qasm3/visitor.py index e896a15..9bdbc80 100644 --- a/qbraid_qir/qasm3/visitor.py +++ b/qbraid_qir/qasm3/visitor.py @@ -64,6 +64,7 @@ Span, Statement, SubroutineDefinition, + SwitchStatement, UnaryExpression, WhileLoop, ) @@ -958,7 +959,7 @@ def _visit_classical_declaration(self, statement: ClassicalDeclaration) -> None: self._update_scope(variable) - def _analyse_classical_indices(self, indices: list[list], var_name: str) -> None: + def _analyse_classical_indices(self, indices: list[IntegerLiteral], var_name: str) -> None: """Validate the indices for a classical variable. Args: @@ -975,18 +976,17 @@ def _analyse_classical_indices(self, indices: list[list], var_name: str) -> None var_dimensions = self._get_scope()[var_name].dims if not var_dimensions: - self._print_err_location(indices[0][0].span) + self._print_err_location(indices[0].span) raise Qasm3ConversionError(f"Indexing error. Variable {var_name} is not an array") if len(indices) != len(var_dimensions): - self._print_err_location(indices[0][0].span) + self._print_err_location(indices[0].span) raise Qasm3ConversionError( f"Invalid number of indices for variable {var_name}. " f"Expected {len(var_dimensions)} but got {len(indices)}" ) for i, index in enumerate(indices): - index = index[0] if isinstance(index, RangeDefinition): self._print_err_location(index.span) raise Qasm3ConversionError( @@ -1075,7 +1075,12 @@ def _visit_classical_assignment(self, statement: ClassicalAssignment) -> None: # handle assignment for arrays if isinstance(lvalue, IndexedIdentifier): - indices = lvalue.indices + # stupid indices structure in openqasm :/ + if len(lvalue.indices[0]) > 1: + indices = lvalue.indices[0] + else: + indices = [idx[0] for idx in lvalue.indices] + validated_indices = self._analyse_classical_indices(indices, var_name) self._update_array_element(var.value, validated_indices, var_value) else: @@ -1107,6 +1112,34 @@ def _evaluate_array_initialization( return init_values + def _analyse_index_expression(self, index_expr: IndexExpression) -> tuple[str, list[list]]: + """Analyse an index expression to get the variable name and indices. + + Args: + index_expr (IndexExpression): The index expression to analyse. + + Returns: + tuple[str, list[list]]: The variable name and indices. + + """ + indices = [] + var_name = None + comma_separated = False + + if isinstance(index_expr.collection, IndexExpression): + while isinstance(index_expr, IndexExpression): + indices.append(index_expr.index[0]) + index_expr = index_expr.collection + else: + comma_separated = True + indices = index_expr.index + + var_name = index_expr.collection.name if comma_separated else index_expr.name + if not comma_separated: + indices = indices[::-1] + + return var_name, indices + # pylint: disable-next=too-many-return-statements def _evaluate_expression(self, expression, const_expr: bool = False): """Evaluate an expression. Scalar types are assigned by value. @@ -1155,20 +1188,6 @@ def _get_var_value(var_name, indices=None): ) return var_value - def _analyse_index_expression(index_expr): - indices = [] - var_name = None - - # Recursive structure for IndexExpression - while isinstance(index_expr, IndexExpression): - indices.append(index_expr.index) - index_expr = index_expr.collection - - indices = indices[::-1] # reverse indices as outermost was present first - var_name = index_expr.name - - return var_name, indices - def process_variable(var_name, indices=None): _check_var_in_scope(var_name) _check_var_constant(var_name) @@ -1183,7 +1202,7 @@ def process_variable(var_name, indices=None): return process_variable(var_name) if isinstance(expression, IndexExpression): - var_name, indices = _analyse_index_expression(expression) + var_name, indices = self._analyse_index_expression(expression) return process_variable(var_name, indices) if isinstance(expression, (BooleanLiteral, IntegerLiteral, FloatLiteral)): @@ -1462,6 +1481,90 @@ def _visit_alias_statement(self, statement: AliasStatement) -> None: _log.debug("Added labels for aliasing '%s'", target) + def _visit_switch_statement(self, statement: SwitchStatement) -> None: + """Visit a switch statement element. + + Args: + statement (SwitchStatement): The switch statement to visit. + + Returns: + None + """ + + # 1. analyse the target - it should ONLY be int, not casted + switch_target = statement.target + + # either identifier or indexed expression + # if isinstance(switch_target, Identifier): + # switch_target_name = switch_target.name + # else: + # switch_target_name, indices = self._analyse_index_expression(switch_target) + + # TODO: self._validate_variable_type(switch_target_name, Qasm3IntType) + switch_target_val = self._evaluate_expression(switch_target) + + if len(statement.cases) == 0: + self._print_err_location(statement.span) + raise Qasm3ConversionError("Switch statement must have at least one case") + + # 2. handle the cases of the switch stmt + # each element in the list of the values + # should be of const int type and no duplicates should be present + case_fulfilled = False + for case in statement.cases: + case_list = case[0] + seen_values = set() + for case_val in case_list: + # 3. evaluate and verify that it is a const_expression + # using vars only within the scope AND each component is either a + # literal OR type int + case_val = self._evaluate_expression( + case_val, const_expr=True + ) # TODO: , reqd_type = Qasm3IntType) + + if case_val in seen_values: + self._print_err_location(case.span) + raise Qasm3ConversionError( + f"Duplicate case value {case_val} in switch statement" + ) + + seen_values.add(case_val) + + if case_val == switch_target_val: + case_fulfilled = True + break + + if case_fulfilled: + # 4. each case has its own scope + self._push_scope({}) + self._curr_scope += 1 + self._label_scope_level[self._curr_scope] = set() + + case_stmts = case[1].statements + for stmt in case_stmts: + self.visit_statement(stmt) + + # 5. remove the labels and pop the scope + del self._label_scope_level[self._curr_scope] + self._curr_scope -= 1 + self._pop_scope() + break + + if not case_fulfilled and statement.default: + # 6. visit the default case + default_stmts = statement.default.statements + + self._push_scope({}) + self._curr_scope += 1 + self._label_scope_level[self._curr_scope] = set() + + for stmt in default_stmts: + self.visit_statement(stmt) + + del self._label_scope_level[self._curr_scope] + self._curr_scope -= 1 + self._pop_scope() + # pylint: disable-next=too-many-branches def visit_statement(self, statement: Statement) -> None: """Visit a statement element. @@ -1495,10 +1598,12 @@ def visit_statement(self, statement: Statement) -> None: self._visit_branching_statement(statement) elif isinstance(statement, ForInLoop): self._visit_forin_loop(statement) - elif isinstance(statement, SubroutineDefinition): - raise NotImplementedError("OpenQASM 3 subroutines not yet supported") elif isinstance(statement, AliasStatement): self._visit_alias_statement(statement) + elif isinstance(statement, SwitchStatement): + self._visit_switch_statement(statement) + elif isinstance(statement, SubroutineDefinition): + raise NotImplementedError("OpenQASM 3 subroutines not yet supported") elif isinstance(statement, IODeclaration): raise NotImplementedError("OpenQASM 3 IO declarations not yet supported") else: diff --git a/tests/qasm3_qir/converter/declarations/test_classical.py b/tests/qasm3_qir/converter/declarations/test_classical.py index 913ea9b..3563730 100644 --- a/tests/qasm3_qir/converter/declarations/test_classical.py +++ b/tests/qasm3_qir/converter/declarations/test_classical.py @@ -209,17 +209,26 @@ def test_array_assignments(): bool f = true; arr_int[0][1] = a*a; + arr_int[0,1] = a*a; + arr_uint[0][1] = b*b; + arr_uint[0,1] = b*b; + arr_float32[0][1] = c*c; + arr_float32[0,1] = c*c; + arr_float64[0][1] = d*d; + arr_float64[0,1] = d*d; + arr_bool[0][1] = f; + arr_bool[0,1] = f; qubit q; - rx(arr_int[0][1]) q; + rx(arr_int[0,1]) q; rx(arr_uint[0][1]) q; - rx(arr_float32[0][1]) q; + rx(arr_float32[0,1]) q; rx(arr_float64[0][1]) q; - rx(arr_bool[0][1]) q; + rx(arr_bool[0,1]) q; """ a = 2 diff --git a/tests/qasm3_qir/converter/test_switch.py b/tests/qasm3_qir/converter/test_switch.py new file mode 100644 index 0000000..9b698a2 --- /dev/null +++ b/tests/qasm3_qir/converter/test_switch.py @@ -0,0 +1,107 @@ +# Copyright (C) 2024 qBraid +# +# This file is part of the qBraid-SDK +# +# The qBraid-SDK is free software released under the GNU General Public License v3 +# or later. You can redistribute and/or modify it under the terms of the GPL v3. +# See the LICENSE file in the project root or . +# +# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3. + +""" +Module containing unit tests for converting OpenQASM 3 programs +with alias statements to QIR. + +""" + + +from qbraid_qir.qasm3 import qasm3_to_qir +from tests.qir_utils import check_attributes, check_single_qubit_gate_op + + +def test_switch(): + """Test converting OpenQASM 3 program with openqasm3.ast.SwitchStatement.""" + + qasm3_switch_program = """ + OPENQASM 3.0; + include "stdgates.inc"; + + const int i = 5; + qubit q; + + switch(i) { + case 1,3,5,7 { + x q; + } + case 2,4,6,8 { + y q; + } + default { + z q; + } + } + """ + + result = qasm3_to_qir(qasm3_switch_program, name="test") + generated_qir = str(result).splitlines() + + check_attributes(generated_qir, 1) + check_single_qubit_gate_op(generated_qir, 1, [0], "x") + + +def test_switch_default(): + """Test converting OpenQASM 3 program with openqasm3.ast.SwitchStatement and default case.""" + + qasm3_switch_program = """ + OPENQASM 3.0; + include "stdgates.inc"; + + const int i = 10; + qubit q; + + switch(i) { + case 1,3,5,7 { + x q; + } + case 2,4,6,8 { + y q; + } + default { + z q; + } + } + """ + + result = qasm3_to_qir(qasm3_switch_program, name="test") + generated_qir = str(result).splitlines() + + check_attributes(generated_qir, 1) + check_single_qubit_gate_op(generated_qir, 1, [0], "z") + + +def test_switch_identifier_case(): + """Test converting OpenQASM 3 program with openqasm3.ast.SwitchStatement and identifier case.""" + + qasm3_switch_program = """ + OPENQASM 3.0; + include "stdgates.inc"; + + const int i = 4; + const int j = 4; + qubit q; + + switch(i) { + case 6, j { + x q; + } + default { + z q; + } + } + """ + + result = qasm3_to_qir(qasm3_switch_program, name="test") + generated_qir = str(result).splitlines() + + check_attributes(generated_qir, 1) + check_single_qubit_gate_op(generated_qir, 1, [0], "x")