Skip to content

Commit

Permalink
Add rollback logic if update fails and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zacdezgeo committed Nov 6, 2024
1 parent 958262e commit 36c1843
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 53 deletions.
52 changes: 32 additions & 20 deletions space2stats_api/src/space2stats_ingest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def load_parquet_to_db(
conn.commit()
return

# Load Parquet file into a temporary table
parquet_table = read_parquet_file(parquet_file)
temp_table = f"{TABLE_NAME}_temp"
with pg.connect(connection_string) as conn, tqdm(
Expand All @@ -143,7 +144,7 @@ def load_parquet_to_db(

conn.commit()

# Fetch columns to add to dataset
# Fetch columns to add to the main table
with pg.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute(f"""
Expand All @@ -156,33 +157,44 @@ def load_parquet_to_db(
""")
new_columns = cur.fetchall()

for column, column_type in new_columns:
cur.execute(
f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column} {column_type}"
# Add new columns and attempt to update in a transaction
try:
with pg.connect(connection_string) as conn:
with conn.cursor() as cur:
# Add new columns to the main table
for column, column_type in new_columns:
cur.execute(
f"ALTER TABLE {TABLE_NAME} ADD COLUMN IF NOT EXISTS {column} {column_type}"
)

print(f"Adding new columns: {[c[0] for c in new_columns]}...")

# Construct the SET clause for the update query
update_columns = [
f"{column} = temp.{column}" for column, _ in new_columns
]
set_clause = ", ".join(update_columns)

# Update TABLE_NAME with data from temp_table based on matching hex_id
print(
"Adding columns to dataset... All or nothing operation may take some time."
)

conn.commit()

print(f"Adding new columns: {[c[0] for c in new_columns]}...")

# Update TABLE_NAME with data from temp_table based on matching hex_id
print("Adding columns to dataset... All or nothing operation may take some time.")
with pg.connect(connection_string) as conn:
with conn.cursor() as cur:
update_columns = [f"{column} = temp.{column}" for column, _ in new_columns]

set_clause = ", ".join(update_columns)

cur.execute(f"""
cur.execute(f"""
UPDATE {TABLE_NAME} AS main
SET {set_clause}
FROM {temp_table} AS temp
WHERE main.hex_id = temp.hex_id
""")

conn.commit()
conn.commit() # Commit transaction if all operations succeed
except Exception as e:
# Rollback if any error occurs during the update
print("An error occurred during update. Rolling back changes.")
conn.rollback()
raise e # Re-raise the exception to alert calling code

# Drop the temporary table
with pg.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute(f"DROP TABLE {temp_table}")
cur.execute(f"DROP TABLE IF EXISTS {temp_table}")
conn.commit()
192 changes: 159 additions & 33 deletions space2stats_api/src/tests/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ def test_load_parquet_to_db(clean_database, tmpdir):
connection_string = f"postgresql://{clean_database.user}:{clean_database.password}@{clean_database.host}:{clean_database.port}/{clean_database.dbname}"

parquet_file = tmpdir.join("local.parquet")

catalog_file = tmpdir.join("catalog.json")
collection_file = tmpdir.join("collection.json")
item_file = tmpdir.join("space2stats_population_2020.json")

data = {
Expand Down Expand Up @@ -45,36 +42,6 @@ def test_load_parquet_to_db(clean_database, tmpdir):
with open(item_file, "w") as f:
json.dump(stac_item, f)

stac_collection = {
"type": "Collection",
"stac_version": "1.0.0",
"id": "space2stats-collection",
"description": "Test collection for Space2Stats.",
"license": "CC-BY-4.0",
"extent": {
"spatial": {"bbox": [[-180, -90, 180, 90]]},
"temporal": {"interval": [["2020-01-01T00:00:00Z", None]]},
},
"links": [{"rel": "item", "href": str(item_file), "type": "application/json"}],
}

with open(collection_file, "w") as f:
json.dump(stac_collection, f)

stac_catalog = {
"type": "Catalog",
"stac_version": "1.0.0",
"id": "space2stats-catalog",
"description": "Test catalog for Space2Stats.",
"license": "CC-BY-4.0",
"links": [
{"rel": "child", "href": str(collection_file), "type": "application/json"}
],
}

with open(catalog_file, "w") as f:
json.dump(stac_catalog, f)

load_parquet_to_db(str(parquet_file), connection_string, str(item_file))

with psycopg.connect(connection_string) as conn:
Expand Down Expand Up @@ -168,3 +135,162 @@ def test_updating_table(clean_database, tmpdir):
cur.execute("SELECT * FROM space2stats WHERE hex_id = 'hex_2'")
result = cur.fetchone()
assert result == ("hex_2", 200, 250, 20_000)


def test_columns_already_exist_in_db(clean_database, tmpdir):
connection_string = f"postgresql://{clean_database.user}:{clean_database.password}@{clean_database.host}:{clean_database.port}/{clean_database.dbname}"

parquet_file = tmpdir.join("local.parquet")
data = {
"hex_id": ["hex_1", "hex_2"],
"existing_column": [123, 456], # Simulates an existing column in DB
"new_column": [789, 1011],
}
table = pa.table(data)
pq.write_table(table, parquet_file)

stac_item = {
"type": "Feature",
"stac_version": "1.0.0",
"id": "space2stats_population_2020",
"properties": {
"table:columns": [
{"name": "hex_id", "type": "string"},
{"name": "existing_column", "type": "int64"},
{"name": "new_column", "type": "int64"},
],
"datetime": "2024-10-07T11:21:25.944150Z",
},
"geometry": None,
"bbox": [-180, -90, 180, 90],
"links": [],
"assets": {},
}

item_file = tmpdir.join("space2stats_population_2020.json")
with open(item_file, "w") as f:
json.dump(stac_item, f)

load_parquet_to_db(str(parquet_file), connection_string, str(item_file))

with psycopg.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute("SELECT * FROM space2stats WHERE hex_id = 'hex_1'")
result = cur.fetchone()
assert result == ("hex_1", 123, 789) # Verify no duplicates


def test_rollback_on_update_failure(clean_database, tmpdir):
connection_string = f"postgresql://{clean_database.user}:{clean_database.password}@{clean_database.host}:{clean_database.port}/{clean_database.dbname}"

parquet_file = tmpdir.join("local.parquet")
data = {
"hex_id": ["hex_1", "hex_2"],
"sum_pop_f_10_2020": [100, 200],
"sum_pop_m_10_2020": [150, 250],
}
table = pa.table(data)
pq.write_table(table, parquet_file)

stac_item = {
"type": "Feature",
"stac_version": "1.0.0",
"id": "space2stats_population_2020",
"properties": {
"table:columns": [
{"name": "hex_id", "type": "string"},
{"name": "sum_pop_f_10_2020", "type": "int64"},
{"name": "sum_pop_m_10_2020", "type": "int64"},
],
"datetime": "2024-10-07T11:21:25.944150Z",
},
"geometry": None,
"bbox": [-180, -90, 180, 90],
"links": [],
"assets": {},
}

item_file = tmpdir.join("space2stats_population_2020.json")
with open(item_file, "w") as f:
json.dump(stac_item, f)

load_parquet_to_db(str(parquet_file), connection_string, str(item_file))

# Invalid Parquet without `hex_id`
update_parquet_file = tmpdir.join("update_local.parquet")
update_data = {
"new_column": [1000, 2000],
}
update_table = pa.table(update_data)
pq.write_table(update_table, update_parquet_file)

update_item_file = tmpdir.join("update_item.json")
update_stac_item = {
"type": "Feature",
"stac_version": "1.0.0",
"id": "space2stats_population_2021",
"properties": {
"table:columns": [{"name": "new_column", "type": "int64"}],
"datetime": "2024-10-07T11:21:25.944150Z",
},
"geometry": None,
"bbox": [-180, -90, 180, 90],
"links": [],
"assets": {},
}

with open(update_item_file, "w") as f:
json.dump(update_stac_item, f)

try:
load_parquet_to_db(
str(update_parquet_file), connection_string, str(update_item_file)
)
except ValueError:
pass

with psycopg.connect(connection_string) as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT column_name FROM information_schema.columns WHERE table_name = 'space2stats'"
)
columns = [row[0] for row in cur.fetchall()]
assert "new_column" not in columns # Verify no unwanted columns were added


def test_hex_id_column_mandatory(clean_database, tmpdir):
connection_string = f"postgresql://{clean_database.user}:{clean_database.password}@{clean_database.host}:{clean_database.port}/{clean_database.dbname}"

parquet_file = tmpdir.join("missing_hex_id.parquet")
data = {
"sum_pop_f_10_2020": [100, 200],
"sum_pop_m_10_2020": [150, 250],
}
table = pa.table(data)
pq.write_table(table, parquet_file)

stac_item = {
"type": "Feature",
"stac_version": "1.0.0",
"id": "space2stats_population_2020",
"properties": {
"table:columns": [
{"name": "sum_pop_f_10_2020", "type": "int64"},
{"name": "sum_pop_m_10_2020", "type": "int64"},
],
"datetime": "2024-10-07T11:21:25.944150Z",
},
"geometry": None,
"bbox": [-180, -90, 180, 90],
"links": [],
"assets": {},
}

item_file = tmpdir.join("space2stats_population_2020.json")
with open(item_file, "w") as f:
json.dump(stac_item, f)

try:
load_parquet_to_db(str(parquet_file), connection_string, str(item_file))
except ValueError as e:
assert "The 'hex_id' column is missing from the Parquet file." in str(e)

0 comments on commit 36c1843

Please sign in to comment.