Skip to content

Commit

Permalink
Merge pull request #55 from openclimatefix/issue/53-read-rds
Browse files Browse the repository at this point in the history
Issue/53 read rds
  • Loading branch information
peterdudfield authored Jan 19, 2022
2 parents 76e649e + f02d318 commit db21842
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 212 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Documentation can be viewed at `/docs`. This is automatically generated from the

# Setup and Run

This can be done it two differen ways: With Python or with Docker.
This can be done it two different ways: With Python or with Docker.

## Python

Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ uvicorn[standard]
pydantic
numpy
requests
nowcasting_forecast
sqlalchemy
psycopg2-binary
41 changes: 41 additions & 0 deletions src/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
""" Functions to read from the database and format """
import os

from nowcasting_forecast.database.connection import DatabaseConnection
from nowcasting_forecast.database.models import Forecast, ManyForecasts
from nowcasting_forecast.database.read import get_latest_forecast
from sqlalchemy.orm.session import Session


def get_forecasts_from_database(session: Session) -> ManyForecasts:
"""Get forecasts from database for all GSPs"""
# sql almacy objects
forecasts = [
get_forecasts_for_a_specific_gsp_from_database(session=session, gsp_id=gsp_id)
for gsp_id in range(0, 338)
]

# change to pydantic objects
forecasts = [Forecast.from_orm(forecast) for forecast in forecasts]

# return as many forecasts
return ManyForecasts(forecasts=forecasts)


def get_forecasts_for_a_specific_gsp_from_database(session: Session, gsp_id) -> Forecast:
"""Get forecasts for on GSP from database"""
# get forecast from database
forecast = get_latest_forecast(session=session, gsp_id=gsp_id)

return Forecast.from_orm(forecast)


def get_session():
"""Get database settion"""
connection = DatabaseConnection(url=os.getenv("DB_URL", "not_set"))

with connection.get_session() as s:
yield s


# TODO load fprecast and make national
63 changes: 63 additions & 0 deletions src/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
""" Create dummy forecasts for testing """
import logging
from datetime import datetime, timedelta, timezone

from nowcasting_forecast.database.models import (
Forecast,
ForecastValue,
InputDataLastUpdated,
Location,
)

from utils import floor_30_minutes_dt

logger = logging.getLogger(__name__)

thirty_minutes = timedelta(minutes=30)


def create_dummy_forecast_for_location(location: Location) -> Forecast:
"""Create dummy forecast for one location"""
logger.debug(f"Creating dummy forecast for location {location}")

# get datetime right now
now = datetime.now(timezone.utc)
now_floor_30 = floor_30_minutes_dt(dt=now)

# make list of datetimes that the forecast is for
datetimes_utc = [now_floor_30 + i * thirty_minutes for i in range(4)]

input_data_last_updated = InputDataLastUpdated(
gsp=now_floor_30,
nwp=now_floor_30,
pv=now_floor_30,
satellite=now_floor_30,
)

forecast_values = [
ForecastValue(expected_power_generation_megawatts=0, target_time=datetime_utc)
for datetime_utc in datetimes_utc
]

forecast_creation_time = now_floor_30 - timedelta(minutes=30)
return Forecast(
location=location,
forecast_creation_time=forecast_creation_time,
forecast_values=forecast_values,
input_data_last_updated=input_data_last_updated,
)


def create_dummy_national_forecast() -> Forecast:
"""Create a dummy forecast for the national level"""

logger.debug("Creating dummy forecast")

location = Location(
label="GB (National)",
region_name="national_GB",
gsp_name="dummy_gsp_name",
gsp_group="dummy_gsp_group",
)

return create_dummy_forecast_for_location(location=location)
187 changes: 18 additions & 169 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
""" Main FastAPI app """
import logging
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from uuid import UUID, uuid4
from datetime import timedelta

from fastapi import FastAPI
from pydantic import BaseModel, Field, validator
from fastapi import Depends, FastAPI
from nowcasting_forecast.database.models import Forecast, ManyForecasts
from sqlalchemy.orm.session import Session

from utils import convert_to_camelcase, datetime_must_have_timezone, floor_30_minutes_dt
from database import (
get_forecasts_for_a_specific_gsp_from_database,
get_forecasts_from_database,
get_session,
)
from dummy import create_dummy_national_forecast

logger = logging.getLogger(__name__)

Expand All @@ -33,163 +37,7 @@
thirty_minutes = timedelta(minutes=30)


class EnhancedBaseModel(BaseModel):
"""Ensures that attribute names are returned in camelCase"""

# Automatically creates camelcase alias for field names
# See https://pydantic-docs.helpmanual.io/usage/model_config/#alias-generator
class Config: # noqa: D106
alias_generator = convert_to_camelcase
allow_population_by_field_name = True


class ForecastValue(EnhancedBaseModel):
"""One Forecast of generation at one timestamp"""

target_time: datetime = Field(
...,
description="The target time that the forecast is produced for",
)
expected_pv_power_generation_megawatts: float = Field(
..., ge=0, description="The forecasted value in MW"
)

_normalize_target_time = validator("target_time", allow_reuse=True)(datetime_must_have_timezone)


class AdditionalLocationInformation(EnhancedBaseModel):
"""Used internally to better describe a Location"""

gsp_id: Optional[int] = Field(None, description="The Grid Supply Point (GSP) id")
gsp_name: Optional[str] = Field(None, description="The GSP name")
gsp_group: Optional[str] = Field(None, description="The GSP group name")
region_name: Optional[str] = Field(..., description="The GSP region name")


class Location(EnhancedBaseModel):
"""Location that the forecast is for"""

location_id: UUID = Field(..., description="OCF-created id for location")
label: str = Field(..., description="User-defined name for the location")
additional_information: AdditionalLocationInformation = Field(
..., description="E.g. Existing GSP properties"
)


class InputDataLastUpdated(EnhancedBaseModel):
"""Information about the input data that was used to create the forecast"""

gsp: datetime = Field(..., description="The time when the input GSP data was last updated")
nwp: datetime = Field(..., description="The time when the input NWP data was last updated")
pv: datetime = Field(..., description="The time when the input PV data was last updated")
satellite: datetime = Field(
..., description="The time when the input satellite data was last updated"
)

_normalize_gsp = validator("gsp", allow_reuse=True)(datetime_must_have_timezone)
_normalize_nwp = validator("nwp", allow_reuse=True)(datetime_must_have_timezone)
_normalize_pv = validator("pv", allow_reuse=True)(datetime_must_have_timezone)
_normalize_satellite = validator("satellite", allow_reuse=True)(datetime_must_have_timezone)


class Forecast(EnhancedBaseModel):
"""A single Forecast"""

location: Location = Field(..., description="The location object for this forecaster")
forecast_creation_time: datetime = Field(
..., description="The time when the forecaster was made"
)
forecast_values: List[ForecastValue] = Field(
...,
description="List of forecasted value objects. Each value has the datestamp and a value",
)
input_data_last_updated: InputDataLastUpdated = Field(
...,
description="Information about the input data that was used to create the forecast",
)

_normalize_forecast_creation_time = validator("forecast_creation_time", allow_reuse=True)(
datetime_must_have_timezone
)


class ManyForecasts(EnhancedBaseModel):
"""Many Forecasts"""

forecasts: List[Forecast] = Field(
...,
description="List of forecasts for different GSPs",
)


def _create_dummy_forecast_for_location(location: Location):
logger.debug(f"Creating dummy forecast for location {location}")

# get datetime right now
now = datetime.now(timezone.utc)
now_floor_30 = floor_30_minutes_dt(dt=now)

# make list of datetimes that the forecast is for
datetimes_utc = [now_floor_30 + i * thirty_minutes for i in range(4)]

input_data_last_updated = InputDataLastUpdated(
gsp=now_floor_30,
nwp=now_floor_30,
pv=now_floor_30,
satellite=now_floor_30,
)

forecast_values = [
ForecastValue(expected_pv_power_generation_megawatts=0, target_time=datetime_utc)
for datetime_utc in datetimes_utc
]

forecast_creation_time = now_floor_30 - timedelta(minutes=30)
return Forecast(
location=location,
forecast_creation_time=forecast_creation_time,
forecast_values=forecast_values,
input_data_last_updated=input_data_last_updated,
)


def _create_dummy_national_forecast():
"""Create a dummy forecast for the national level"""

logger.debug("Creating dummy forecast")

additional_information = AdditionalLocationInformation(
region_name="national_GB",
)

location = Location(
location_id=uuid4(),
label="GB (National)",
additional_information=additional_information,
)

return _create_dummy_forecast_for_location(location=location)


def _create_dummy_gsp_forecast(gsp_id):
"""Create a dummy forecast for a given GSP"""

logger.debug(f"Creating dummy forecast for {gsp_id=}")

additional_information = AdditionalLocationInformation(
gsp_id=gsp_id,
gsp_name="dummy_gsp_name",
gsp_group="dummy_gsp_group",
region_name="dummy_region_name",
)

location = Location(
location_id=uuid4(),
label="dummy_label",
additional_information=additional_information,
)

return _create_dummy_forecast_for_location(location=location)
# Dependency


@app.get("/")
Expand All @@ -207,27 +55,28 @@ async def get_api_information():


@app.get("/v0/forecasts/GB/pv/gsp/{gsp_id}", response_model=Forecast)
async def get_forecasts_for_a_specific_gsp(gsp_id) -> Forecast:
async def get_forecasts_for_a_specific_gsp(
gsp_id, session: Session = Depends(get_session)
) -> Forecast:
"""Get one forecast for a specific GSP id"""

logger.info(f"Get forecasts for gsp id {gsp_id}")

return _create_dummy_gsp_forecast(gsp_id=gsp_id)
return get_forecasts_for_a_specific_gsp_from_database(session=session, gsp_id=gsp_id)


@app.get("/v0/forecasts/GB/pv/gsp", response_model=ManyForecasts)
async def get_all_available_forecasts() -> ManyForecasts:
async def get_all_available_forecasts(session: Session = Depends(get_session)) -> ManyForecasts:
"""Get the latest information for all available forecasts"""

logger.info("Get forecasts for all gsps")

return ManyForecasts(forecasts=[_create_dummy_gsp_forecast(gsp_id) for gsp_id in range(10)])
return get_forecasts_from_database(session=session)


@app.get("/v0/forecasts/GB/pv/national", response_model=Forecast)
async def get_nationally_aggregated_forecasts() -> Forecast:
"""Get an aggregated forecast at the national level"""

logger.debug("Get national forecasts")

return _create_dummy_national_forecast()
return create_dummy_national_forecast()
42 changes: 42 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
""" Pytest fixitures for tests """
import os
import tempfile

import pytest
from nowcasting_forecast.database.connection import DatabaseConnection
from nowcasting_forecast.database.fake import make_fake_forecasts
from nowcasting_forecast.database.models import Base


@pytest.fixture
def forecasts(db_session):
"""Pytest fixture of 338 fake forecasts"""
# create
f = make_fake_forecasts(gsp_ids=list(range(0, 338)))
db_session.add_all(f)

return f


@pytest.fixture
def db_connection():
"""Pytest fixture for a database connection"""
with tempfile.NamedTemporaryFile(suffix="db") as tmp:
# set url option to not check same thread, this solves an error seen in testing
url = f"sqlite:///{tmp.name}.db?check_same_thread=False"
os.environ["DB_URL"] = url
connection = DatabaseConnection(url=url)
Base.metadata.create_all(connection.engine)

yield connection


@pytest.fixture(scope="function", autouse=True)
def db_session(db_connection):
"""Creates a new database session for a test."""

with db_connection.get_session() as s:
s.begin()
yield s

s.rollback()
Loading

0 comments on commit db21842

Please sign in to comment.