Skip to content

Commit 60570c2

Browse files
authored
Merge branch 'main' into rydberg-h-factory
2 parents 4d133b6 + 541c2fe commit 60570c2

File tree

2 files changed

+121
-119
lines changed

2 files changed

+121
-119
lines changed

src/bloqade/task/batch.py

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from bloqade.serialize import Serializer, dumps
1+
from bloqade.serialize import Serializer
22
from bloqade.task.base import Report
33
from bloqade.task.quera import QuEraTask
44
from bloqade.task.braket import BraketTask
@@ -15,19 +15,18 @@
1515

1616
# from bloqade.submission.base import ValidationError
1717

18-
from beartype.typing import Union, Optional, Dict, Any
18+
from beartype.typing import Union, Optional, Dict, Any, List
1919
from beartype import beartype
2020
from collections import OrderedDict
2121
from itertools import product
22-
import json
2322
import traceback
2423
import datetime
2524
import sys
2625
import os
2726
import warnings
2827
import pandas as pd
2928
import numpy as np
30-
from dataclasses import dataclass
29+
from dataclasses import dataclass, field
3130

3231

3332
class Serializable:
@@ -39,6 +38,8 @@ def json(self, **options) -> str:
3938
JSON string
4039
4140
"""
41+
from bloqade import dumps
42+
4243
return dumps(self, **options)
4344

4445

@@ -184,6 +185,63 @@ def _deserializer(d: Dict[str, Any]) -> LocalBatch:
184185
return LocalBatch(**d)
185186

186187

188+
@dataclass
189+
@Serializer.register
190+
class TaskError(Serializable):
191+
exception_type: str
192+
stack_trace: str
193+
194+
195+
@dataclass
196+
@Serializer.register
197+
class BatchErrors(Serializable):
198+
task_errors: OrderedDict[int, TaskError] = field(
199+
default_factory=lambda: OrderedDict([])
200+
)
201+
202+
@beartype
203+
def print_errors(self, task_indices: Union[List[int], int]) -> str:
204+
return str(self.get_errors(task_indices))
205+
206+
@beartype
207+
def get_errors(self, task_indices: Union[List[int], int]):
208+
return BatchErrors(
209+
task_errors=OrderedDict(
210+
[
211+
(task_index, self.task_errors[task_index])
212+
for task_index in task_indices
213+
if task_index in self.task_errors
214+
]
215+
)
216+
)
217+
218+
def __str__(self) -> str:
219+
output = ""
220+
for task_index, task_error in self.task_errors.items():
221+
output += (
222+
f"Task {task_index} failed to submit with error: "
223+
f"{task_error.exception_type}\n"
224+
f"{task_error.stack_trace}"
225+
)
226+
227+
return output
228+
229+
230+
@BatchErrors.set_serializer
231+
def _serialize(self: BatchErrors) -> Dict[str, List]:
232+
return {
233+
"task_errors": [
234+
(task_number, task_error)
235+
for task_number, task_error in self.task_errors.items()
236+
]
237+
}
238+
239+
240+
@BatchErrors.set_deserializer
241+
def _deserialize(obj: dict) -> BatchErrors:
242+
return BatchErrors(task_errors=OrderedDict(obj["task_errors"]))
243+
244+
187245
# this class get collection of tasks
188246
# basically behaves as a psudo queuing system
189247
# the user only need to store this objecet
@@ -321,6 +379,8 @@ def resubmit(self, shuffle_submit_order: bool = True) -> "RemoteBatch":
321379
def _submit(
322380
self, shuffle_submit_order: bool = True, ignore_submission_error=False, **kwargs
323381
) -> "RemoteBatch":
382+
from bloqade import save
383+
324384
# online, non-blocking
325385
if shuffle_submit_order:
326386
submission_order = np.random.permutation(list(self.tasks.keys()))
@@ -333,7 +393,7 @@ def _submit(
333393

334394
## upon submit() should validate for Both backends
335395
## and throw errors when fail.
336-
errors = OrderedDict()
396+
errors = BatchErrors()
337397
shuffled_tasks = OrderedDict()
338398
for task_index in submission_order:
339399
task = self.tasks[task_index]
@@ -342,17 +402,18 @@ def _submit(
342402
task.submit(**kwargs)
343403
except BaseException as error:
344404
# record the error in the error dict
345-
errors[int(task_index)] = {
346-
"exception_type": error.__class__.__name__,
347-
"stack trace": traceback.format_exc(),
348-
}
405+
errors.task_errors[int(task_index)] = TaskError(
406+
exception_type=error.__class__.__name__,
407+
stack_trace=traceback.format_exc(),
408+
)
409+
349410
task.task_result_ir = QuEraTaskResults(
350411
task_status=QuEraTaskStatusCode.Unaccepted
351412
)
352413

353414
self.tasks = shuffled_tasks # permute order using dump way
354415

355-
if errors:
416+
if len(errors.task_errors) > 0:
356417
time_stamp = datetime.datetime.now()
357418

358419
if "win" in sys.platform:
@@ -369,8 +430,8 @@ def _submit(
369430
# cloud_batch_result.save_json(future_file, indent=2)
370431
# saving ?
371432

372-
with open(error_file, "w") as f:
373-
json.dump(errors, f, indent=2)
433+
save(errors, error_file)
434+
save(self, future_file)
374435

375436
if ignore_submission_error:
376437
warnings.warn(
@@ -382,7 +443,9 @@ def _submit(
382443
)
383444
else:
384445
raise RemoteBatch.SubmissionException(
385-
"One or more error(s) occured during submission, please see "
446+
str(errors)
447+
+ "\n"
448+
+ "One or more error(s) occured during submission, please see "
386449
"the following files for more information:\n"
387450
f" - {os.path.join(cwd, future_file)}\n"
388451
f" - {os.path.join(cwd, error_file)}\n"

tests/test_task.py

Lines changed: 45 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,60 @@
1-
from bloqade.ir import Sequence, rydberg, detuning, Uniform, Linear, ScaledLocations
2-
from bloqade.ir.location import Square
1+
from bloqade.atom_arrangement import Chain
2+
from unittest.mock import patch
3+
import bloqade.ir.routine.quera
4+
from bloqade.task.batch import RemoteBatch
5+
import glob
6+
import os
37

4-
# import bloqade.lattice as lattice
8+
import pytest
59

6-
n_atoms = 11
7-
lattice_const = 5.9
810

9-
rabi_amplitude_values = [0.0, 15.8, 15.8, 0.0]
10-
rabi_detuning_values = [-16.33, -16.33, "delta_end", "delta_end"]
11-
durations = [0.8, "sweep_time", 0.8]
11+
@patch("bloqade.ir.routine.quera.MockBackend")
12+
def test_batch_error(*args):
13+
backend = bloqade.ir.routine.quera.MockBackend()
1214

13-
ordered_state_2D_prog = (
14-
Square(n_atoms, lattice_const)
15-
.rydberg.rabi.amplitude.uniform.piecewise_linear(durations, rabi_amplitude_values)
16-
.detuning.uniform.piecewise_linear(durations, rabi_detuning_values)
17-
)
15+
backend.submit_task.side_effect = ValueError("some random error")
16+
backend.dict.return_value = {"state_file": ".mock_state.txt"}
1817

19-
ordered_state_2D_job = ordered_state_2D_prog.assign(delta_end=42.66, sweep_time=2.4)
18+
with pytest.raises(RemoteBatch.SubmissionException):
19+
(
20+
Chain(5, 6.1)
21+
.rydberg.detuning.uniform.linear(-10, 10, 3.0)
22+
.quera.mock()
23+
.run_async(100)
24+
)
2025

21-
pbin = ordered_state_2D_job.quera.aquila()
26+
error_files = glob.glob("partial-batch-errors-*")
27+
batch_files = glob.glob("partial-batch-future-*")
2228

23-
pbin = pbin.parse_circuit()
29+
for error_file, batch_file in zip(error_files, batch_files):
30+
os.remove(error_file)
31+
os.remove(batch_file)
2432

25-
# pbin.circuit.sequence
33+
assert len(error_files) == 1
34+
assert len(batch_files) == 1
2635

2736

28-
# dict interface
29-
seq = Sequence(
30-
{
31-
rydberg: {
32-
detuning: {
33-
Uniform: Linear(start=1.0, stop="x", duration=3.0),
34-
ScaledLocations({1: 1.0, 2: 2.0}): Linear(
35-
start=1.0, stop="x", duration=3.0
36-
),
37-
},
38-
}
39-
}
40-
)
37+
@patch("bloqade.ir.routine.quera.MockBackend")
38+
def test_batch_warn(*args):
39+
backend = bloqade.ir.routine.quera.MockBackend()
4140

41+
backend.submit_task.side_effect = ValueError("some random error")
42+
backend.dict.return_value = {"state_file": ".mock_state.txt"}
4243

43-
# job = HardwareBatchResult.load_json("example-3-2d-ordered-state-job.json")
44+
with pytest.warns():
45+
(
46+
Chain(5, 6.1)
47+
.rydberg.detuning.uniform.linear(-10, 10, 3.0)
48+
.quera.mock()
49+
.run_async(100, ignore_submission_error=True)
50+
)
4451

45-
# res = job.report()
52+
error_files = glob.glob("partial-batch-errors-*")
53+
batch_files = glob.glob("partial-batch-future-*")
4654

55+
for error_file, batch_file in zip(error_files, batch_files):
56+
os.remove(error_file)
57+
os.remove(batch_file)
4758

48-
# print(lattice.Square(3).apply(seq).__lattice__)
49-
# print(lattice.Square(3).apply(seq).braket(nshots=1000).run_async().report().dataframe)
50-
# print("bitstring")
51-
# print(lattice.Square(3).apply(seq).braket(nshots=1000).run_async().report().bitstring)
52-
53-
# # pipe interface
54-
# report = (
55-
# lattice.Square(3)
56-
# .rydberg.detuning.uniform.apply(Linear(start=1.0, stop="x", duration=3.0))
57-
# .location(2)
58-
# .scale(3.0)
59-
# .apply(Linear(start=1.0, stop="x", duration=3.0))
60-
# .rydberg.rabi.amplitude.uniform
61-
# .apply(Linear(start=1.0, stop="x", duration=3.0))
62-
# .assign(x=10)
63-
# .braket(nshots=1000)
64-
# .run_async()
65-
# .report()
66-
# )
67-
68-
# print(report)
69-
# print(report.bitstring)
70-
# print(report.dataframe)
71-
72-
# lattice.Square(3).rydberg.detuning.location(2).location(3).apply(
73-
# Linear(start=1.0, stop="x", duration=3.0)
74-
# ).location(3).location(4).apply(Linear(start=1.0, stop="x", duration=3.0)).braket(
75-
# nshots=1000
76-
# ).run_async()
77-
78-
# # start.rydberg.detuning.location(2).location(3)
79-
80-
81-
# prog = (
82-
# lattice.Square(3)
83-
# .rydberg.detuning.uniform.apply(Linear(start=1.0, stop="x", duration=3.0))
84-
# .location(2)
85-
# .scale(3.0)
86-
# .apply(Linear(start=1.0, stop="x", duration=3.0))
87-
# .hyperfine.rabi.amplitude.location(2)
88-
# .apply(Linear(start=1.0, stop="x", duration=3.0))
89-
# .assign(x=1.0)
90-
# .multiplex(10.0).braket(nshots=1000)
91-
# .run_async()
92-
# .report()
93-
# .dataframe.groupby(by=["x"])
94-
# .count()
95-
# )
96-
97-
# (
98-
# lattice.Square(3)
99-
# .rydberg.detuning.uniform.apply(Linear(start=1.0, stop="x", duration=3.0))
100-
# .multiplex.quera
101-
# )
102-
103-
104-
# wf = (
105-
# Linear(start=1.0, stop=2.0, duration=2.0)
106-
# .scale(2.0)
107-
# .append(Linear(start=1.0, stop=2.0, duration=2.0))
108-
# )
109-
110-
111-
# prog = (
112-
# lattice.Square(3)
113-
# .hyperfine.detuning.location(1)
114-
# .scale(2.0)
115-
# .piecewise_linear(coeffs=[1.0, 2.0, 3.0])
116-
# .location(2)
117-
# .constant(value=2.0, duration="x")
118-
# )
119-
120-
# prog.seq
121-
# prog.lattice
59+
assert len(error_files) == 1
60+
assert len(batch_files) == 1

0 commit comments

Comments
 (0)