Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
andycasey committed Oct 31, 2024
1 parent d6e405b commit 61fdec5
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions src/astra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from inspect import isgeneratorfunction
from decorator import decorator
from peewee import chunked, IntegrityError, SqliteDatabase
from playhouse.sqlite_ext import SqliteExtDatabase
from peewee import IntegrityError
from sdsstools.configuration import get_config

from astra.utils import log, Timer
Expand All @@ -13,10 +12,10 @@
def task(
function,
*args,
batch_size=1000,
write_frequency=300,
write_to_database=True,
re_raise_exceptions=True,
batch_size: int = 1000,
write_frequency: int = 300,
write_to_database: bool = True,
re_raise_exceptions: bool = True,
**kwargs
):
"""
Expand All @@ -28,18 +27,20 @@ def task(
:param \*args:
The arguments to the task.
:param batch_size: [optional]
The number of rows to insert per batch (default: 1000).
:param write_frequency: [optional]
The number of seconds to wait before saving the results to the database (default: 300).
:param write_to_database: [optional]
If `True` (default), results will be written to the database. Otherwise, they will be ignored.
:param re_raise_exceptions: [optional]
If `True` (default), exceptions raised in the task will be raised. Otherwise, they will be logged and ignored.
:param \**kwargs:
Keyword arguments for the task and the task decorator. See below.
:Keyword Arguments:
* *frequency* (``int``) --
The number of seconds to wait before saving the results to the database (default: 300).
* *result_frequency* (``int``) --
The number of results to wait before saving the results to the database (default: 300).
* *batch_size* (``int``) --
The number of rows to insert per batch (default: 1000).
* *re_raise_exceptions* (``bool``) --
If `True` (default), exceptions raised in the task will be raised. Otherwise, they will be logged and ignored.
"""

if not isgeneratorfunction(function):
Expand Down Expand Up @@ -73,15 +74,15 @@ def task(
except:
log.exception(f"Exception raised in task {function.__name__}")
if re_raise_exceptions:
raise
raise

finally:
if write_to_database and (timer.check_point or n >= batch_size):
with timer.pause():
# Add estimated overheads to each result.
timer.add_overheads(results)
try:
yield from bulk_insert_or_replace_pipeline_results(results, re_raise_exceptions)
yield from bulk_insert_or_replace_pipeline_results(results)
except:
log.exception(f"Exception trying to insert results to database:")
if re_raise_exceptions:
Expand All @@ -99,22 +100,19 @@ def task(
timer.add_overheads(results)
try:
# Write any remaining results to the database.
yield from bulk_insert_or_replace_pipeline_results(results, re_raise_exceptions)
yield from bulk_insert_or_replace_pipeline_results(results)
except:
log.exception(f"Exception trying to insert results to database:")
if re_raise_exceptions:
raise


def bulk_insert_or_replace_pipeline_results(results, re_raise_exceptions):
def bulk_insert_or_replace_pipeline_results(results):
"""
Insert a batch of results to the database.
:param results:
A list of records to create (e.g., sub-classes of `astra.models.BaseModel`).
:param batch_size:
The batch size to use when creating results.
A list of records to create (e.g., sub-classes of `astra.models.BaseModel`).
"""
log.info(f"Bulk inserting {len(results)} into the database")

Expand Down

0 comments on commit 61fdec5

Please sign in to comment.