diff --git a/airflow/dags/bodhi_waves_to_sl_db.py b/airflow/dags/bodhi_waves_to_sl_db.py new file mode 100644 index 0000000..a1ea121 --- /dev/null +++ b/airflow/dags/bodhi_waves_to_sl_db.py @@ -0,0 +1,148 @@ +import logging +import os + +import pandas as pd +import pendulum +from airflow.decorators import dag, task +from extensions.models.models import BodhiWaves, bodhi_engine, create_tables, engine, get_session +from extensions.schemas.schemas import BodhiWavesModel, SlOffshoreIdx +from extensions.utils.sl_data import SpotsForecast, SpotsGetter +from sqlalchemy import insert, select, text + +# db_uri = LOCAL_PG_URI +# Have to declare it this way for now +# the parser is giving an error on initial load +# using environ directly seems to fix it +db_uri = os.environ.get("SUPABASE_PG_URI") + +start_date = pendulum.datetime(2024, 6, 9) + +default_args = { + "owner": "airflow", + "depends_on_past": False, + "start_date": start_date, + "email": ["your-email@example.com"], + "email_on_failure": False, + "email_on_retry": False, + "retries": 1, + "retry_delay": pendulum.duration(minutes=5), +} + + +@dag( + dag_id="bodhi_waves_to_sl_db", + start_date=start_date, + schedule="30 10 * * *", + catchup=False, + is_paused_upon_creation=False, +) +def taskflow(): + def fetch_wave_data(lat_lon_str): + try: + with get_session(bodhi_engine) as db: + stmt = text( + f"""select * from wave_forecast where time = CURRENT_DATE AND (latitude, longitude) in ({lat_lon_str})""" + ) + results = db.execute(stmt).fetchall() + data = [BodhiWavesModel.model_validate(entry) for entry in results] + data_dict = [entry.model_dump() for entry in data] + for d in data_dict: + d.pop("location", None) + return data_dict + except Exception as e: + logging.error(f"Error fetching wave data: {str(e)}") + return [] + + def wave_data_to_db(data): + try: + with get_session(engine) as db: + stmt = insert(BodhiWaves).values(data) + db.execute(stmt) + db.commit() + except Exception as e: + logging.error(f"Error inserting wave data to database: {str(e)}") + + def batch(iterable, n=1): + l = len(iterable) + for idx in range(0, l, n): + yield iterable[idx : min(idx + n, l)] + + def get_all_batches(lat_lon_list, bs=10): + try: + processed = 0 + for batch_lat_lon_list in batch(lat_lon_list, bs): + lat_lon_str = ", ".join(map(str, batch_lat_lon_list)) + data = fetch_wave_data(lat_lon_str) + wave_data_to_db(data) + processed += len(batch_lat_lon_list) + logging.info(f"Processed {processed} out of {len(lat_lon_list)}.") + except Exception as e: + logging.error(f"Error processing batches: {str(e)}") + + @task() + def handle_enable_extension(): + with get_session(engine) as db: + stmt = text("""CREATE EXTENSION IF NOT EXISTS postgis""") + db.execute(stmt) + db.commit() + + @task() + def handle_db_idxs(): + try: + with get_session(bodhi_engine) as db: + stmt = text( + """CREATE INDEX if not exists idx_wave_forecast_lat_lon ON wave_forecast (latitude, longitude)""" + ) + db.execute(stmt) + db.commit() + except Exception as e: + logging.error(f"Error handling database indexes: {str(e)}") + + @task() + def get_spot_offshore_locations(): + try: + with get_session(engine) as db: + stmt = text( + """select distinct on ("associated_spotId") "associated_spotId", "associated_offshoreLocation_lat", "associated_offshoreLocation_lon" from sl_ratings""" + ) + results = db.execute(stmt).fetchall() + + data = [SlOffshoreIdx.model_validate(entry) for entry in results] + data_dicts = [entry.model_dump() for entry in data] + + df = pd.DataFrame(data_dicts) + # Create a mask to only keep lat an lon where they are in the intervals .0, .25, .5, .75 + df["lat_mod"] = df["associated_offshoreLocation_lat"] % 4 + df["lon_mod"] = df["associated_offshoreLocation_lon"] % 4 + + mask = df["lat_mod"].apply(lambda x: round(x, 2) == x) & df["lon_mod"].apply( + lambda x: round(x, 2) == x + ) + df = df[mask] + df = df.drop(columns=["lat_mod", "lon_mod"]) + + lat_lon_list = list( + set( + zip( + df["associated_offshoreLocation_lat"].values, + df["associated_offshoreLocation_lon"].values, + ) + ) + ) + return lat_lon_list + except Exception as e: + logging.error(f"Error getting spot offshore locations: {str(e)}") + return [] + + @task() + def bodhi_waves_to_db(lat_lon_list): + # Xcom serializes tuples to list of lists, so deserialize back to List[Tuple] + lat_lon_list = [tuple(pair) for pair in lat_lon_list] + get_all_batches(lat_lon_list=lat_lon_list, bs=50) + + handle_enable_extension() >> handle_db_idxs() + lat_lon_list = get_spot_offshore_locations() + bodhi_waves_to_db(lat_lon_list) + + +dag_run = taskflow() diff --git a/airflow/extensions/models/models.py b/airflow/extensions/models/models.py index 610ad53..f52f96b 100644 --- a/airflow/extensions/models/models.py +++ b/airflow/extensions/models/models.py @@ -8,20 +8,30 @@ DateTime, Float, Integer, + Interval, String, Text, create_engine, + func, ) -from sqlalchemy.orm import declarative_base +from sqlalchemy.engine import Engine +from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.orm.session import Session Base = declarative_base() engine = create_engine(os.environ.get("SUPABASE_PG_URI")) +bodhi_engine = create_engine(os.environ.get("AIRFLOW__DATABASE__SQL_ALCHEMY_CONN")) def create_tables(): Base.metadata.create_all(bind=engine) +def get_session(engine: Engine) -> Session: + SessionLocal = sessionmaker(bind=engine) + return SessionLocal() + + class SlSpots(Base): __tablename__ = "sl_spots" spot_id = Column(Text, primary_key=True) @@ -72,3 +82,24 @@ class SlRatings(Base): data_wave_timestamp = Column(String) swells_idx = Column(Integer) timestamp_utc = Column(DateTime) + + +class BodhiWaves(Base): + __tablename__ = "bodhi_waves" + id = Column(BigInteger, primary_key=True) + latitude = Column(Float) + longitude = Column(Float) + time = Column(DateTime(timezone=True)) + step = Column(Interval) # using an Interval to represent a timedelta + valid_time = Column(DateTime(timezone=True)) + swh = Column(Float) # Significant height of combined wind waves and swell + perpw = Column(Float) # Primary wave mean period + dirpw = Column(Float) # Primary wave direction + shww = Column(Float) # Significant height of wind waves + mpww = Column(Float) # Mean period of wind waves + wvdir = Column(Float) # Direction of wind waves + ws = Column(Float) # Wind speed + wdir = Column(Float) # Wind direction + swell = Column(Float) # Significant height of swell waves + swper = Column(Float) # Mean period of swell waves + entry_updated = Column(DateTime(timezone=True), onupdate=func.now()) diff --git a/airflow/extensions/schemas/schemas.py b/airflow/extensions/schemas/schemas.py index 8b150a1..e92e13e 100644 --- a/airflow/extensions/schemas/schemas.py +++ b/airflow/extensions/schemas/schemas.py @@ -1,5 +1,9 @@ from dataclasses import dataclass +from datetime import datetime, timedelta from enum import Enum +from typing import Any, List, Optional + +from pydantic import BaseModel, ConfigDict @dataclass @@ -19,3 +23,34 @@ class SlApiEndpoints(Enum): WIND = "wind" TIDES = "tides" WEATHER = "weather" + + +class SlOffshoreIdx(BaseModel): + associated_spotId: str + associated_offshoreLocation_lat: float + associated_offshoreLocation_lon: float + + model_config = ConfigDict(from_attributes=True) + + +class BodhiWavesModel(BaseModel): + id: int + location: str + latitude: float + longitude: float + time: datetime + step: timedelta + valid_time: datetime + swh: Optional[float] + perpw: Optional[float] + dirpw: Optional[float] + shww: Optional[float] + mpww: Optional[float] + wvdir: Optional[float] + ws: Optional[float] + wdir: Optional[float] + swell: Optional[float] + swper: Optional[float] + entry_updated: datetime + + model_config = ConfigDict(from_attributes=True) diff --git a/airflow/requirements.txt b/airflow/requirements.txt index 5108332..2f7dca7 100644 --- a/airflow/requirements.txt +++ b/airflow/requirements.txt @@ -8,4 +8,5 @@ debugpy noaa-coops==0.3.2 tqdm==4.66.2 memray==1.11.0 -geopandas \ No newline at end of file +geopandas +pydantic==2.7 \ No newline at end of file