From ce135987133c2a6f627deeeca61c5156d69683f5 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Tue, 12 Nov 2024 16:05:39 -0500 Subject: [PATCH 1/2] visualization bug + better error message for emulator code generation. --- .../compiler/codegen/python/emulator_ir.py | 19 ++++++++++++------- src/bloqade/ir/analog_circuit.py | 7 ++++--- src/bloqade/ir/control/field.py | 8 +++++++- tests/test_batch.py | 4 ++-- tests/test_field.py | 2 +- 5 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/bloqade/compiler/codegen/python/emulator_ir.py b/src/bloqade/compiler/codegen/python/emulator_ir.py index 1917817db..914e8daa9 100644 --- a/src/bloqade/compiler/codegen/python/emulator_ir.py +++ b/src/bloqade/compiler/codegen/python/emulator_ir.py @@ -338,6 +338,12 @@ def visit_field_RunTimeVector( self, node: field.RunTimeVector ) -> Dict[int, Decimal]: value = self.assignments[node.name] + for new_index, original_index in enumerate(self.original_index): + if original_index >= len(value): + raise ValueError( + f"Index {original_index} is out of bounds for the runtime vector {node.name}" + ) + return { new_index: Decimal(str(value[original_index])) for new_index, original_index in enumerate(self.original_index) @@ -347,6 +353,12 @@ def visit_field_RunTimeVector( def visit_field_AssignedRunTimeVector( self, node: field.AssignedRunTimeVector ) -> Dict[int, Decimal]: + for new_index, original_index in enumerate(self.original_index): + if original_index >= len(node.value): + raise ValueError( + f"Index {original_index} is out of bounds for the mask vector." + ) + return { new_index: Decimal(str(node.value[original_index])) for new_index, original_index in enumerate(self.original_index) @@ -358,13 +370,6 @@ def visit_field_ScaledLocations( ) -> Dict[int, Decimal]: target_atoms = {} - for location in node.value.keys(): - if location.value >= self.n_sites or location.value < 0: - raise ValueError( - f"Location {location.value} is out of bounds for register with " - f"{self.n_sites} sites." - ) - for new_index, original_index in enumerate(self.original_index): value = node.value.get(field.Location(original_index)) if value is not None and value != 0: diff --git a/src/bloqade/ir/analog_circuit.py b/src/bloqade/ir/analog_circuit.py index 6da9c3bb1..caa6fafbc 100644 --- a/src/bloqade/ir/analog_circuit.py +++ b/src/bloqade/ir/analog_circuit.py @@ -81,6 +81,9 @@ def figure(self, **assignments): # analysis the SpatialModulation information spmod_extracted_data: Dict[str, Tuple[List[int], List[float]]] = {} + def process_names(x): + return int(x.split("[")[-1].split("]")[0]) + for tab in fig_seq.tabs: pulse_name = tab.title field_plots = tab.child.children @@ -101,9 +104,7 @@ def figure(self, **assignments): for ch in channels: ch_data = Spmod_raw_data[Spmod_raw_data.d0 == ch] - sites = list( - map(lambda x: int(x.split("[")[-1].split("]")[0]), ch_data.d1) - ) + sites = list(map(process_names, ch_data.d1)) values = list(ch_data.px.astype(float)) key = f"{pulse_name}.{field_name}.{ch}" diff --git a/src/bloqade/ir/control/field.py b/src/bloqade/ir/control/field.py index 2a0b65ca5..2e362baf5 100644 --- a/src/bloqade/ir/control/field.py +++ b/src/bloqade/ir/control/field.py @@ -122,7 +122,13 @@ def figure(self, **assginment): return get_ir_figure(self, **assginment) def _get_data(self, **assignment): - return [self.name], ["vec"] + locs = [] + values = [] + for i, v in enumerate(self.value): + locs.append(f"{self.name or 'value'}[{i}]") + values.append(str(v)) + + return locs, values def show(self, **assignment): display_ir(self, **assignment) diff --git a/tests/test_batch.py b/tests/test_batch.py index 9fe95937e..d14ca80d8 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -172,7 +172,7 @@ def test_metadata_filter_scalar(): assert filtered_batch.tasks.keys() == {0, 1, 4} - with pytest.raises(ValueError): + with pytest.raises(Exception): filtered_batch = batch.filter_metadata(d=[1, 2, 16, 1j]) @@ -198,7 +198,7 @@ def test_metadata_filter_vector(): filters = dict(d=[1, 8], m=[[0, 1], [1, 0], (0, 0)]) - with pytest.raises(ValueError): + with pytest.raises(Exception): filtered_batch_all = batch.filter_metadata(**filters) diff --git a/tests/test_field.py b/tests/test_field.py index d428c4122..16c8b2f07 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -82,7 +82,7 @@ def test_assigned_runtime_vec(): ) assert x.print_node() == "AssignedRunTimeVector: sss" assert x.children() == cast([Decimal("1.0"), Decimal("2.0")]) - assert x._get_data() == (["sss"], ["vec"]) + assert x._get_data() == (["sss[0]", "sss[1]"], ["1.0", "2.0"]) mystdout = StringIO() p = PP(mystdout) From e3eac1379883b22528573acb577efdcb2019f41a Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Tue, 12 Nov 2024 16:12:01 -0500 Subject: [PATCH 2/2] adding back test for location ir. --- src/bloqade/compiler/codegen/python/emulator_ir.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/bloqade/compiler/codegen/python/emulator_ir.py b/src/bloqade/compiler/codegen/python/emulator_ir.py index 914e8daa9..f250a671e 100644 --- a/src/bloqade/compiler/codegen/python/emulator_ir.py +++ b/src/bloqade/compiler/codegen/python/emulator_ir.py @@ -369,6 +369,12 @@ def visit_field_ScaledLocations( self, node: field.ScaledLocations ) -> Dict[int, Decimal]: target_atoms = {} + for location in node.value.keys(): + if location.value >= self.n_sites or location.value < 0: + raise ValueError( + f"Location {location.value} is out of bounds for register with " + f"{self.n_sites} sites." + ) for new_index, original_index in enumerate(self.original_index): value = node.value.get(field.Location(original_index))