Skip to content

Commit

Permalink
Merge pull request #50 from Dewberry/ruff-docs
Browse files Browse the repository at this point in the history
consolidate/centralize utils; update docstrings.
  • Loading branch information
slawler authored Jun 28, 2024
2 parents 317b53d + 610b318 commit daf3f01
Show file tree
Hide file tree
Showing 30 changed files with 1,348 additions and 1,322 deletions.
91 changes: 4 additions & 87 deletions production/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import os
import sqlite3

import psycopg2
from dotenv import load_dotenv
from psycopg2 import sql

from ripple.stacio.s3_utils import list_keys
from ripple.utils import get_sessioned_s3_client

load_dotenv()


class PGFim:
"""Class to interact with the FIM database."""

def __init__(self):
self.dbuser = os.getenv("DBUSER")
self.dbpass = os.getenv("DBPASS")
Expand All @@ -23,15 +21,8 @@ def __conn_string(self):
conn_string = f"dbname='{self.dbname}' user='{self.dbuser}' password='{self.dbpass}' host='{self.dbhost}' port='{self.dbport}'"
return conn_string

def create_table(self, table: str, fields: list[str]):
with psycopg2.connect(self.__conn_string()) as connection:
fields_string = ""
for field, data_type in fields:
fields_string += f"{field} {data_type}"
cursor = connection.cursor()
cursor.execute(sql.SQL(f"CREATE TABLE cases.{table} ({fields_string})"))

def read_cases(self, table: str, fields: list[str], mip_group: str, optional_condition=""):
"""Read cases from the cases schema."""
with psycopg2.connect(self.__conn_string()) as connection:
cursor = connection.cursor()
fields_str = ""
Expand All @@ -43,21 +34,10 @@ def read_cases(self, table: str, fields: list[str], mip_group: str, optional_con
cursor.execute(sql_query)
return cursor.fetchall()

def update_table(self, table: str, fields: tuple, values: tuple):
with psycopg2.connect(self.__conn_string()) as connection:
cursor = connection.cursor()
insert_query = sql.SQL(
f"""
INSERT INTO cases.{table}{fields}
VALUES ({tuple("%s" for i in range(len(fields)))})
"""
)
cursor.execute(insert_query, values)
connection.commit()

def update_case_status(
self, mip_group: str, mip_case: str, key: str, status: bool, exc: str, traceback: str, process: str
):
"""Update the status of a table in the cases schema."""
with psycopg2.connect(self.__conn_string()) as connection:
cursor = connection.cursor()
insert_query = sql.SQL(
Expand All @@ -73,66 +53,3 @@ def update_case_status(
)
cursor.execute(insert_query, (key, mip_group, mip_case, status, exc, traceback))
connection.commit()


def read_case_db(cases_db_path: str, table_name: str):
with sqlite3.connect(cases_db_path) as connection:
cursor = connection.cursor()
cursor.execute(f"SELECT key, crs FROM {table_name} ")
return cursor.fetchall()


def add_columns(cases_db_path: str, table_name: str, columns: list[str]):
with sqlite3.connect(cases_db_path) as connection:
cursor = connection.cursor()
existing_columns = cursor.execute(f"""SELECT * FROM {table_name}""")
for column in columns:
if column in [c[0] for c in existing_columns.description]:
cursor.execute(f"ALTER TABLE {table_name} DROP {column}")
connection.commit()
cursor.execute(f"ALTER TABLE {table_name} ADD {column} TEXT")
connection.commit()


def insert_data(cases_db_path: str, table_name: str, data):
with sqlite3.connect(cases_db_path) as connection:
cursor = connection.cursor()
for key, val in data.items():
cursor.execute(
f"""INSERT OR REPLACE INTO {table_name} (exc, tb, gpkg, crs, key) VALUES (?, ?, ?, ?, ?)""",
(val["exc"], val["tb"], val["gpkg"], val["crs"], key),
)
connection.commit()


def create_table(cases_db_path: str, table_name: str):
with sqlite3.connect(cases_db_path) as connection:
cursor = connection.cursor()
res = cursor.execute(f"SELECT name FROM sqlite_master WHERE name='{table_name}'")
if res.fetchone():
cursor.execute(f"DROP TABLE {table_name}")
connection.commit()

cursor.execute(f"""Create Table {table_name} (key Text, crs Text, gpkg Text, exc Text, tb Text)""")
connection.commit()


def create_tx_ble_db(s3_prefix: str, crs: int, db_path: str, bucket: str = "fim"):

if os.path.exists(db_path):
os.remove(db_path)
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
cursor.execute("""create table tx_ble_crs_A (key Text, crs Text)""")

client = get_sessioned_s3_client()
keys = list_keys(client, bucket, s3_prefix, ".prj")
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
for i, key in enumerate(keys):

cursor.execute(
"""Insert or replace into tx_ble_crs_A (key,crs) values (?, ?)""",
(key, f"EPSG:{crs}"),
)
connection.commit()
8 changes: 4 additions & 4 deletions production/step_1_extract_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def process_one_geom(
crs: str,
bucket: str = None,
):
"""Process one geometry file and convert it to geopackage."""
# create path name for gpkg
if key.endswith(".prj"):
gpkg_path = key.replace("prj", "gpkg")
Expand All @@ -31,9 +32,7 @@ def process_one_geom(


def main(table_name: str, mip_group: str, bucket: str = None):
"""
Reads from database a list of ras files to convert to geopackage
"""
"""Read from database a list of ras files to convert to geopackage."""
db = PGFim()

data = db.read_cases(table_name, ["mip_case", "s3_key", "crs"], mip_group)
Expand All @@ -60,5 +59,6 @@ def main(table_name: str, mip_group: str, bucket: str = None):

table_name = "inferred_crs"
bucket = "fim"
mip_group = "tx_ble"

main(table_name, "tx_ble", bucket)
main(table_name, mip_group, bucket)
2 changes: 1 addition & 1 deletion production/step_2_create_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def main(mip_group: str, table_name: str, bucket: str, ripple_version: str):

"""Read from database a list of geopackages to create stac items."""
db = PGFim()
optional_condition = "AND gpkg_complete=true AND stac_complete IS NULL"
data = db.read_cases(table_name, ["case_id", "s3_key"], mip_group, optional_condition)
Expand Down
7 changes: 3 additions & 4 deletions production/step_3_conflate_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import traceback
from pathlib import Path
from time import sleep

import pystac
Expand All @@ -9,14 +8,14 @@
from ripple.conflate.rasfim import RasFimConflater
from ripple.ops.conflate_ras_model import conflate_s3_model, href_to_vsis
from ripple.ripple_logger import configure_logging
from ripple.stacio.utils.s3_utils import init_s3_resources, s3_key_public_url_converter
from ripple.utils.s3_utils import init_s3_resources


def main(table_name: str, mip_group: str, bucket: str, nwm_pq_path: str):

"""Read from database a list of geopackages to create stac items."""
db = PGFim()

session, client, s3_resource = init_s3_resources()
_, client, _ = init_s3_resources()
rfc = None
optional_condition = "AND stac_complete=true AND conflation_complete IS NULL"
data = db.read_cases(
Expand Down
26 changes: 18 additions & 8 deletions ripple/conflate/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def plot_conflation_results(
s3_client: Session.client = None,
limit_plot_to_nearby_reaches: bool = True,
):
"""
Create/write png to s3. The png contains RAS centerline and cross sections and nearby NWM branches
"""
"""Create/write png to s3. The png contains RAS centerline and cross sections and nearby NWM branches."""
_, ax = plt.subplots(figsize=(10, 10))

# Plot the centerline and cross-sections first
Expand All @@ -33,14 +31,26 @@ def plot_conflation_results(
ylim = ax.get_ylim()
bounds = shapely.geometry.box(xlim[0], ylim[0], xlim[1], ylim[1])

zoom_factor = 3 # Adjust this value to change the zoom level

# Calculate the range of x and y
x_range = bounds.bounds[2] - bounds.bounds[0]
y_range = bounds.bounds[3] - bounds.bounds[1]

min_x = bounds.bounds[0] - x_range * (zoom_factor - 1) / 2
max_x = bounds.bounds[2] + x_range * (zoom_factor - 1) / 2
min_y = bounds.bounds[1] - y_range * (zoom_factor - 1) / 2
max_y = bounds.bounds[3] + y_range * (zoom_factor - 1) / 2

adjusted_bounds = shapely.geometry.box(min_x, min_y, max_x, max_y)
# Add a patch for the ras_centerline
patches = [mpatches.Patch(color="black", label="RAS Centerline", linestyle="dashed")]

# Add a patch for nearby reaches
patches.append(mpatches.Patch(color="blue", label="Nearby NWM reaches", alpha=0.3))

# Plot the reaches that fall within the axis limits
rfc.nwm_reaches.plot(ax=ax, color="blue", linewidth=1, alpha=0.3)
rfc.nwm_reaches.clip(adjusted_bounds).plot(ax=ax, color="blue", linewidth=1, alpha=0.3)

if limit_plot_to_nearby_reaches:
# Create a colormap that maps each reach_id to a color
Expand All @@ -64,12 +74,12 @@ def plot_conflation_results(

# Set the axis limits to the bounds, expanded by the zoom factor
ax.set_xlim(
bounds.bounds[0] - x_range * (zoom_factor - 1) / 2,
bounds.bounds[2] + x_range * (zoom_factor - 1) / 2,
min_x,
max_x,
)
ax.set_ylim(
bounds.bounds[1] - y_range * (zoom_factor - 1) / 2,
bounds.bounds[3] + y_range * (zoom_factor - 1) / 2,
min_y,
max_y,
)

ax.legend(handles=patches, handleheight=0.005)
Expand Down
Loading

0 comments on commit daf3f01

Please sign in to comment.