From cac7b60d1cee65778459510a75e6e19c60f44788 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 29 Jan 2024 21:35:59 -0800 Subject: [PATCH] Different strategy, avoiding threads --- datasette_upload_csvs/__init__.py | 100 ++++++++++-------------------- 1 file changed, 33 insertions(+), 67 deletions(-) diff --git a/datasette_upload_csvs/__init__.py b/datasette_upload_csvs/__init__.py index 2a5bb4e..34799ec 100644 --- a/datasette_upload_csvs/__init__.py +++ b/datasette_upload_csvs/__init__.py @@ -126,60 +126,40 @@ def insert_initial_record(conn): await db.execute_write_fn(insert_initial_record) - # We run the CSV parser in a thread, sending 100 rows at a time to the DB - def parse_csv_in_thread(event_loop, csv_file, db, table_name, task_id): + def make_insert_batch(batch): + def inner(conn): + db = sqlite_utils.Database(conn) + db[table_name].insert_all(batch, alter=True) + + return inner + + # We run a parser in a separate async task, writing and yielding every 100 rows + async def parse_csv(): + i = 0 + tracker = TypeTracker() try: - reader = csv_std.reader(codecs.iterdecode(csv_file, encoding)) + reader = csv_std.reader(codecs.iterdecode(csv.file, encoding)) headers = next(reader) - tracker = TypeTracker() - docs = tracker.wrap(dict(zip(headers, row)) for row in reader) - i = 0 - - def docs_with_progress(): - nonlocal i - for doc in docs: - i += 1 - yield doc - if i % 10 == 0: - - def update_progress(conn): - database = sqlite_utils.Database(conn) - database["_csv_progress_"].update( - task_id, - { - "rows_done": i, - "bytes_done": csv_file.tell(), - }, - ) - - asyncio.run_coroutine_threadsafe( - db.execute_write_fn(update_progress), event_loop - ).result() - - def write_batch(batch): - def insert_batch(conn): - database = sqlite_utils.Database(conn) - database[table_name].insert_all(batch, alter=True) - - asyncio.run_coroutine_threadsafe( - db.execute_write_fn(insert_batch), event_loop - ).result() - batch = [] - batch_size = 0 - for doc in docs_with_progress(): + for doc in docs: batch.append(doc) - batch_size += 1 - if batch_size > 100: - write_batch(batch) + i += 1 + if i % 10 == 0: + await db.execute_write( + "update _csv_progress_ set rows_done = ?, bytes_done = ? where id = ?", + (i, csv.file.tell(), task_id), + ) + if i % 100 == 0: + await db.execute_write_fn(make_insert_batch(batch)) batch = [] - batch_size = 0 + # And yield to the event loop + await asyncio.sleep(0) if batch: - write_batch(batch) + await db.execute_write_fn(make_insert_batch(batch)) # Mark progress as complete def mark_complete(conn): @@ -194,37 +174,23 @@ def mark_complete(conn): }, ) - asyncio.run_coroutine_threadsafe( - db.execute_write_fn(mark_complete), event_loop - ).result() + await db.execute_write_fn(mark_complete) # Transform columns to detected types def transform_columns(conn): database = sqlite_utils.Database(conn) database[table_name].transform(types=tracker.types) - asyncio.run_coroutine_threadsafe( - db.execute_write_fn(transform_columns), event_loop - ).result() - except Exception as error: - - def insert_error(conn): - database = sqlite_utils.Database(conn) - database["_csv_progress_"].update( - task_id, - {"error": str(error)}, - ) + await db.execute_write_fn(transform_columns) - asyncio.run_coroutine_threadsafe( - db.execute_write_fn(insert_error), event_loop - ).result() - - loop = asyncio.get_running_loop() + except Exception as error: + await db.execute_write( + "update _csv_progress_ set error = ? where id = ?", + (str(error), task_id), + ) - # Start that thread running in the default executor in the background - loop.run_in_executor( - None, parse_csv_in_thread, loop, csv.file, db, table_name, task_id - ) + # Run that as a task + asyncio.create_task(parse_csv()) if formdata.get("xhr"): return Response.json(