Skip to content

Commit

Permalink
Merge branch 'main' into 600-beartypes-multiple-dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe58 committed Sep 27, 2023
2 parents 953b703 + cb429e2 commit 691eafe
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 26 deletions.
30 changes: 15 additions & 15 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions src/bloqade/builder/pragmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ def assign(self, **assignments) -> "Assign":
class BatchAssignable:
def batch_assign(
self,
batch_params: List[Dict[str, ParamType]] = [],
__batch_params: List[Dict[str, ParamType]] = [],
**assignments: List[ParamType],
) -> Union["BatchAssign", "ListAssign"]:
from bloqade.builder.assign import BatchAssign, ListAssign

if len(batch_params) > 0 and assignments:
if len(__batch_params) > 0 and assignments:
raise ValueError("batch_params and assignments cannot be used together.")

if len(batch_params) > 0:
return ListAssign(batch_params, parent=self)
if len(__batch_params) > 0:
return ListAssign(__batch_params, parent=self)
else:
return BatchAssign(assignments, parent=self)

Expand Down
22 changes: 16 additions & 6 deletions src/bloqade/ir/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,14 @@ def _repr_pretty_(self, p, cycle):
Printer(p).print(self, cycle)

@validator("name", allow_reuse=True)
def validate_name(cls, v):
check_variable_name(v)
return v
def validate_name(cls, name):
check_variable_name(name)
if name in ["__batch_params"]:
raise ValidationError(
"Cannot use reserved name `__batch_params` for variable name"
)

return name


@dataclass(frozen=True, repr=False)
Expand All @@ -430,9 +435,14 @@ def print_node(self):
return f"AssignedVariable: {self.name} = {self.value}"

@validator("name", allow_reuse=True)
def validate_name(cls, v):
check_variable_name(v)
return v
def validate_name(cls, name):
check_variable_name(name)
if name in ["__batch_params"]:
raise ValidationError(
"Cannot use reserved name `__batch_params` for variable name"
)

return name


@dataclass(frozen=True, repr=False)
Expand Down
42 changes: 41 additions & 1 deletion tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,50 @@ def test_assign_error():
)

list_dict = [dict(c=1, t=10), dict(c=2, t=20), dict(c=np, t=30)]

with pytest.raises(TypeError):
start.rydberg.detuning.uniform.constant("c", "t").batch_assign(list_dict)

list_dict = [dict(c=1, t=10), dict(c=2, t=20), dict(t=30)]
with pytest.raises(ValueError):
start.rydberg.detuning.uniform.constant("c", "t").batch_assign(list_dict)

list_dict = [dict(c=1, t=10, f=1), dict(c=2, t=20, f=2), dict(t=30, c=3, f=3)]
with pytest.raises(ValueError):
start.rydberg.detuning.uniform.constant("c", "t").batch_assign(list_dict)

dict_list = dict(
c=[1, 2, 3], t=[10, 20, 30], mask=[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
)

with pytest.raises(ValueError):
start.rydberg.detuning.uniform.constant("c", "t").batch_assign(**dict_list)

list_dict = [
dict(c=1, t=10, mask=[1, 2, 3]),
dict(c=2, t=20, mask=[4, 5, 6]),
dict(t=30, c=3, mask=[7, 8, 9]),
]

with pytest.raises(ValueError):
start.add_position([(0, 0), (0, 6)]).rydberg.detuning.var("mask").constant(
"c", "t"
).batch_assign(list_dict)

# happy path is to have a list of dicts with the same keys
start.add_position([(0, 0), (0, 6), (0, 12)]).rydberg.detuning.var("mask").constant(
"c", "t"
).batch_assign(list_dict)

list_dict = [
dict(c=1, t=10, mask=np.array([1, 2, 3])),
dict(c=2, t=20, mask=np.array([4, 5, 6])),
dict(t=30, c=3, mask=np.array([7, 8, 9])),
]
with pytest.raises(ValueError):
start.add_position([(0, 0), (0, 6)]).rydberg.detuning.var("mask").constant(
"c", "t"
).batch_assign(list_dict)

with pytest.raises(TypeError):
(
start.add_position((0, 0))
Expand Down
4 changes: 4 additions & 0 deletions tests/test_scalar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pydantic import ValidationError
from bloqade import cast, var
import bloqade.ir.scalar as scalar
import pytest
Expand All @@ -19,6 +20,9 @@ def test_var():
with pytest.raises(ValueError):
var("a*b")

with pytest.raises(ValidationError):
var("__batch_params")

Vv = var("a")
vs = var(Vv)

Expand Down

0 comments on commit 691eafe

Please sign in to comment.