-
-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #55 from openclimatefix/issue/53-read-rds
Issue/53 read rds
- Loading branch information
Showing
10 changed files
with
205 additions
and
212 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,6 @@ uvicorn[standard] | |
pydantic | ||
numpy | ||
requests | ||
nowcasting_forecast | ||
sqlalchemy | ||
psycopg2-binary |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
""" Functions to read from the database and format """ | ||
import os | ||
|
||
from nowcasting_forecast.database.connection import DatabaseConnection | ||
from nowcasting_forecast.database.models import Forecast, ManyForecasts | ||
from nowcasting_forecast.database.read import get_latest_forecast | ||
from sqlalchemy.orm.session import Session | ||
|
||
|
||
def get_forecasts_from_database(session: Session) -> ManyForecasts: | ||
"""Get forecasts from database for all GSPs""" | ||
# sql almacy objects | ||
forecasts = [ | ||
get_forecasts_for_a_specific_gsp_from_database(session=session, gsp_id=gsp_id) | ||
for gsp_id in range(0, 338) | ||
] | ||
|
||
# change to pydantic objects | ||
forecasts = [Forecast.from_orm(forecast) for forecast in forecasts] | ||
|
||
# return as many forecasts | ||
return ManyForecasts(forecasts=forecasts) | ||
|
||
|
||
def get_forecasts_for_a_specific_gsp_from_database(session: Session, gsp_id) -> Forecast: | ||
"""Get forecasts for on GSP from database""" | ||
# get forecast from database | ||
forecast = get_latest_forecast(session=session, gsp_id=gsp_id) | ||
|
||
return Forecast.from_orm(forecast) | ||
|
||
|
||
def get_session(): | ||
"""Get database settion""" | ||
connection = DatabaseConnection(url=os.getenv("DB_URL", "not_set")) | ||
|
||
with connection.get_session() as s: | ||
yield s | ||
|
||
|
||
# TODO load fprecast and make national |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
""" Create dummy forecasts for testing """ | ||
import logging | ||
from datetime import datetime, timedelta, timezone | ||
|
||
from nowcasting_forecast.database.models import ( | ||
Forecast, | ||
ForecastValue, | ||
InputDataLastUpdated, | ||
Location, | ||
) | ||
|
||
from utils import floor_30_minutes_dt | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
thirty_minutes = timedelta(minutes=30) | ||
|
||
|
||
def create_dummy_forecast_for_location(location: Location) -> Forecast: | ||
"""Create dummy forecast for one location""" | ||
logger.debug(f"Creating dummy forecast for location {location}") | ||
|
||
# get datetime right now | ||
now = datetime.now(timezone.utc) | ||
now_floor_30 = floor_30_minutes_dt(dt=now) | ||
|
||
# make list of datetimes that the forecast is for | ||
datetimes_utc = [now_floor_30 + i * thirty_minutes for i in range(4)] | ||
|
||
input_data_last_updated = InputDataLastUpdated( | ||
gsp=now_floor_30, | ||
nwp=now_floor_30, | ||
pv=now_floor_30, | ||
satellite=now_floor_30, | ||
) | ||
|
||
forecast_values = [ | ||
ForecastValue(expected_power_generation_megawatts=0, target_time=datetime_utc) | ||
for datetime_utc in datetimes_utc | ||
] | ||
|
||
forecast_creation_time = now_floor_30 - timedelta(minutes=30) | ||
return Forecast( | ||
location=location, | ||
forecast_creation_time=forecast_creation_time, | ||
forecast_values=forecast_values, | ||
input_data_last_updated=input_data_last_updated, | ||
) | ||
|
||
|
||
def create_dummy_national_forecast() -> Forecast: | ||
"""Create a dummy forecast for the national level""" | ||
|
||
logger.debug("Creating dummy forecast") | ||
|
||
location = Location( | ||
label="GB (National)", | ||
region_name="national_GB", | ||
gsp_name="dummy_gsp_name", | ||
gsp_group="dummy_gsp_group", | ||
) | ||
|
||
return create_dummy_forecast_for_location(location=location) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" Pytest fixitures for tests """ | ||
import os | ||
import tempfile | ||
|
||
import pytest | ||
from nowcasting_forecast.database.connection import DatabaseConnection | ||
from nowcasting_forecast.database.fake import make_fake_forecasts | ||
from nowcasting_forecast.database.models import Base | ||
|
||
|
||
@pytest.fixture | ||
def forecasts(db_session): | ||
"""Pytest fixture of 338 fake forecasts""" | ||
# create | ||
f = make_fake_forecasts(gsp_ids=list(range(0, 338))) | ||
db_session.add_all(f) | ||
|
||
return f | ||
|
||
|
||
@pytest.fixture | ||
def db_connection(): | ||
"""Pytest fixture for a database connection""" | ||
with tempfile.NamedTemporaryFile(suffix="db") as tmp: | ||
# set url option to not check same thread, this solves an error seen in testing | ||
url = f"sqlite:///{tmp.name}.db?check_same_thread=False" | ||
os.environ["DB_URL"] = url | ||
connection = DatabaseConnection(url=url) | ||
Base.metadata.create_all(connection.engine) | ||
|
||
yield connection | ||
|
||
|
||
@pytest.fixture(scope="function", autouse=True) | ||
def db_session(db_connection): | ||
"""Creates a new database session for a test.""" | ||
|
||
with db_connection.get_session() as s: | ||
s.begin() | ||
yield s | ||
|
||
s.rollback() |
Oops, something went wrong.