-
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Made client for calling db for real data
- Loading branch information
1 parent
6efabec
commit 8a36808
Showing
13 changed files
with
381 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from . import dummydb | ||
from . import dummydb, indiadb | ||
|
||
__all__ = [ | ||
"indiadb", | ||
"dummydb", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Defines sources of data for the API.""" | ||
|
||
from .client import Client | ||
|
||
__all__ = ["Client"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
"""India DB client that conforms to the DatabaseInterface.""" | ||
import datetime as dt | ||
import logging | ||
|
||
from pvsite_datamodel import DatabaseConnection | ||
from pvsite_datamodel.read import get_sites_by_country, get_latest_forecast_values_by_site, get_pv_generation_by_sites | ||
from pvsite_datamodel.sqlmodels import SiteAssetType, ForecastValueSQL | ||
|
||
from india_api import internal | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class Client(internal.DatabaseInterface): | ||
"""Defines India DB client that conforms to the DatabaseInterface.""" | ||
|
||
def __init__(self, database_url: str) -> None: | ||
"""Initialize the client with a SQLAlchemy database connection and session.""" | ||
|
||
self.connection = DatabaseConnection(url=database_url, echo=False) | ||
|
||
def get_predicted_yields_for_location( | ||
self, | ||
location: str, | ||
asset_type: SiteAssetType, | ||
) -> list[internal.PredictedPower]: | ||
"""Gets the predicted yields for a location.""" | ||
|
||
# Get the window | ||
start, end = _getWindow() | ||
|
||
# get site uuid | ||
with self.connection.get_session() as session: | ||
sites = get_sites_by_country(session, country="india") | ||
|
||
# just select wind site | ||
sites = [s for s in sites if s.asset_type == asset_type] | ||
site = sites[0] | ||
|
||
# read actual generations | ||
values = get_latest_forecast_values_by_site( | ||
session, site_uuids=[site.site_uuid], start_utc=start | ||
) | ||
forecast_values: [ForecastValueSQL] = values[site.site_uuid] | ||
|
||
# convert ForecastValueSQL to PredictedPower | ||
values = [ | ||
internal.PredictedPower(PowerKW=value.forecast_power_kw, Time=value.start_utc.astimezone(dt.UTC)) | ||
for value in forecast_values | ||
] | ||
|
||
return values | ||
|
||
def get_generation_for_location( | ||
self, | ||
location: str, | ||
asset_type: SiteAssetType, | ||
) -> [internal.PredictedPower]: | ||
"""Gets the predicted yields for a location.""" | ||
|
||
# Get the window | ||
start, end = _getWindow() | ||
|
||
# get site uuid | ||
with self.connection.get_session() as session: | ||
sites = get_sites_by_country(session, country="india") | ||
|
||
# just select wind site | ||
sites = [site for site in sites if site.asset_type == asset_type] | ||
site = sites[0] | ||
|
||
# read actual generations | ||
values = get_pv_generation_by_sites( | ||
session=session, site_uuids=[site.site_uuid], start_utc=start, end_utc=end | ||
) | ||
|
||
# convert from GenerationSQL to PredictedPower | ||
values = [ | ||
internal.ActualPower(PowerKW=value.generation_power_kw, Time=value.start_utc.astimezone(dt.UTC)) | ||
for value in values | ||
] | ||
|
||
return values | ||
|
||
def get_predicted_solar_yields_for_location( | ||
self, | ||
location: str, | ||
) -> [internal.PredictedPower]: | ||
""" | ||
Gets the predicted solar yields for a location. | ||
Args: | ||
location: The location to get the predicted solar yields for. | ||
""" | ||
|
||
return self.get_predicted_yields_for_location( | ||
location=location, asset_type=SiteAssetType.pv | ||
) | ||
|
||
def get_predicted_wind_yields_for_location( | ||
self, | ||
location: str, | ||
) -> list[internal.PredictedPower]: | ||
""" | ||
Gets the predicted wind yields for a location. | ||
Args: | ||
location: The location to get the predicted wind yields for. | ||
""" | ||
|
||
return self.get_predicted_yields_for_location( | ||
location=location, asset_type=SiteAssetType.wind | ||
) | ||
|
||
def get_actual_solar_yields_for_location(self, location: str) -> list[internal.PredictedPower]: | ||
"""Gets the actual solar yields for a location.""" | ||
|
||
return self.get_generation_for_location(location=location, asset_type=SiteAssetType.pv) | ||
|
||
def get_actual_wind_yields_for_location(self, location: str) -> list[internal.PredictedPower]: | ||
"""Gets the actual wind yields for a location.""" | ||
|
||
log.error('test') | ||
return self.get_generation_for_location(location=location, asset_type=SiteAssetType.wind) | ||
|
||
def get_wind_regions(self) -> list[str]: | ||
"""Gets the valid wind regions.""" | ||
return ["ruvnl"] | ||
|
||
def get_solar_regions(self) -> list[str]: | ||
"""Gets the valid solar regions.""" | ||
return ["ruvnl"] | ||
|
||
|
||
def _getWindow() -> tuple[dt.datetime, dt.datetime]: | ||
"""Returns the start and end of the window for timeseries data.""" | ||
|
||
# Window start is the beginning of the day two days ago | ||
start = (dt.datetime.now(tz=dt.UTC) - dt.timedelta(days=2)).replace( | ||
hour=0, | ||
minute=0, | ||
second=0, | ||
microsecond=0, | ||
) | ||
# Window end is the beginning of the day two days ahead | ||
end = (dt.datetime.now(tz=dt.UTC) + dt.timedelta(days=2)).replace( | ||
hour=0, | ||
minute=0, | ||
second=0, | ||
microsecond=0, | ||
) | ||
return start, end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
""" Test fixtures to set up fake database for testing. """ | ||
import os | ||
from datetime import datetime, timedelta | ||
|
||
import pytest | ||
from pvsite_datamodel.sqlmodels import ( | ||
Base, | ||
ForecastSQL, | ||
ForecastValueSQL, | ||
GenerationSQL, SiteSQL | ||
) | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import Session | ||
from testcontainers.postgres import PostgresContainer | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def engine(): | ||
"""Database engine fixture.""" | ||
|
||
with PostgresContainer("postgres:14.5") as postgres: | ||
url = postgres.get_connection_url() | ||
os.environ["DB_URL"] = url | ||
engine = create_engine(url) | ||
Base.metadata.create_all(engine) | ||
|
||
yield engine | ||
|
||
|
||
@pytest.fixture() | ||
def db_session(engine): | ||
"""Return a sqlalchemy session, which tears down everything properly post-test.""" | ||
|
||
connection = engine.connect() | ||
# begin the nested transaction | ||
transaction = connection.begin() | ||
# use the connection with the already started transaction | ||
|
||
with Session(bind=connection) as session: | ||
yield session | ||
|
||
session.close() | ||
# roll back the broader transaction | ||
transaction.rollback() | ||
# put back the connection to the connection pool | ||
connection.close() | ||
session.flush() | ||
|
||
engine.dispose() | ||
|
||
|
||
@pytest.fixture(scope="session", autouse=True) | ||
def db_data(engine): | ||
"""Seed some initial data into DB.""" | ||
|
||
with engine.connect() as connection: | ||
with Session(bind=connection) as session: | ||
# PV site | ||
site = SiteSQL( | ||
client_site_id=1, | ||
latitude=20.59, | ||
longitude=78.96, | ||
capacity_kw=4, | ||
ml_id=1, | ||
asset_type="pv", | ||
country="india" | ||
) | ||
session.add(site) | ||
|
||
# Wind site | ||
site = SiteSQL( | ||
client_site_id=2, | ||
latitude=20.59, | ||
longitude=78.96, | ||
capacity_kw=4, | ||
ml_id=2, | ||
asset_type="wind", | ||
country="india" | ||
) | ||
session.add(site) | ||
|
||
session.commit() | ||
|
||
|
||
@pytest.fixture() | ||
def generations(db_session): | ||
"""Create some fake generations""" | ||
start_times = [datetime.today() - timedelta(minutes=x) for x in range(10)] | ||
|
||
all_generations = [] | ||
|
||
sites = db_session.query(SiteSQL).all() | ||
for site in sites: | ||
for i in range(0, 10): | ||
generation = GenerationSQL( | ||
site_uuid=site.site_uuid, | ||
generation_power_kw=i, | ||
start_utc=start_times[i], | ||
end_utc=start_times[i] + timedelta(minutes=5), | ||
) | ||
all_generations.append(generation) | ||
|
||
db_session.add_all(all_generations) | ||
db_session.commit() | ||
|
||
return all_generations | ||
|
||
|
||
@pytest.fixture() | ||
def forecast_values(db_session): | ||
"""Create some fake forecast values""" | ||
forecast_values = [] | ||
forecast_version: str = "0.0.0" | ||
|
||
num_forecasts = 10 | ||
num_values_per_forecast = 11 | ||
|
||
timestamps = [datetime.utcnow() - timedelta(minutes=10 * i) for i in range(num_forecasts)] | ||
|
||
# To make things trickier we make a second forecast at the same for one of the timestamps. | ||
timestamps = timestamps + timestamps[-1:] | ||
|
||
sites = db_session.query(SiteSQL).all() | ||
for site in sites: | ||
for timestamp in timestamps: | ||
forecast: ForecastSQL = ForecastSQL( | ||
site_uuid=site.site_uuid, forecast_version=forecast_version, timestamp_utc=timestamp | ||
) | ||
|
||
db_session.add(forecast) | ||
db_session.commit() | ||
|
||
for i in range(num_values_per_forecast): | ||
# Forecasts of 15 minutes. | ||
duration = 15 | ||
horizon = duration * i | ||
forecast_value: ForecastValueSQL = ForecastValueSQL( | ||
forecast_power_kw=i, | ||
forecast_uuid=forecast.forecast_uuid, | ||
start_utc=timestamp + timedelta(minutes=horizon), | ||
end_utc=timestamp + timedelta(minutes=horizon + duration), | ||
horizon_minutes=horizon, | ||
) | ||
|
||
forecast_values.append(forecast_value) | ||
|
||
db_session.add_all(forecast_values) | ||
db_session.commit() | ||
|
||
return forecast_values |
Oops, something went wrong.