From b14c97db3879a754a31b84346d3a613bae0363c2 Mon Sep 17 00:00:00 2001 From: TheGupta2012 Date: Thu, 5 Dec 2024 12:04:26 +0530 Subject: [PATCH] add remove includes --- src/pyqasm/modules/base.py | 34 ++++++++++++++++++++++++++++---- tests/test_include.py | 40 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index fc08b29..488ced7 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -120,7 +120,7 @@ def unrolled_ast(self, value: Program): """Setter for the unrolled AST""" self._unrolled_ast = value - def has_measurements(self): + def has_measurements(self) -> bool: """Check if the module has any measurement operations.""" if self._has_measurements is None: self._has_measurements = False @@ -137,7 +137,7 @@ def has_measurements(self): break return self._has_measurements - def remove_measurements(self, in_place: bool = True): + def remove_measurements(self, in_place: bool = True) -> Optional["QasmModule"]: """Remove the measurement operations Args: @@ -172,7 +172,7 @@ def remove_measurements(self, in_place: bool = True): return curr_module - def has_barriers(self): + def has_barriers(self) -> bool: """Check if the module has any barrier operations. Args: @@ -196,7 +196,7 @@ def has_barriers(self): break return self._has_barriers - def remove_barriers(self, in_place: bool = True): + def remove_barriers(self, in_place: bool = True) -> Optional["QasmModule"]: """Remove the barrier operations Args: @@ -226,6 +226,32 @@ def remove_barriers(self, in_place: bool = True): return curr_module + def remove_includes(self, in_place=True) -> Optional["QasmModule"]: + """Remove the include statements from the module + + Args: + in_place (bool): Flag to indicate if the removal should be done in place. + + Returns: + QasmModule: The module with the includes removed if in_place is False, None otherwise + """ + stmt_list = ( + self._statements + if len(self._unrolled_ast.statements) == 0 + else self._unrolled_ast.statements + ) + stmts_without_includes = [ + stmt for stmt in stmt_list if not isinstance(stmt, qasm3_ast.Include) + ] + curr_module = self + if not in_place: + curr_module = self.copy() + + curr_module._statements = stmts_without_includes + curr_module._unrolled_ast.statements = stmts_without_includes + + return curr_module + def depth(self): """Calculate the depth of the unrolled openqasm program. diff --git a/tests/test_include.py b/tests/test_include.py index aab0779..ef7c29e 100644 --- a/tests/test_include.py +++ b/tests/test_include.py @@ -63,3 +63,43 @@ def test_repeated_include_raises_error(): with pytest.raises(ValidationError): module = loads(qasm_str) module.validate() + + +def test_remove_includes(): + qasm_str = """ + OPENQASM 3.0; + include "stdgates.inc"; + include "random.qasm"; + + qubit[2] q; + h q; + """ + expected_qasm_str = """ + OPENQASM 3.0; + qubit[2] q; + h q[0]; + h q[1]; + """ + module = loads(qasm_str) + module.remove_includes() + module.unroll() + check_unrolled_qasm(dumps(module), expected_qasm_str) + + +def test_remove_includes_without_include(): + qasm_str = """ + OPENQASM 3.0; + + qubit[2] q; + h q; + """ + expected_qasm_str = """ + OPENQASM 3.0; + qubit[2] q; + h q[0]; + h q[1]; + """ + module = loads(qasm_str) + module = module.remove_includes(in_place=False) + module.unroll() + check_unrolled_qasm(dumps(module), expected_qasm_str)