Skip to content

Commit

Permalink
Merge branch 'main' into rydberg-h-factory
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe58 authored Sep 29, 2023
2 parents 4d133b6 + 541c2fe commit 60570c2
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 119 deletions.
89 changes: 76 additions & 13 deletions src/bloqade/task/batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from bloqade.serialize import Serializer, dumps
from bloqade.serialize import Serializer
from bloqade.task.base import Report
from bloqade.task.quera import QuEraTask
from bloqade.task.braket import BraketTask
Expand All @@ -15,19 +15,18 @@

# from bloqade.submission.base import ValidationError

from beartype.typing import Union, Optional, Dict, Any
from beartype.typing import Union, Optional, Dict, Any, List
from beartype import beartype
from collections import OrderedDict
from itertools import product
import json
import traceback
import datetime
import sys
import os
import warnings
import pandas as pd
import numpy as np
from dataclasses import dataclass
from dataclasses import dataclass, field


class Serializable:
Expand All @@ -39,6 +38,8 @@ def json(self, **options) -> str:
JSON string
"""
from bloqade import dumps

return dumps(self, **options)


Expand Down Expand Up @@ -184,6 +185,63 @@ def _deserializer(d: Dict[str, Any]) -> LocalBatch:
return LocalBatch(**d)


@dataclass
@Serializer.register
class TaskError(Serializable):
exception_type: str
stack_trace: str


@dataclass
@Serializer.register
class BatchErrors(Serializable):
task_errors: OrderedDict[int, TaskError] = field(
default_factory=lambda: OrderedDict([])
)

@beartype
def print_errors(self, task_indices: Union[List[int], int]) -> str:
return str(self.get_errors(task_indices))

@beartype
def get_errors(self, task_indices: Union[List[int], int]):
return BatchErrors(
task_errors=OrderedDict(
[
(task_index, self.task_errors[task_index])
for task_index in task_indices
if task_index in self.task_errors
]
)
)

def __str__(self) -> str:
output = ""
for task_index, task_error in self.task_errors.items():
output += (
f"Task {task_index} failed to submit with error: "
f"{task_error.exception_type}\n"
f"{task_error.stack_trace}"
)

return output


@BatchErrors.set_serializer
def _serialize(self: BatchErrors) -> Dict[str, List]:
return {
"task_errors": [
(task_number, task_error)
for task_number, task_error in self.task_errors.items()
]
}


@BatchErrors.set_deserializer
def _deserialize(obj: dict) -> BatchErrors:
return BatchErrors(task_errors=OrderedDict(obj["task_errors"]))


# this class get collection of tasks
# basically behaves as a psudo queuing system
# the user only need to store this objecet
Expand Down Expand Up @@ -321,6 +379,8 @@ def resubmit(self, shuffle_submit_order: bool = True) -> "RemoteBatch":
def _submit(
self, shuffle_submit_order: bool = True, ignore_submission_error=False, **kwargs
) -> "RemoteBatch":
from bloqade import save

# online, non-blocking
if shuffle_submit_order:
submission_order = np.random.permutation(list(self.tasks.keys()))
Expand All @@ -333,7 +393,7 @@ def _submit(

## upon submit() should validate for Both backends
## and throw errors when fail.
errors = OrderedDict()
errors = BatchErrors()
shuffled_tasks = OrderedDict()
for task_index in submission_order:
task = self.tasks[task_index]
Expand All @@ -342,17 +402,18 @@ def _submit(
task.submit(**kwargs)
except BaseException as error:
# record the error in the error dict
errors[int(task_index)] = {
"exception_type": error.__class__.__name__,
"stack trace": traceback.format_exc(),
}
errors.task_errors[int(task_index)] = TaskError(
exception_type=error.__class__.__name__,
stack_trace=traceback.format_exc(),
)

task.task_result_ir = QuEraTaskResults(
task_status=QuEraTaskStatusCode.Unaccepted
)

self.tasks = shuffled_tasks # permute order using dump way

if errors:
if len(errors.task_errors) > 0:
time_stamp = datetime.datetime.now()

if "win" in sys.platform:
Expand All @@ -369,8 +430,8 @@ def _submit(
# cloud_batch_result.save_json(future_file, indent=2)
# saving ?

with open(error_file, "w") as f:
json.dump(errors, f, indent=2)
save(errors, error_file)
save(self, future_file)

if ignore_submission_error:
warnings.warn(
Expand All @@ -382,7 +443,9 @@ def _submit(
)
else:
raise RemoteBatch.SubmissionException(
"One or more error(s) occured during submission, please see "
str(errors)
+ "\n"
+ "One or more error(s) occured during submission, please see "
"the following files for more information:\n"
f" - {os.path.join(cwd, future_file)}\n"
f" - {os.path.join(cwd, error_file)}\n"
Expand Down
151 changes: 45 additions & 106 deletions tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,121 +1,60 @@
from bloqade.ir import Sequence, rydberg, detuning, Uniform, Linear, ScaledLocations
from bloqade.ir.location import Square
from bloqade.atom_arrangement import Chain
from unittest.mock import patch
import bloqade.ir.routine.quera
from bloqade.task.batch import RemoteBatch
import glob
import os

# import bloqade.lattice as lattice
import pytest

n_atoms = 11
lattice_const = 5.9

rabi_amplitude_values = [0.0, 15.8, 15.8, 0.0]
rabi_detuning_values = [-16.33, -16.33, "delta_end", "delta_end"]
durations = [0.8, "sweep_time", 0.8]
@patch("bloqade.ir.routine.quera.MockBackend")
def test_batch_error(*args):
backend = bloqade.ir.routine.quera.MockBackend()

ordered_state_2D_prog = (
Square(n_atoms, lattice_const)
.rydberg.rabi.amplitude.uniform.piecewise_linear(durations, rabi_amplitude_values)
.detuning.uniform.piecewise_linear(durations, rabi_detuning_values)
)
backend.submit_task.side_effect = ValueError("some random error")
backend.dict.return_value = {"state_file": ".mock_state.txt"}

ordered_state_2D_job = ordered_state_2D_prog.assign(delta_end=42.66, sweep_time=2.4)
with pytest.raises(RemoteBatch.SubmissionException):
(
Chain(5, 6.1)
.rydberg.detuning.uniform.linear(-10, 10, 3.0)
.quera.mock()
.run_async(100)
)

pbin = ordered_state_2D_job.quera.aquila()
error_files = glob.glob("partial-batch-errors-*")
batch_files = glob.glob("partial-batch-future-*")

pbin = pbin.parse_circuit()
for error_file, batch_file in zip(error_files, batch_files):
os.remove(error_file)
os.remove(batch_file)

# pbin.circuit.sequence
assert len(error_files) == 1
assert len(batch_files) == 1


# dict interface
seq = Sequence(
{
rydberg: {
detuning: {
Uniform: Linear(start=1.0, stop="x", duration=3.0),
ScaledLocations({1: 1.0, 2: 2.0}): Linear(
start=1.0, stop="x", duration=3.0
),
},
}
}
)
@patch("bloqade.ir.routine.quera.MockBackend")
def test_batch_warn(*args):
backend = bloqade.ir.routine.quera.MockBackend()

backend.submit_task.side_effect = ValueError("some random error")
backend.dict.return_value = {"state_file": ".mock_state.txt"}

# job = HardwareBatchResult.load_json("example-3-2d-ordered-state-job.json")
with pytest.warns():
(
Chain(5, 6.1)
.rydberg.detuning.uniform.linear(-10, 10, 3.0)
.quera.mock()
.run_async(100, ignore_submission_error=True)
)

# res = job.report()
error_files = glob.glob("partial-batch-errors-*")
batch_files = glob.glob("partial-batch-future-*")

for error_file, batch_file in zip(error_files, batch_files):
os.remove(error_file)
os.remove(batch_file)

# print(lattice.Square(3).apply(seq).__lattice__)
# print(lattice.Square(3).apply(seq).braket(nshots=1000).run_async().report().dataframe)
# print("bitstring")
# print(lattice.Square(3).apply(seq).braket(nshots=1000).run_async().report().bitstring)

# # pipe interface
# report = (
# lattice.Square(3)
# .rydberg.detuning.uniform.apply(Linear(start=1.0, stop="x", duration=3.0))
# .location(2)
# .scale(3.0)
# .apply(Linear(start=1.0, stop="x", duration=3.0))
# .rydberg.rabi.amplitude.uniform
# .apply(Linear(start=1.0, stop="x", duration=3.0))
# .assign(x=10)
# .braket(nshots=1000)
# .run_async()
# .report()
# )

# print(report)
# print(report.bitstring)
# print(report.dataframe)

# lattice.Square(3).rydberg.detuning.location(2).location(3).apply(
# Linear(start=1.0, stop="x", duration=3.0)
# ).location(3).location(4).apply(Linear(start=1.0, stop="x", duration=3.0)).braket(
# nshots=1000
# ).run_async()

# # start.rydberg.detuning.location(2).location(3)


# prog = (
# lattice.Square(3)
# .rydberg.detuning.uniform.apply(Linear(start=1.0, stop="x", duration=3.0))
# .location(2)
# .scale(3.0)
# .apply(Linear(start=1.0, stop="x", duration=3.0))
# .hyperfine.rabi.amplitude.location(2)
# .apply(Linear(start=1.0, stop="x", duration=3.0))
# .assign(x=1.0)
# .multiplex(10.0).braket(nshots=1000)
# .run_async()
# .report()
# .dataframe.groupby(by=["x"])
# .count()
# )

# (
# lattice.Square(3)
# .rydberg.detuning.uniform.apply(Linear(start=1.0, stop="x", duration=3.0))
# .multiplex.quera
# )


# wf = (
# Linear(start=1.0, stop=2.0, duration=2.0)
# .scale(2.0)
# .append(Linear(start=1.0, stop=2.0, duration=2.0))
# )


# prog = (
# lattice.Square(3)
# .hyperfine.detuning.location(1)
# .scale(2.0)
# .piecewise_linear(coeffs=[1.0, 2.0, 3.0])
# .location(2)
# .constant(value=2.0, duration="x")
# )

# prog.seq
# prog.lattice
assert len(error_files) == 1
assert len(batch_files) == 1

0 comments on commit 60570c2

Please sign in to comment.