Skip to content

Commit

Permalink
add waves to supa pipeline dag and update models
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbull committed Jun 12, 2024
1 parent f394a20 commit c884f8f
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 2 deletions.
148 changes: 148 additions & 0 deletions airflow/dags/bodhi_waves_to_sl_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import logging
import os

import pandas as pd
import pendulum
from airflow.decorators import dag, task
from extensions.models.models import BodhiWaves, bodhi_engine, create_tables, engine, get_session
from extensions.schemas.schemas import BodhiWavesModel, SlOffshoreIdx
from extensions.utils.sl_data import SpotsForecast, SpotsGetter
from sqlalchemy import insert, select, text

# db_uri = LOCAL_PG_URI
# Have to declare it this way for now
# the parser is giving an error on initial load
# using environ directly seems to fix it
db_uri = os.environ.get("SUPABASE_PG_URI")

start_date = pendulum.datetime(2024, 6, 9)

default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": start_date,
"email": ["[email protected]"],
"email_on_failure": False,
"email_on_retry": False,
"retries": 1,
"retry_delay": pendulum.duration(minutes=5),
}


@dag(
dag_id="bodhi_waves_to_sl_db",
start_date=start_date,
schedule="30 10 * * *",
catchup=False,
is_paused_upon_creation=False,
)
def taskflow():
def fetch_wave_data(lat_lon_str):
try:
with get_session(bodhi_engine) as db:
stmt = text(
f"""select * from wave_forecast where time = CURRENT_DATE AND (latitude, longitude) in ({lat_lon_str})"""
)
results = db.execute(stmt).fetchall()
data = [BodhiWavesModel.model_validate(entry) for entry in results]
data_dict = [entry.model_dump() for entry in data]
for d in data_dict:
d.pop("location", None)
return data_dict
except Exception as e:
logging.error(f"Error fetching wave data: {str(e)}")
return []

def wave_data_to_db(data):
try:
with get_session(engine) as db:
stmt = insert(BodhiWaves).values(data)
db.execute(stmt)
db.commit()
except Exception as e:
logging.error(f"Error inserting wave data to database: {str(e)}")

def batch(iterable, n=1):
l = len(iterable)
for idx in range(0, l, n):
yield iterable[idx : min(idx + n, l)]

def get_all_batches(lat_lon_list, bs=10):
try:
processed = 0
for batch_lat_lon_list in batch(lat_lon_list, bs):
lat_lon_str = ", ".join(map(str, batch_lat_lon_list))
data = fetch_wave_data(lat_lon_str)
wave_data_to_db(data)
processed += len(batch_lat_lon_list)
logging.info(f"Processed {processed} out of {len(lat_lon_list)}.")
except Exception as e:
logging.error(f"Error processing batches: {str(e)}")

@task()
def handle_enable_extension():
with get_session(engine) as db:
stmt = text("""CREATE EXTENSION IF NOT EXISTS postgis""")
db.execute(stmt)
db.commit()

@task()
def handle_db_idxs():
try:
with get_session(bodhi_engine) as db:
stmt = text(
"""CREATE INDEX if not exists idx_wave_forecast_lat_lon ON wave_forecast (latitude, longitude)"""
)
db.execute(stmt)
db.commit()
except Exception as e:
logging.error(f"Error handling database indexes: {str(e)}")

@task()
def get_spot_offshore_locations():
try:
with get_session(engine) as db:
stmt = text(
"""select distinct on ("associated_spotId") "associated_spotId", "associated_offshoreLocation_lat", "associated_offshoreLocation_lon" from sl_ratings"""
)
results = db.execute(stmt).fetchall()

data = [SlOffshoreIdx.model_validate(entry) for entry in results]
data_dicts = [entry.model_dump() for entry in data]

df = pd.DataFrame(data_dicts)
# Create a mask to only keep lat an lon where they are in the intervals .0, .25, .5, .75
df["lat_mod"] = df["associated_offshoreLocation_lat"] % 4
df["lon_mod"] = df["associated_offshoreLocation_lon"] % 4

mask = df["lat_mod"].apply(lambda x: round(x, 2) == x) & df["lon_mod"].apply(
lambda x: round(x, 2) == x
)
df = df[mask]
df = df.drop(columns=["lat_mod", "lon_mod"])

lat_lon_list = list(
set(
zip(
df["associated_offshoreLocation_lat"].values,
df["associated_offshoreLocation_lon"].values,
)
)
)
return lat_lon_list
except Exception as e:
logging.error(f"Error getting spot offshore locations: {str(e)}")
return []

@task()
def bodhi_waves_to_db(lat_lon_list):
# Xcom serializes tuples to list of lists, so deserialize back to List[Tuple]
lat_lon_list = [tuple(pair) for pair in lat_lon_list]
get_all_batches(lat_lon_list=lat_lon_list, bs=50)

handle_enable_extension() >> handle_db_idxs()
lat_lon_list = get_spot_offshore_locations()
bodhi_waves_to_db(lat_lon_list)


dag_run = taskflow()
33 changes: 32 additions & 1 deletion airflow/extensions/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,30 @@
DateTime,
Float,
Integer,
Interval,
String,
Text,
create_engine,
func,
)
from sqlalchemy.orm import declarative_base
from sqlalchemy.engine import Engine
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.orm.session import Session

Base = declarative_base()
engine = create_engine(os.environ.get("SUPABASE_PG_URI"))
bodhi_engine = create_engine(os.environ.get("AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"))


def create_tables():
Base.metadata.create_all(bind=engine)


def get_session(engine: Engine) -> Session:
SessionLocal = sessionmaker(bind=engine)
return SessionLocal()


class SlSpots(Base):
__tablename__ = "sl_spots"
spot_id = Column(Text, primary_key=True)
Expand Down Expand Up @@ -72,3 +82,24 @@ class SlRatings(Base):
data_wave_timestamp = Column(String)
swells_idx = Column(Integer)
timestamp_utc = Column(DateTime)


class BodhiWaves(Base):
__tablename__ = "bodhi_waves"
id = Column(BigInteger, primary_key=True)
latitude = Column(Float)
longitude = Column(Float)
time = Column(DateTime(timezone=True))
step = Column(Interval) # using an Interval to represent a timedelta
valid_time = Column(DateTime(timezone=True))
swh = Column(Float) # Significant height of combined wind waves and swell
perpw = Column(Float) # Primary wave mean period
dirpw = Column(Float) # Primary wave direction
shww = Column(Float) # Significant height of wind waves
mpww = Column(Float) # Mean period of wind waves
wvdir = Column(Float) # Direction of wind waves
ws = Column(Float) # Wind speed
wdir = Column(Float) # Wind direction
swell = Column(Float) # Significant height of swell waves
swper = Column(Float) # Mean period of swell waves
entry_updated = Column(DateTime(timezone=True), onupdate=func.now())
35 changes: 35 additions & 0 deletions airflow/extensions/schemas/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, List, Optional

from pydantic import BaseModel, ConfigDict


@dataclass
Expand All @@ -19,3 +23,34 @@ class SlApiEndpoints(Enum):
WIND = "wind"
TIDES = "tides"
WEATHER = "weather"


class SlOffshoreIdx(BaseModel):
associated_spotId: str
associated_offshoreLocation_lat: float
associated_offshoreLocation_lon: float

model_config = ConfigDict(from_attributes=True)


class BodhiWavesModel(BaseModel):
id: int
location: str
latitude: float
longitude: float
time: datetime
step: timedelta
valid_time: datetime
swh: Optional[float]
perpw: Optional[float]
dirpw: Optional[float]
shww: Optional[float]
mpww: Optional[float]
wvdir: Optional[float]
ws: Optional[float]
wdir: Optional[float]
swell: Optional[float]
swper: Optional[float]
entry_updated: datetime

model_config = ConfigDict(from_attributes=True)
3 changes: 2 additions & 1 deletion airflow/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ debugpy
noaa-coops==0.3.2
tqdm==4.66.2
memray==1.11.0
geopandas
geopandas
pydantic==2.7

0 comments on commit c884f8f

Please sign in to comment.