Skip to content

Commit

Permalink
debugging new casting.
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe58 committed Sep 27, 2023
1 parent 96f1b08 commit 6df9f35
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 82 deletions.
148 changes: 84 additions & 64 deletions src/bloqade/builder/assign.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import repeat, starmap
from beartype.typing import Optional, List
from itertools import repeat
from beartype.typing import Optional, List, Dict, Set
from bloqade.builder.typing import ParamType
from bloqade.builder.base import Builder
from bloqade.builder.pragmas import Parallelizable, AddArgs, BatchAssignable
Expand All @@ -9,50 +9,55 @@
import numpy as np


def cast_scalar_param(value: ParamType, name: str) -> Decimal:
if isinstance(value, (Real, Decimal)):
return Decimal(str(value))
class CastParams:
def __init__(self, n_sites: int, scalar_vars: Set[str], vector_vars: Set[str]):
self.n_sites = n_sites
self.scalar_vars = scalar_vars
self.vector_vars = vector_vars

raise TypeError(
f"assign parameter '{name}' must be a real number, "
f"found type: {type(value)}"
)
def cast_scalar_param(self, value: ParamType, name: str) -> Decimal:
if isinstance(value, (Real, Decimal)):
return Decimal(str(value))

raise TypeError(
f"assign parameter '{name}' must be a real number, "
f"found type: {type(value)}"
)

def cast_vector_param(value: List[ParamType], name: str) -> List[Decimal]:
if isinstance(value, np.ndarray):
value = value.tolist()
def cast_vector_param(
self,
value: List[ParamType],
name: str,
) -> List[Decimal]:
if isinstance(value, np.ndarray):
value = value.tolist()

if isinstance(value, (list, tuple)):
return list(starmap(cast_scalar_param, zip(value, repeat(name))))

raise TypeError(
f"assign parameter '{name}' must be a list of real numbers, "
f"found type: {type(value)}"
)


def cast_batch_scalar_param(value: List[ParamType], name: str) -> List[Decimal]:
if isinstance(value, np.ndarray):
value = value.tolist()

if isinstance(value, (list, tuple)):
return list(starmap(cast_scalar_param, zip(value, repeat(name))))

raise TypeError(
f"batch_assign parameter '{name}' must be a list of real numbers, "
f"found type: {type(value)}"
)
if isinstance(value, (list, tuple)):
if len(value) != self.n_sites:
raise ValueError(
f"assign parameter '{name}' must be a list of length "
f"{self.n_sites}, found length: {len(value)}"
)
return list(map(self.cast_scalar_param, value, repeat(name, len(value))))

raise TypeError(
f"assign parameter '{name}' must be a list of real numbers, "
f"found type: {type(value)}"
)

def cast_batch_vector_param(value: List[ParamType], name: str) -> List[List[Decimal]]:
if isinstance(value, (list, tuple)):
return list(starmap(cast_vector_param, zip(value, repeat(name))))
def cast_params(self, params: Dict[str, ParamType]) -> Dict[str, ParamType]:
checked_params = {}
for name, value in params.items():
if name not in self.scalar_vars and name not in self.vector_vars:
raise ValueError(
f"assign parameter '{name}' is not found in analog circuit."
)
if name in self.vector_vars:
checked_params[name] = self.cast_vector_param(value, name)
else:
checked_params[name] = self.cast_scalar_param(value, name)

raise TypeError(
f"batch_assign parameter '{name}' must be a list of lists of real numbers, "
f"found type: {type(value)}"
)
return checked_params


class AssignBase(Builder):
Expand All @@ -62,49 +67,64 @@ class AssignBase(Builder):
class Assign(BatchAssignable, AddArgs, Parallelizable, BackendRoute, AssignBase):
__match_args__ = ("_assignments", "__parent__")

def __init__(self, parent: Optional[Builder] = None, **assignments) -> None:
def __init__(
self, assignments: Dict[str, ParamType], parent: Optional[Builder] = None
) -> None:
from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit

super().__init__(parent)

circuit = self.parse_circuit()
vars = ScanVariablesAnalogCircuit().emit(circuit)
variables = ScanVariablesAnalogCircuit().emit(circuit)

self._assignments = {}
for name, value in assignments.items():
if name not in vars.scalar_vars and name not in vars.vector_vars:
raise ValueError(
f"batch_assign parameter '{name}' is not found in analog circuit."
)
if name in vars.vector_vars:
self._assignments[name] = cast_vector_param(value, name)
else:
self._assignments[name] = cast_scalar_param(value, name)
self._static_params = CastParams(
circuit.register.n_sites, variables.scalar_vars, variables.vector_vars
).cast_params(assignments)


class BatchAssign(AddArgs, Parallelizable, BackendRoute, AssignBase):
__match_args__ = ("_assignments", "__parent__")

def __init__(self, parent: Optional[Builder] = None, **assignments) -> None:
def __init__(
self, assignments: Dict[str, ParamType], parent: Optional[Builder] = None
) -> None:
from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit

super().__init__(parent)

circuit = self.parse_circuit()
vars = ScanVariablesAnalogCircuit().emit(circuit)

self._assignments = {}
for name, values in assignments.items():
if name not in vars.scalar_vars and name not in vars.vector_vars:
raise ValueError(
f"batch_assign parameter '{name}' is not found in analog circuit."
)
if name in vars.vector_vars:
self._assignments[name] = cast_batch_vector_param(values, name)
else:
self._assignments[name] = cast_batch_scalar_param(values, name)
variables = ScanVariablesAnalogCircuit().emit(circuit)

if not len(np.unique(list(map(len, assignments.values())))) == 1:
raise ValueError(
"all the assignment variables need to have same number of elements."
)

tuple_iterators = [
zip(repeat(name), values) for name, values in assignments.items()
]

caster = CastParams(
circuit.register.n_sites, variables.scalar_vars, variables.vector_vars
)

self._batch_params = list(
map(caster.cast_params, map(dict, zip(*tuple_iterators)))
)


class ListAssign(AddArgs, Parallelizable, BackendRoute, AssignBase):
def __init__(
self, batch_params: List[Dict[str, ParamType]], parent: Optional[Builder] = None
) -> None:
from bloqade.ir.analysis.scan_variables import ScanVariablesAnalogCircuit

super().__init__(parent)

circuit = self.parse_circuit()
variables = ScanVariablesAnalogCircuit().emit(circuit)
caster = CastParams(
circuit.register.n_sites, variables.scalar_vars, variables.vector_vars
)

self._batch_params = list(map(caster.cast_params, batch_params))
17 changes: 6 additions & 11 deletions src/bloqade/builder/parse/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from bloqade.builder.field import Field, Detuning, RabiAmplitude, RabiPhase
from bloqade.builder.spatial import SpatialModulation, Location, Uniform, Var, Scale
from bloqade.builder.waveform import WaveformPrimitive, Slice, Record, Sample, Fn
from bloqade.builder.assign import Assign, BatchAssign
from bloqade.builder.assign import Assign, BatchAssign, ListAssign
from bloqade.builder.args import Args
from bloqade.builder.parallelize import Parallelize
from bloqade.builder.parse.stream import BuilderNode, BuilderStream

import bloqade.ir as ir
from itertools import repeat
from typing import TYPE_CHECKING, Tuple, Union, Dict, List, Optional, Set
from beartype.typing import TYPE_CHECKING, Tuple, Union, Dict, List, Optional, Set

if TYPE_CHECKING:
from bloqade.ir.routine.params import ParamType
Expand Down Expand Up @@ -186,6 +184,7 @@ def read_pragmas(self) -> None:
pragma_types = (
Assign,
BatchAssign,
ListAssign,
Args,
Parallelize,
)
Expand All @@ -197,13 +196,9 @@ def read_pragmas(self) -> None:
node = curr.node

if isinstance(node, Assign):
self.static_params = dict(node._assignments)
elif isinstance(node, BatchAssign):
tuple_iterators = [
zip(repeat(name), values)
for name, values in node._assignments.items()
]
self.batch_params = list(map(dict, zip(*tuple_iterators)))
self.static_params = dict(node._static_params)
elif isinstance(node, BatchAssign) or isinstance(node, ListAssign):
self.batch_params = node._batch_params
elif isinstance(node, Args):
order = node._order

Expand Down
24 changes: 17 additions & 7 deletions src/bloqade/builder/pragmas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from beartype.typing import List, Union, TYPE_CHECKING
from beartype.typing import List, Dict, Union, TYPE_CHECKING
from bloqade.builder.typing import LiteralType, ParamType
from bloqade.ir.scalar import Variable

if TYPE_CHECKING:
from bloqade.builder.assign import Assign, BatchAssign
from bloqade.builder.assign import Assign, BatchAssign, ListAssign
from bloqade.builder.parallelize import Parallelize
from bloqade.builder.args import Args

Expand Down Expand Up @@ -39,14 +39,24 @@ def assign(self, **assignments) -> "Assign":
"""
from bloqade.builder.assign import Assign

return Assign(parent=self, **assignments)
return Assign(assignments, parent=self)


class BatchAssignable:
def batch_assign(self, **assignments: ParamType) -> "BatchAssign":
from bloqade.builder.assign import BatchAssign

return BatchAssign(parent=self, **assignments)
def batch_assign(
self,
batch_params: List[Dict[str, ParamType]] = [],
**assignments: List[ParamType],
) -> Union["BatchAssign", "ListAssign"]:
from bloqade.builder.assign import BatchAssign, ListAssign

if len(batch_params) > 0 and assignments:
raise ValueError("batch_params and assignments cannot be used together.")

if len(batch_params) > 0:
return ListAssign(batch_params, parent=self)
else:
return BatchAssign(assignments, parent=self)


class Parallelizable:
Expand Down

0 comments on commit 6df9f35

Please sign in to comment.