diff --git a/src/bloqade/ir/location/location.py b/src/bloqade/ir/location/location.py index 6a5278195..25cbbe73b 100644 --- a/src/bloqade/ir/location/location.py +++ b/src/bloqade/ir/location/location.py @@ -8,6 +8,7 @@ from beartype import beartype from enum import Enum from numpy.typing import NDArray +from bloqade.submission.ir.capabilities import QuEraCapabilities from bloqade.visualization import get_atom_arrangement_figure from bloqade.visualization import display_ir @@ -517,6 +518,34 @@ def n_dims(self): def __str__(self): return "ParallelRegister:\n" + self.atom_arrangement.__str__() + def _compile_to_list( + self, __capabilities: Optional[QuEraCapabilities] = None, **assignments + ): + from bloqade.compiler.rewrite.common import AssignBloqadeIR + from bloqade.compiler.codegen.hardware import GenerateLattice + from bloqade.submission.capabilities import get_capabilities + + lattice_data = GenerateLattice(__capabilities or get_capabilities()).emit( + AssignBloqadeIR(assignments).emit(self) + ) + + list_of_locations = ListOfLocations() + for site, filling in zip(lattice_data.sites, lattice_data.filling): + list_of_locations = list_of_locations.add_position(site, filling == 1) + + return list_of_locations + + def figure( + self, + fig_kwargs=None, + capabilities: Optional[QuEraCapabilities] = None, + **assignments, + ): + return self._compile_to_list(capabilities).figure(fig_kwargs) + + def show(self, **assignments) -> None: + display_ir(self, assignments) + @dataclass(init=False) class ParallelRegisterInfo: diff --git a/tests/test_lattice.py b/tests/test_lattice.py index b94e1eb9d..aa93a1a1e 100644 --- a/tests/test_lattice.py +++ b/tests/test_lattice.py @@ -1,11 +1,14 @@ +from decimal import Decimal import bloqade.ir as ir -from bloqade.ir.location import ListOfLocations, AtomArrangement +from bloqade.ir.location import ListOfLocations, AtomArrangement, ParallelRegister from bloqade.ir.location import Square from bloqade.constants import RB_C6 from bloqade import cast import pytest import numpy as np +from bloqade.submission.capabilities import get_capabilities + def test_rydberg_interactions(): geometry = ListOfLocations([(0, 0), (1, 0), (0, 1), (1, 1)]).scale(5.0) @@ -113,3 +116,18 @@ def test_internal_base_listofloc(): with pytest.raises(NotImplementedError): lattice.n_dims + + +def test_parallel_register(): + lat = ListOfLocations([(0, 0)]) + reg = ParallelRegister(lat, cast(5)) + + capabilities = get_capabilities() + capabilities.capabilities.lattice.area.width = Decimal("1e-5") + capabilities.capabilities.lattice.area.height = Decimal("1e-5") + + list_of_locations = reg._compile_to_list(capabilities) + expected = ListOfLocations().add_position( + [(0, 0), (0, 5), (0, 10), (5, 0), (5, 5), (5, 10), (10, 0), (10, 5), (10, 10)] + ) + assert set(expected.enumerate()) == set(list_of_locations.enumerate())