From 30fc63a30f0e6912c0049c778348c2b360d7ac1c Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Fri, 29 Sep 2023 11:09:17 -0400 Subject: [PATCH] making Routine's `pydantic` dataclasses. --- src/bloqade/ir/routine/base.py | 11 +++++++---- src/bloqade/ir/routine/bloqade.py | 6 +++--- src/bloqade/ir/routine/braket.py | 8 ++++---- src/bloqade/ir/routine/quera.py | 6 +++--- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/bloqade/ir/routine/base.py b/src/bloqade/ir/routine/base.py index 56338a889..ff0bde798 100644 --- a/src/bloqade/ir/routine/base.py +++ b/src/bloqade/ir/routine/base.py @@ -5,8 +5,8 @@ from bloqade.builder.base import Builder from bloqade.ir.routine.params import Params - -from dataclasses import dataclass +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass from typing import TYPE_CHECKING, Union if TYPE_CHECKING: @@ -29,7 +29,10 @@ def parse(self: "RoutineBase") -> "Routine": return self -@dataclass(frozen=True) +__pydantic_dataclass_config__ = ConfigDict(arbitrary_types_allowed=True) + + +@dataclass(frozen=True, config=__pydantic_dataclass_config__) class RoutineBase(RoutineParse): source: Builder circuit: AnalogCircuit @@ -43,7 +46,7 @@ def __str__(self): return out -@dataclass(frozen=True) +@dataclass(frozen=True, config=__pydantic_dataclass_config__) class Routine(RoutineBase): """Result of parsing a completed Builder string.""" diff --git a/src/bloqade/ir/routine/bloqade.py b/src/bloqade/ir/routine/bloqade.py index 51a09a258..eccc7afaa 100644 --- a/src/bloqade/ir/routine/bloqade.py +++ b/src/bloqade/ir/routine/bloqade.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from bloqade.ir.routine.base import RoutineBase +from bloqade.ir.routine.base import RoutineBase, __pydantic_dataclass_config__ from bloqade.builder.typing import LiteralType from bloqade.task.batch import LocalBatch from beartype import beartype @@ -7,13 +7,13 @@ from dataclasses import dataclass -@dataclass(frozen=True) +@dataclass(frozen=True, config=__pydantic_dataclass_config__) class BloqadeServiceOptions(RoutineBase): def python(self): return BloqadePythonRoutine(self.source, self.circuit, self.params) -@dataclass(frozen=True) +@dataclass(frozen=True, config=__pydantic_dataclass_config__) class BloqadePythonRoutine(RoutineBase): def _compile( self, diff --git a/src/bloqade/ir/routine/braket.py b/src/bloqade/ir/routine/braket.py index c2c718b4b..c0b5f52ae 100644 --- a/src/bloqade/ir/routine/braket.py +++ b/src/bloqade/ir/routine/braket.py @@ -4,14 +4,14 @@ from beartype.typing import Optional, Tuple from bloqade.builder.typing import LiteralType -from bloqade.ir.routine.base import RoutineBase +from bloqade.ir.routine.base import RoutineBase, __pydantic_dataclass_config__ from bloqade.submission.braket import BraketBackend from bloqade.task.batch import LocalBatch, RemoteBatch from bloqade.task.braket_simulator import BraketEmulatorTask from bloqade.task.braket import BraketTask -@dataclass(frozen=True) +@dataclass(frozen=True, config=__pydantic_dataclass_config__) class BraketServiceOptions(RoutineBase): def aquila(self) -> "BraketHardwareRoutine": backend = BraketBackend( @@ -23,7 +23,7 @@ def local_emulator(self): return BraketLocalEmulatorRoutine(self.source, self.circuit, self.params) -@dataclass(frozen=True) +@dataclass(frozen=True, config=__pydantic_dataclass_config__) class BraketHardwareRoutine(RoutineBase): backend: BraketBackend @@ -163,7 +163,7 @@ def __call__( return self.run(shots, args, name, shuffle, **kwargs) -@dataclass(frozen=True) +@dataclass(frozen=True, config=__pydantic_dataclass_config__) class BraketLocalEmulatorRoutine(RoutineBase): def _compile( self, shots: int, args: Tuple[LiteralType, ...] = (), name: Optional[str] = None diff --git a/src/bloqade/ir/routine/quera.py b/src/bloqade/ir/routine/quera.py index 5f6c9fa23..717b90b33 100644 --- a/src/bloqade/ir/routine/quera.py +++ b/src/bloqade/ir/routine/quera.py @@ -3,7 +3,7 @@ import json from bloqade.builder.typing import LiteralType -from bloqade.ir.routine.base import RoutineBase +from bloqade.ir.routine.base import RoutineBase, __pydantic_dataclass_config__ from bloqade.submission.quera import QuEraBackend from bloqade.submission.mock import MockBackend from bloqade.submission.quera_api_client.load_config import load_config @@ -14,7 +14,7 @@ from beartype import beartype -@dataclass(frozen=True) +@dataclass(frozen=True, config=__pydantic_dataclass_config__) class QuEraServiceOptions(RoutineBase): @beartype def device(self, config_file: Optional[str], **api_config): @@ -40,7 +40,7 @@ def mock(self, state_file: str = ".mock_state.txt") -> "QuEraHardwareRoutine": return QuEraHardwareRoutine(self.source, self.circuit, self.params, backend) -@dataclass(frozen=True) +@dataclass(frozen=True, config=__pydantic_dataclass_config__) class QuEraHardwareRoutine(RoutineBase): backend: Union[QuEraBackend, MockBackend]