Skip to content

Commit

Permalink
Made client for calling db for real data
Browse files Browse the repository at this point in the history
  • Loading branch information
confusedmatrix committed Feb 1, 2024
1 parent 6efabec commit 8a36808
Show file tree
Hide file tree
Showing 13 changed files with 381 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ jobs:
# * Produce JUnit XML report
- name: Run unit tests
run: |
.venv/bin/python3 -m xmlrunner discover -s src/india_api -p "test_*.py" --output-file ut-report.xml
.venv/bin/python3 -m pytest src/india_api --cov -s src/india_api --cov-report=xml
# Create test summary to be visualised on the job summary screen on GitHub
# * Runs even if previous steps fail
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ and create a new environment with your favorite environment manager.
Install all the dependencies with

```
pip install -e .[all]
pip install -e ".[all]"
```

You can run the service with the command `india-api`.
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ authors = [
classifiers = ["Programming Language :: Python :: 3"]
dependencies = [
"fastapi >= 0.105.0",
"pvsite-datamodel >= 1.0.10",
"pytz >= 2023.3",
"structlog >= 23.2.0",
"uvicorn >= 0.24.0",
Expand All @@ -26,6 +27,9 @@ dependencies = [
[project.optional-dependencies]
test = [
"unittest-xml-reporting == 3.2.0",
"pytest >= 8.0.0",
"pytest-cov >= 4.1.0",
"testcontainers >= 3.7.1",
]
lint = [
"mypy >= 1.7.1",
Expand Down
7 changes: 6 additions & 1 deletion src/india_api/cmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@
cfg = Config()

match cfg.SOURCE:
case "dummydb":
case "indiadb":
if cfg.DB_URL == '' or cfg.DB_URL == None:
raise OSError(f"DB_URL env var is required using db source: {cfg.SOURCE}")

def get_db_client_override() -> internal.DatabaseInterface:
return internal.inputs.indiadb.Client(cfg.DB_URL)
case "dummydb":
def get_db_client_override() -> internal.DatabaseInterface:
return internal.inputs.dummydb.Client()
case _:
Expand Down
2 changes: 2 additions & 0 deletions src/india_api/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Packages internal to the service."""

from .models import (
ActualPower,
DatabaseInterface,
PredictedPower,
)
Expand All @@ -11,6 +12,7 @@
)

__all__ = [
"ActualPower",
"PredictedPower",
"DatabaseInterface",
"inputs",
Expand Down
3 changes: 2 additions & 1 deletion src/india_api/internal/config/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,6 @@ def __init__(self) -> None:
class Config(EnvParser):
"""Config for the application."""

SOURCE: str = "dummydb"
SOURCE: str = "indiadb"
DB_URL: str = ""
PORT: int = 8000
3 changes: 2 additions & 1 deletion src/india_api/internal/inputs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import dummydb
from . import dummydb, indiadb

__all__ = [
"indiadb",
"dummydb",
]
5 changes: 5 additions & 0 deletions src/india_api/internal/inputs/indiadb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Defines sources of data for the API."""

from .client import Client

__all__ = ["Client"]
152 changes: 152 additions & 0 deletions src/india_api/internal/inputs/indiadb/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""India DB client that conforms to the DatabaseInterface."""
import datetime as dt
import logging

from pvsite_datamodel import DatabaseConnection
from pvsite_datamodel.read import get_sites_by_country, get_latest_forecast_values_by_site, get_pv_generation_by_sites
from pvsite_datamodel.sqlmodels import SiteAssetType, ForecastValueSQL

from india_api import internal

log = logging.getLogger(__name__)


class Client(internal.DatabaseInterface):
"""Defines India DB client that conforms to the DatabaseInterface."""

def __init__(self, database_url: str) -> None:
"""Initialize the client with a SQLAlchemy database connection and session."""

self.connection = DatabaseConnection(url=database_url, echo=False)

def get_predicted_yields_for_location(
self,
location: str,
asset_type: SiteAssetType,
) -> list[internal.PredictedPower]:
"""Gets the predicted yields for a location."""

# Get the window
start, end = _getWindow()

# get site uuid
with self.connection.get_session() as session:
sites = get_sites_by_country(session, country="india")

# just select wind site
sites = [s for s in sites if s.asset_type == asset_type]
site = sites[0]

# read actual generations
values = get_latest_forecast_values_by_site(
session, site_uuids=[site.site_uuid], start_utc=start
)
forecast_values: [ForecastValueSQL] = values[site.site_uuid]

# convert ForecastValueSQL to PredictedPower
values = [
internal.PredictedPower(PowerKW=value.forecast_power_kw, Time=value.start_utc.astimezone(dt.UTC))
for value in forecast_values
]

return values

def get_generation_for_location(
self,
location: str,
asset_type: SiteAssetType,
) -> [internal.PredictedPower]:
"""Gets the predicted yields for a location."""

# Get the window
start, end = _getWindow()

# get site uuid
with self.connection.get_session() as session:
sites = get_sites_by_country(session, country="india")

# just select wind site
sites = [site for site in sites if site.asset_type == asset_type]
site = sites[0]

# read actual generations
values = get_pv_generation_by_sites(
session=session, site_uuids=[site.site_uuid], start_utc=start, end_utc=end
)

# convert from GenerationSQL to PredictedPower
values = [
internal.ActualPower(PowerKW=value.generation_power_kw, Time=value.start_utc.astimezone(dt.UTC))
for value in values
]

return values

def get_predicted_solar_yields_for_location(
self,
location: str,
) -> [internal.PredictedPower]:
"""
Gets the predicted solar yields for a location.
Args:
location: The location to get the predicted solar yields for.
"""

return self.get_predicted_yields_for_location(
location=location, asset_type=SiteAssetType.pv
)

def get_predicted_wind_yields_for_location(
self,
location: str,
) -> list[internal.PredictedPower]:
"""
Gets the predicted wind yields for a location.
Args:
location: The location to get the predicted wind yields for.
"""

return self.get_predicted_yields_for_location(
location=location, asset_type=SiteAssetType.wind
)

def get_actual_solar_yields_for_location(self, location: str) -> list[internal.PredictedPower]:
"""Gets the actual solar yields for a location."""

return self.get_generation_for_location(location=location, asset_type=SiteAssetType.pv)

def get_actual_wind_yields_for_location(self, location: str) -> list[internal.PredictedPower]:
"""Gets the actual wind yields for a location."""

log.error('test')
return self.get_generation_for_location(location=location, asset_type=SiteAssetType.wind)

def get_wind_regions(self) -> list[str]:
"""Gets the valid wind regions."""
return ["ruvnl"]

def get_solar_regions(self) -> list[str]:
"""Gets the valid solar regions."""
return ["ruvnl"]


def _getWindow() -> tuple[dt.datetime, dt.datetime]:
"""Returns the start and end of the window for timeseries data."""

# Window start is the beginning of the day two days ago
start = (dt.datetime.now(tz=dt.UTC) - dt.timedelta(days=2)).replace(
hour=0,
minute=0,
second=0,
microsecond=0,
)
# Window end is the beginning of the day two days ahead
end = (dt.datetime.now(tz=dt.UTC) + dt.timedelta(days=2)).replace(
hour=0,
minute=0,
second=0,
microsecond=0,
)
return start, end
150 changes: 150 additions & 0 deletions src/india_api/internal/inputs/indiadb/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
""" Test fixtures to set up fake database for testing. """
import os
from datetime import datetime, timedelta

import pytest
from pvsite_datamodel.sqlmodels import (
Base,
ForecastSQL,
ForecastValueSQL,
GenerationSQL, SiteSQL
)
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from testcontainers.postgres import PostgresContainer


@pytest.fixture(scope="session")
def engine():
"""Database engine fixture."""

with PostgresContainer("postgres:14.5") as postgres:
url = postgres.get_connection_url()
os.environ["DB_URL"] = url
engine = create_engine(url)
Base.metadata.create_all(engine)

yield engine


@pytest.fixture()
def db_session(engine):
"""Return a sqlalchemy session, which tears down everything properly post-test."""

connection = engine.connect()
# begin the nested transaction
transaction = connection.begin()
# use the connection with the already started transaction

with Session(bind=connection) as session:
yield session

session.close()
# roll back the broader transaction
transaction.rollback()
# put back the connection to the connection pool
connection.close()
session.flush()

engine.dispose()


@pytest.fixture(scope="session", autouse=True)
def db_data(engine):
"""Seed some initial data into DB."""

with engine.connect() as connection:
with Session(bind=connection) as session:
# PV site
site = SiteSQL(
client_site_id=1,
latitude=20.59,
longitude=78.96,
capacity_kw=4,
ml_id=1,
asset_type="pv",
country="india"
)
session.add(site)

# Wind site
site = SiteSQL(
client_site_id=2,
latitude=20.59,
longitude=78.96,
capacity_kw=4,
ml_id=2,
asset_type="wind",
country="india"
)
session.add(site)

session.commit()


@pytest.fixture()
def generations(db_session):
"""Create some fake generations"""
start_times = [datetime.today() - timedelta(minutes=x) for x in range(10)]

all_generations = []

sites = db_session.query(SiteSQL).all()
for site in sites:
for i in range(0, 10):
generation = GenerationSQL(
site_uuid=site.site_uuid,
generation_power_kw=i,
start_utc=start_times[i],
end_utc=start_times[i] + timedelta(minutes=5),
)
all_generations.append(generation)

db_session.add_all(all_generations)
db_session.commit()

return all_generations


@pytest.fixture()
def forecast_values(db_session):
"""Create some fake forecast values"""
forecast_values = []
forecast_version: str = "0.0.0"

num_forecasts = 10
num_values_per_forecast = 11

timestamps = [datetime.utcnow() - timedelta(minutes=10 * i) for i in range(num_forecasts)]

# To make things trickier we make a second forecast at the same for one of the timestamps.
timestamps = timestamps + timestamps[-1:]

sites = db_session.query(SiteSQL).all()
for site in sites:
for timestamp in timestamps:
forecast: ForecastSQL = ForecastSQL(
site_uuid=site.site_uuid, forecast_version=forecast_version, timestamp_utc=timestamp
)

db_session.add(forecast)
db_session.commit()

for i in range(num_values_per_forecast):
# Forecasts of 15 minutes.
duration = 15
horizon = duration * i
forecast_value: ForecastValueSQL = ForecastValueSQL(
forecast_power_kw=i,
forecast_uuid=forecast.forecast_uuid,
start_utc=timestamp + timedelta(minutes=horizon),
end_utc=timestamp + timedelta(minutes=horizon + duration),
horizon_minutes=horizon,
)

forecast_values.append(forecast_value)

db_session.add_all(forecast_values)
db_session.commit()

return forecast_values
Loading

0 comments on commit 8a36808

Please sign in to comment.