Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch assign list and dict #640

Merged
merged 6 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 98 additions & 63 deletions src/bloqade/builder/assign.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import repeat, starmap
from beartype.typing import Optional, List
from itertools import repeat
from beartype.typing import Optional, List, Dict, Set, Sequence, Union
from bloqade.builder.typing import ParamType
from bloqade.builder.base import Builder
from bloqade.builder.pragmas import Parallelizable, AddArgs, BatchAssignable
Expand All @@ -9,50 +9,56 @@
import numpy as np


def cast_scalar_param(value: ParamType, name: str) -> Decimal:
if isinstance(value, (Real, Decimal)):
return Decimal(str(value))
class CastParams:
def __init__(self, n_sites: int, scalar_vars: Set[str], vector_vars: Set[str]):
self.n_sites = n_sites
self.scalar_vars = scalar_vars
self.vector_vars = vector_vars

raise TypeError(
f"assign parameter '{name}' must be a real number, "
f"found type: {type(value)}"
)
def cast_scalar_param(self, value: ParamType, name: str) -> Decimal:
if isinstance(value, (Real, Decimal)):
return Decimal(str(value))

raise TypeError(
f"assign parameter '{name}' must be a real number, "
f"found type: {type(value)}"
)

def cast_vector_param(value: List[ParamType], name: str) -> List[Decimal]:
if isinstance(value, np.ndarray):
value = value.tolist()
def cast_vector_param(
self,
value: Union[np.ndarray, List[ParamType]],
name: str,
) -> List[Decimal]:
if isinstance(value, np.ndarray):
value = value.tolist()

if isinstance(value, (list, tuple)):
return list(starmap(cast_scalar_param, zip(value, repeat(name))))

raise TypeError(
f"assign parameter '{name}' must be a list of real numbers, "
f"found type: {type(value)}"
)


def cast_batch_scalar_param(value: List[ParamType], name: str) -> List[Decimal]:
if isinstance(value, np.ndarray):
value = value.tolist()

if isinstance(value, (list, tuple)):
return list(starmap(cast_scalar_param, zip(value, repeat(name))))
if isinstance(value, (list, tuple)):
if len(value) != self.n_sites:
raise ValueError(
f"assign parameter '{name}' must be a list of length "
f"{self.n_sites}, found length: {len(value)}"
)
return list(map(self.cast_scalar_param, value, repeat(name, len(value))))

raise TypeError(
f"batch_assign parameter '{name}' must be a list of real numbers, "
f"found type: {type(value)}"
)
raise TypeError(
f"assign parameter '{name}' must be a list of real numbers, "
f"found type: {type(value)}"
)

def cast_params(self, params: Dict[str, ParamType]) -> Dict[str, ParamType]:
checked_params = {}

def cast_batch_vector_param(value: List[ParamType], name: str) -> List[List[Decimal]]:
if isinstance(value, (list, tuple)):
return list(starmap(cast_vector_param, zip(value, repeat(name))))
for name, value in params.items():
if name not in self.scalar_vars and name not in self.vector_vars:
raise ValueError(
f"assign parameter '{name}' is not found in analog circuit."
)
if name in self.vector_vars:
checked_params[name] = self.cast_vector_param(value, name)
else:
checked_params[name] = self.cast_scalar_param(value, name)

raise TypeError(
f"batch_assign parameter '{name}' must be a list of lists of real numbers, "
f"found type: {type(value)}"
)
return checked_params


class AssignBase(Builder):
Expand All @@ -62,49 +68,78 @@ class AssignBase(Builder):
class Assign(BatchAssignable, AddArgs, Parallelizable, BackendRoute, AssignBase):
__match_args__ = ("_assignments", "__parent__")

def __init__(self, parent: Optional[Builder] = None, **assignments) -> None:
def __init__(
self, assignments: Dict[str, ParamType], parent: Optional[Builder] = None
) -> None:
from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit

super().__init__(parent)

circuit = self.parse_circuit()
vars = ScanVariablesAnalogCircuit().emit(circuit)
variables = ScanVariablesAnalogCircuit().emit(circuit)

self._assignments = {}
for name, value in assignments.items():
if name not in vars.scalar_vars and name not in vars.vector_vars:
raise ValueError(
f"batch_assign parameter '{name}' is not found in analog circuit."
)
if name in vars.vector_vars:
self._assignments[name] = cast_vector_param(value, name)
else:
self._assignments[name] = cast_scalar_param(value, name)
self._static_params = CastParams(
circuit.register.n_sites, variables.scalar_vars, variables.vector_vars
).cast_params(assignments)


class BatchAssign(AddArgs, Parallelizable, BackendRoute, AssignBase):
__match_args__ = ("_assignments", "__parent__")

def __init__(self, parent: Optional[Builder] = None, **assignments) -> None:
def __init__(
self, assignments: Dict[str, List[ParamType]], parent: Optional[Builder] = None
) -> None:
from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit

super().__init__(parent)

circuit = self.parse_circuit()
vars = ScanVariablesAnalogCircuit().emit(circuit)

self._assignments = {}
for name, values in assignments.items():
if name not in vars.scalar_vars and name not in vars.vector_vars:
raise ValueError(
f"batch_assign parameter '{name}' is not found in analog circuit."
)
if name in vars.vector_vars:
self._assignments[name] = cast_batch_vector_param(values, name)
else:
self._assignments[name] = cast_batch_scalar_param(values, name)
variables = ScanVariablesAnalogCircuit().emit(circuit)

if not len(np.unique(list(map(len, assignments.values())))) == 1:
raise ValueError(
"all the assignment variables need to have same number of elements."
)

tuple_iterators = [
zip(repeat(name), values) for name, values in assignments.items()
]

caster = CastParams(
circuit.register.n_sites, variables.scalar_vars, variables.vector_vars
)

self._batch_params = list(
map(caster.cast_params, map(dict, zip(*tuple_iterators)))
)


class ListAssign(AddArgs, Parallelizable, BackendRoute, AssignBase):
def __init__(
self,
batch_params: Sequence[Dict[str, ParamType]],
parent: Optional[Builder] = None,
) -> None:
from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit

super().__init__(parent)

circuit = self.parse_circuit()
variables = ScanVariablesAnalogCircuit().emit(circuit)
caster = CastParams(
circuit.register.n_sites, variables.scalar_vars, variables.vector_vars
)

keys = set([])
for params in batch_params:
keys.update(params.keys())

for batch_num, params in enumerate(batch_params):
curr_keys = set(params.keys())
missing_keys = keys.difference(curr_keys)
if missing_keys:
raise ValueError(
f"Batch {batch_num} missing key(s): {tuple(missing_keys)}."
)

self._batch_params = list(map(caster.cast_params, batch_params))
17 changes: 6 additions & 11 deletions src/bloqade/builder/parse/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from bloqade.builder.field import Field, Detuning, RabiAmplitude, RabiPhase
from bloqade.builder.spatial import SpatialModulation, Location, Uniform, Var, Scale
from bloqade.builder.waveform import WaveformPrimitive, Slice, Record, Sample, Fn
from bloqade.builder.assign import Assign, BatchAssign
from bloqade.builder.assign import Assign, BatchAssign, ListAssign
from bloqade.builder.args import Args
from bloqade.builder.parallelize import Parallelize
from bloqade.builder.parse.stream import BuilderNode, BuilderStream

import bloqade.ir as ir
from itertools import repeat
from typing import TYPE_CHECKING, Tuple, Union, Dict, List, Optional, Set
from beartype.typing import TYPE_CHECKING, Tuple, Union, Dict, List, Optional, Set

if TYPE_CHECKING:
from bloqade.ir.routine.params import ParamType
Expand Down Expand Up @@ -186,6 +184,7 @@ def read_pragmas(self) -> None:
pragma_types = (
Assign,
BatchAssign,
ListAssign,
Args,
Parallelize,
)
Expand All @@ -197,13 +196,9 @@ def read_pragmas(self) -> None:
node = curr.node

if isinstance(node, Assign):
self.static_params = dict(node._assignments)
elif isinstance(node, BatchAssign):
tuple_iterators = [
zip(repeat(name), values)
for name, values in node._assignments.items()
]
self.batch_params = list(map(dict, zip(*tuple_iterators)))
self.static_params = dict(node._static_params)
elif isinstance(node, BatchAssign) or isinstance(node, ListAssign):
self.batch_params = node._batch_params
elif isinstance(node, Args):
order = node._order

Expand Down
24 changes: 17 additions & 7 deletions src/bloqade/builder/pragmas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from beartype.typing import List, Union, TYPE_CHECKING
from beartype.typing import List, Dict, Union, TYPE_CHECKING
from bloqade.builder.typing import LiteralType, ParamType
from bloqade.ir.scalar import Variable

if TYPE_CHECKING:
from bloqade.builder.assign import Assign, BatchAssign
from bloqade.builder.assign import Assign, BatchAssign, ListAssign

Check warning on line 6 in src/bloqade/builder/pragmas.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/builder/pragmas.py#L6

Added line #L6 was not covered by tests
from bloqade.builder.parallelize import Parallelize
from bloqade.builder.args import Args

Expand Down Expand Up @@ -39,14 +39,24 @@
"""
from bloqade.builder.assign import Assign

return Assign(parent=self, **assignments)
return Assign(assignments, parent=self)


class BatchAssignable:
def batch_assign(self, **assignments: ParamType) -> "BatchAssign":
from bloqade.builder.assign import BatchAssign

return BatchAssign(parent=self, **assignments)
def batch_assign(
self,
__batch_params: List[Dict[str, ParamType]] = [],
**assignments: List[ParamType],
) -> Union["BatchAssign", "ListAssign"]:
from bloqade.builder.assign import BatchAssign, ListAssign

if len(__batch_params) > 0 and assignments:
raise ValueError("batch_params and assignments cannot be used together.")

Check warning on line 54 in src/bloqade/builder/pragmas.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/builder/pragmas.py#L54

Added line #L54 was not covered by tests

if len(__batch_params) > 0:
return ListAssign(__batch_params, parent=self)
else:
return BatchAssign(assignments, parent=self)


class Parallelizable:
Expand Down
3 changes: 2 additions & 1 deletion src/bloqade/ir/location/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ def __init__(
else:
self.location_list.append(LocationInfo(ele, True))

if location_list:
if self.location_list:
self.__n_atoms = sum(
1 for loc in self.location_list if loc.filling == SiteFilling.filled
)
self.__n_sites = len(self.location_list)
self.__n_vacant = self.__n_sites - self.__n_atoms
self.__n_dims = len(self.location_list[0].position)
else:
self.__n_sites = 0
self.__n_atoms = 0
self.__n_dims = None

Expand Down
22 changes: 16 additions & 6 deletions src/bloqade/ir/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,14 @@
Printer(p).print(self, cycle)

@validator("name", allow_reuse=True)
def validate_name(cls, v):
check_variable_name(v)
return v
def validate_name(cls, name):
check_variable_name(name)
if name in ["__batch_params"]:
raise ValidationError(
"Cannot use reserved name `__batch_params` for variable name"
)

return name


@dataclass(frozen=True, repr=False)
Expand All @@ -430,9 +435,14 @@
return f"AssignedVariable: {self.name} = {self.value}"

@validator("name", allow_reuse=True)
def validate_name(cls, v):
check_variable_name(v)
return v
def validate_name(cls, name):
check_variable_name(name)
if name in ["__batch_params"]:
raise ValidationError(

Check warning on line 441 in src/bloqade/ir/scalar.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/ir/scalar.py#L441

Added line #L441 was not covered by tests
"Cannot use reserved name `__batch_params` for variable name"
)

return name


@dataclass(frozen=True, repr=False)
Expand Down
Loading
Loading