1
- from bloqade .serialize import Serializer , dumps
1
+ from bloqade .serialize import Serializer
2
2
from bloqade .task .base import Report
3
3
from bloqade .task .quera import QuEraTask
4
4
from bloqade .task .braket import BraketTask
15
15
16
16
# from bloqade.submission.base import ValidationError
17
17
18
- from beartype .typing import Union , Optional , Dict , Any
18
+ from beartype .typing import Union , Optional , Dict , Any , List
19
19
from beartype import beartype
20
20
from collections import OrderedDict
21
21
from itertools import product
22
- import json
23
22
import traceback
24
23
import datetime
25
24
import sys
26
25
import os
27
26
import warnings
28
27
import pandas as pd
29
28
import numpy as np
30
- from dataclasses import dataclass
29
+ from dataclasses import dataclass , field
31
30
32
31
33
32
class Serializable :
@@ -39,6 +38,8 @@ def json(self, **options) -> str:
39
38
JSON string
40
39
41
40
"""
41
+ from bloqade import dumps
42
+
42
43
return dumps (self , ** options )
43
44
44
45
@@ -184,6 +185,63 @@ def _deserializer(d: Dict[str, Any]) -> LocalBatch:
184
185
return LocalBatch (** d )
185
186
186
187
188
+ @dataclass
189
+ @Serializer .register
190
+ class TaskError (Serializable ):
191
+ exception_type : str
192
+ stack_trace : str
193
+
194
+
195
+ @dataclass
196
+ @Serializer .register
197
+ class BatchErrors (Serializable ):
198
+ task_errors : OrderedDict [int , TaskError ] = field (
199
+ default_factory = lambda : OrderedDict ([])
200
+ )
201
+
202
+ @beartype
203
+ def print_errors (self , task_indices : Union [List [int ], int ]) -> str :
204
+ return str (self .get_errors (task_indices ))
205
+
206
+ @beartype
207
+ def get_errors (self , task_indices : Union [List [int ], int ]):
208
+ return BatchErrors (
209
+ task_errors = OrderedDict (
210
+ [
211
+ (task_index , self .task_errors [task_index ])
212
+ for task_index in task_indices
213
+ if task_index in self .task_errors
214
+ ]
215
+ )
216
+ )
217
+
218
+ def __str__ (self ) -> str :
219
+ output = ""
220
+ for task_index , task_error in self .task_errors .items ():
221
+ output += (
222
+ f"Task { task_index } failed to submit with error: "
223
+ f"{ task_error .exception_type } \n "
224
+ f"{ task_error .stack_trace } "
225
+ )
226
+
227
+ return output
228
+
229
+
230
+ @BatchErrors .set_serializer
231
+ def _serialize (self : BatchErrors ) -> Dict [str , List ]:
232
+ return {
233
+ "task_errors" : [
234
+ (task_number , task_error )
235
+ for task_number , task_error in self .task_errors .items ()
236
+ ]
237
+ }
238
+
239
+
240
+ @BatchErrors .set_deserializer
241
+ def _deserialize (obj : dict ) -> BatchErrors :
242
+ return BatchErrors (task_errors = OrderedDict (obj ["task_errors" ]))
243
+
244
+
187
245
# this class get collection of tasks
188
246
# basically behaves as a psudo queuing system
189
247
# the user only need to store this objecet
@@ -321,6 +379,8 @@ def resubmit(self, shuffle_submit_order: bool = True) -> "RemoteBatch":
321
379
def _submit (
322
380
self , shuffle_submit_order : bool = True , ignore_submission_error = False , ** kwargs
323
381
) -> "RemoteBatch" :
382
+ from bloqade import save
383
+
324
384
# online, non-blocking
325
385
if shuffle_submit_order :
326
386
submission_order = np .random .permutation (list (self .tasks .keys ()))
@@ -333,7 +393,7 @@ def _submit(
333
393
334
394
## upon submit() should validate for Both backends
335
395
## and throw errors when fail.
336
- errors = OrderedDict ()
396
+ errors = BatchErrors ()
337
397
shuffled_tasks = OrderedDict ()
338
398
for task_index in submission_order :
339
399
task = self .tasks [task_index ]
@@ -342,17 +402,18 @@ def _submit(
342
402
task .submit (** kwargs )
343
403
except BaseException as error :
344
404
# record the error in the error dict
345
- errors [int (task_index )] = {
346
- "exception_type" : error .__class__ .__name__ ,
347
- "stack trace" : traceback .format_exc (),
348
- }
405
+ errors .task_errors [int (task_index )] = TaskError (
406
+ exception_type = error .__class__ .__name__ ,
407
+ stack_trace = traceback .format_exc (),
408
+ )
409
+
349
410
task .task_result_ir = QuEraTaskResults (
350
411
task_status = QuEraTaskStatusCode .Unaccepted
351
412
)
352
413
353
414
self .tasks = shuffled_tasks # permute order using dump way
354
415
355
- if errors :
416
+ if len ( errors . task_errors ) > 0 :
356
417
time_stamp = datetime .datetime .now ()
357
418
358
419
if "win" in sys .platform :
@@ -369,8 +430,8 @@ def _submit(
369
430
# cloud_batch_result.save_json(future_file, indent=2)
370
431
# saving ?
371
432
372
- with open ( error_file , "w" ) as f :
373
- json . dump ( errors , f , indent = 2 )
433
+ save ( errors , error_file )
434
+ save ( self , future_file )
374
435
375
436
if ignore_submission_error :
376
437
warnings .warn (
@@ -382,7 +443,9 @@ def _submit(
382
443
)
383
444
else :
384
445
raise RemoteBatch .SubmissionException (
385
- "One or more error(s) occured during submission, please see "
446
+ str (errors )
447
+ + "\n "
448
+ + "One or more error(s) occured during submission, please see "
386
449
"the following files for more information:\n "
387
450
f" - { os .path .join (cwd , future_file )} \n "
388
451
f" - { os .path .join (cwd , error_file )} \n "
0 commit comments