From 11d8b601675a42644ba6e890f445bb9016b7b876 Mon Sep 17 00:00:00 2001 From: bendnorman Date: Wed, 20 Nov 2024 16:44:37 -0900 Subject: [PATCH] Create a pyarrow schema for parquet files to avoid timestamp issue --- src/dbcp/cli.py | 5 ++-- src/dbcp/data_mart/__init__.py | 15 ++++++++---- src/dbcp/etl.py | 11 +++++++-- src/dbcp/helpers.py | 44 +++++++++++++++++++++++++++++----- 4 files changed, 60 insertions(+), 15 deletions(-) diff --git a/src/dbcp/cli.py b/src/dbcp/cli.py index 0085f455..2cfa38e0 100644 --- a/src/dbcp/cli.py +++ b/src/dbcp/cli.py @@ -1,4 +1,5 @@ """A Command line interface for the down ballot project.""" + import argparse import logging import sys @@ -66,9 +67,9 @@ def main(): SPATIAL_CACHE.clear() if args.etl: - dbcp.etl.etl(args) + dbcp.etl.etl() if args.data_mart: - dbcp.data_mart.create_data_marts(args) + dbcp.data_mart.create_data_marts() if __name__ == "__main__": diff --git a/src/dbcp/data_mart/__init__.py b/src/dbcp/data_mart/__init__.py index 5a4b708f..eff435d6 100644 --- a/src/dbcp/data_mart/__init__.py +++ b/src/dbcp/data_mart/__init__.py @@ -5,6 +5,8 @@ import pkgutil import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq import dbcp from dbcp.constants import OUTPUT_DIR @@ -15,7 +17,7 @@ logger = logging.getLogger(__name__) -def create_data_marts(args): # noqa: max-complexity=11 +def create_data_marts(): # noqa: max-complexity=11 """Collect and load all data mart tables to data warehouse.""" engine = dbcp.helpers.get_sql_engine() data_marts = {} @@ -64,8 +66,8 @@ def create_data_marts(args): # noqa: max-complexity=11 with engine.connect() as con: for table in metadata.sorted_tables: logger.info(f"Load {table.name} to postgres.") - df = enforce_dtypes(data_marts[table.name], table.name, "data_mart") - df = dbcp.helpers.trim_columns_length(df) + df = dbcp.helpers.trim_columns_length(data_marts[table.name]) + df = enforce_dtypes(df, table.name, "data_mart") df.to_sql( name=table.name, con=con, @@ -74,7 +76,10 @@ def create_data_marts(args): # noqa: max-complexity=11 schema="data_mart", method=psql_insert_copy, ) - - df.to_parquet(parquet_dir / f"{table.name}.parquet", index=False) + schema = dbcp.helpers.get_pyarrow_schema_from_metadata( + table.name, "data_mart" + ) + pa_table = pa.Table.from_pandas(df, schema=schema) + pq.write_table(pa_table, parquet_dir / f"{table.name}.parquet") validate_data_mart(engine=engine) diff --git a/src/dbcp/etl.py b/src/dbcp/etl.py index da138a9f..2ba21634 100644 --- a/src/dbcp/etl.py +++ b/src/dbcp/etl.py @@ -5,6 +5,8 @@ from typing import Callable, Dict import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq import sqlalchemy as sa import dbcp @@ -225,12 +227,17 @@ def run_etl(funcs: dict[str, Callable], schema_name: str): chunksize=1000, method=psql_insert_copy, ) - df.to_parquet(parquet_dir / f"{table.name}.parquet", index=False) + + schema = dbcp.helpers.get_pyarrow_schema_from_metadata( + table.name, schema_name + ) + pa_table = pa.Table.from_pandas(df, schema=schema) + pq.write_table(pa_table, parquet_dir / f"{table.name}.parquet") logger.info("Sucessfully finished ETL.") -def etl(args): +def etl(): """Run dbc ETL.""" # Reduce size of caches if necessary GEOCODER_CACHE.reduce_size() diff --git a/src/dbcp/helpers.py b/src/dbcp/helpers.py index bf19053e..cd5b7bc5 100644 --- a/src/dbcp/helpers.py +++ b/src/dbcp/helpers.py @@ -11,6 +11,7 @@ import google.auth import pandas as pd import pandas_gbq +import pyarrow as pa import sqlalchemy as sa from google.cloud import bigquery from tqdm import tqdm @@ -35,6 +36,14 @@ "BOOLEAN": "boolean", "DATETIME": "datetime64[ns]", } +SA_TO_PA_TYPES = { + "VARCHAR": pa.string(), + "INTEGER": pa.int64(), + "BIGINT": pa.int64(), + "FLOAT": pa.float64(), + "BOOLEAN": pa.bool_(), + "DATETIME": pa.timestamp("ms"), +} SA_TO_BQ_MODES = {True: "NULLABLE", False: "REQUIRED"} @@ -81,6 +90,25 @@ def get_bq_schema_from_metadata( return bq_schema +def get_pyarrow_schema_from_metadata(table_name: str, schema: str) -> pa.Schema: + """ + Create a PyArrow schema from SQL Alchemy metadata. + + Args: + table_name: the name of the table. + schema: the name of the database schema. + Returns: + pyarrow_schema: a PyArrow schema description. + """ + table_name = f"{schema}.{table_name}" + metadata = get_schema_sql_alchemy_metadata(schema) + table_sa = metadata.tables[table_name] + pyarrow_schema = [] + for column in table_sa.columns: + pyarrow_schema.append((column.name, SA_TO_PA_TYPES[str(column.type)])) + return pa.schema(pyarrow_schema) + + def enforce_dtypes(df: pd.DataFrame, table_name: str, schema: str): """Apply dtypes to a dataframe using the sqlalchemy metadata.""" table_name = f"{schema}.{table_name}" @@ -90,12 +118,16 @@ def enforce_dtypes(df: pd.DataFrame, table_name: str, schema: str): except KeyError: raise KeyError(f"{table_name} does not exist in metadata.") - dtypes = { - col.name: SA_TO_PD_TYPES[str(col.type)] - for col in table.columns - if col.name in df.columns - } - return df.astype(dtypes) + for col in table.columns: + # Add the column if it doesn't exist + if col.name not in df.columns: + df[col.name] = None + df[col.name] = df[col.name].astype(SA_TO_PD_TYPES[str(col.type)]) + + # convert datetime[ns] columns to milliseconds + for col in df.select_dtypes(include=["datetime64[ns]"]).columns: + df[col] = df[col].dt.floor("ms") + return df def get_sql_engine() -> sa.engine.Engine: