Skip to content

Commit

Permalink
Merge pull request #91 from openclimatefix/issue/90/gsp/systems
Browse files Browse the repository at this point in the history
Issue/90/gsp/systems
  • Loading branch information
peterdudfield authored May 9, 2022
2 parents 62e1d2a + 96458e4 commit 97d4edc
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 16 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uvicorn[standard]
pydantic
numpy
requests
nowcasting_datamodel==0.0.43
nowcasting_datamodel==0.0.49
nowcasting_dataset==3.1.59
sqlalchemy
psycopg2-binary
Expand Down
22 changes: 21 additions & 1 deletion src/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from typing import List, Optional

from nowcasting_datamodel.connection import DatabaseConnection
from nowcasting_datamodel.models import Forecast, ForecastValue, GSPYield, ManyForecasts
from nowcasting_datamodel.models import Forecast, ForecastValue, GSPYield, Location, ManyForecasts
from nowcasting_datamodel.read.read import (
get_all_gsp_ids_latest_forecast,
get_all_locations,
get_forecast_values,
get_latest_forecast,
get_latest_national_forecast,
get_location,
)
from nowcasting_datamodel.read.read_gsp import get_gsp_yield
from sqlalchemy.orm.session import Session
Expand Down Expand Up @@ -111,3 +113,21 @@ def get_truth_values_for_a_specific_gsp_from_database(
start_datetime_utc=yesterday_start_datetime,
regime=regime,
)


def get_gsp_system(session: Session, gsp_id: Optional[int] = None) -> List[Location]:
"""
Get gsp system details
:param session:
:param gsp_id: optional input. If None, get all systems
:return:
"""

if gsp_id is not None:
gsp_systems = [get_location(session=session, gsp_id=gsp_id)]

else:
gsp_systems = get_all_locations(session=session)

return [Location.from_orm(gsp_system) for gsp_system in gsp_systems]
42 changes: 29 additions & 13 deletions src/gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

import geopandas as gpd
from fastapi import APIRouter, Depends
from nowcasting_datamodel.models import Forecast, ForecastValue, GSPYield, ManyForecasts
from nowcasting_datamodel.models import Forecast, ForecastValue, GSPYield, Location, ManyForecasts
from nowcasting_dataset.data_sources.gsp.eso import get_gsp_metadata_from_eso
from sqlalchemy.orm.session import Session

from database import (
get_forecasts_for_a_specific_gsp_from_database,
get_forecasts_from_database,
get_gsp_system,
get_latest_forecast_values_for_a_specific_gsp_from_database,
get_latest_national_forecast_from_database,
get_session,
Expand Down Expand Up @@ -81,6 +82,23 @@ async def get_truths_for_a_specific_gsp(
)


@router.get("/forecast/all", response_model=ManyForecasts)
async def get_all_available_forecasts(session: Session = Depends(get_session)) -> ManyForecasts:
"""Get the latest information for all available forecasts"""

logger.info("Get forecasts for all gsps")

return get_forecasts_from_database(session=session)


@router.get("/forecast/national", response_model=Forecast)
async def get_nationally_aggregated_forecasts(session: Session = Depends(get_session)) -> Forecast:
"""Get an aggregated forecast at the national level"""

logger.debug("Get national forecasts")
return get_latest_national_forecast_from_database(session=session)


@router.get("/gsp_boundaries")
async def get_gsp_boundaries() -> dict:
"""Get one gsp boundary for a specific GSP id
Expand All @@ -100,18 +118,16 @@ async def get_gsp_boundaries() -> dict:
return json.loads(json_string)


@router.get("/forecast/all", response_model=ManyForecasts)
async def get_all_available_forecasts(session: Session = Depends(get_session)) -> ManyForecasts:
"""Get the latest information for all available forecasts"""

logger.info("Get forecasts for all gsps")

return get_forecasts_from_database(session=session)
@router.get("/gsp_systems", response_model=List[Location])
async def get_systems(
session: Session = Depends(get_session), gsp_id: Optional[int] = None
) -> List[Location]:
"""
Get gsp system details.
Provide gsp_id to just return one gsp system, otherwise all are returned
"""

@router.get("/forecast/national", response_model=Forecast)
async def get_nationally_aggregated_forecasts(session: Session = Depends(get_session)) -> Forecast:
"""Get an aggregated forecast at the national level"""
logger.info(f"Get GSP systems for {gsp_id=}")

logger.debug("Get national forecasts")
return get_latest_national_forecast_from_database(session=session)
return get_gsp_system(session=session, gsp_id=gsp_id)
20 changes: 19 additions & 1 deletion src/tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" Test for main app """

from database import get_forecasts_for_a_specific_gsp_from_database, get_session
from database import get_forecasts_for_a_specific_gsp_from_database, get_gsp_system, get_session


def test_get_session():
Expand All @@ -14,3 +14,21 @@ def test_get_forecasts_for_a_specific_gsp_from_database(db_session, forecasts):
gsp_id = 1

_ = get_forecasts_for_a_specific_gsp_from_database(gsp_id=gsp_id, session=db_session)


def test_get_gsp_system_none(db_session):
"""Check get gsp system works with no systems"""
a = get_gsp_system(session=db_session)
assert len(a) == 0


def test_get_gsp_system_all(db_session, forecasts):
"""Check get gsp system works for all systems"""
a = get_gsp_system(session=db_session)
assert len(a) == 338


def test_get_gsp_system_one(db_session, forecasts):
"""Check get gsp system works for one system"""
a = get_gsp_system(session=db_session, gsp_id=1)
assert len(a) == 1
15 changes: 15 additions & 0 deletions src/tests/test_gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,18 @@ def test_read_forecast_one_gsp(db_session):
print(i)
assert len(r_json) == 3
_ = [ForecastValue(**forecast_value) for forecast_value in r_json]


def test_get_gsp_systems(db_session):
"""Check main GB/pv/gsp route works"""

forecasts = make_fake_forecasts(gsp_ids=list(range(0, 10)), session=db_session)
db_session.add_all(forecasts)

app.dependency_overrides[get_session] = lambda: db_session

response = client.get("v0/GB/solar/gsp/gsp_systems")
assert response.status_code == 200

locations = [Location(**location) for location in response.json()]
assert len(locations) == 10

0 comments on commit 97d4edc

Please sign in to comment.