From 855db76959bfdd3ca0028b329eb7b527cb092324 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Tue, 15 Oct 2024 11:28:01 +0100 Subject: [PATCH] add back in pvnet ecmwf old model --- pvnet_app/app.py | 7 +++++-- pvnet_app/model_configs/all_models.yaml | 12 +++++++++++ pvnet_app/model_configs/pydantic_models.py | 23 +++++++++++++++++++-- tests/model_configs/test_pydantic_models.py | 14 ++++++++++++- 4 files changed, 51 insertions(+), 5 deletions(-) diff --git a/pvnet_app/app.py b/pvnet_app/app.py index 802c0c8..960aced 100644 --- a/pvnet_app/app.py +++ b/pvnet_app/app.py @@ -107,6 +107,7 @@ def app( - SENTRY_DSN, optional link to sentry - ENVIRONMENT, the environment this is running in, defaults to local - USE_ECMWF_ONLY, option to use ecmwf only model, defaults to false + - USE_OCF_DATA_SAMPLER, option to use ocf_data_sampler, defaults to true Args: t0 (datetime): Datetime at which forecast is made @@ -127,18 +128,20 @@ def app( use_day_ahead_model = os.getenv("DAY_AHEAD_MODEL", "false").lower() == "true" use_ecmwf_only = os.getenv("USE_ECMWF_ONLY", "false").lower() == "true" run_extra_models = os.getenv("RUN_EXTRA_MODELS", "false").lower() == "true" + use_ocf_data_sampler = os.getenv("USE_OCF_DATA_SAMPLER", "true").lower() == "true" logger.info(f"Using `pvnet` library version: {pvnet.__version__}") logger.info(f"Using `pvnet_app` library version: {pvnet_app.__version__}") logger.info(f"Using {num_workers} workers") logger.info(f"Using day ahead model: {use_day_ahead_model}") - logger.info(f"Using ecwmwf only: {use_ecmwf_only}") + logger.info(f"Using ecmwf only: {use_ecmwf_only}") logger.info(f"Running extra models: {run_extra_models}") # load models model_configs = get_all_models(get_ecmwf_only=use_ecmwf_only, get_day_ahead_only=use_day_ahead_model, - run_extra_models=run_extra_models) + run_extra_models=run_extra_models, + use_ocf_data_sampler=use_ocf_data_sampler) logger.info(f"Using adjuster: {model_configs[0].use_adjuster}") logger.info(f"Saving GSP sum: {model_configs[0].save_gsp_sum}") diff --git a/pvnet_app/model_configs/all_models.yaml b/pvnet_app/model_configs/all_models.yaml index 6d81057..988b4e5 100644 --- a/pvnet_app/model_configs/all_models.yaml +++ b/pvnet_app/model_configs/all_models.yaml @@ -47,6 +47,17 @@ models: version: 4fe6b1441b6dd549292c201ed85eee156ecc220c ecmwf_only: True uses_satellite_data: False +# This is the old model for pvnet_ecmwf + - name: pvnet_ecmwf # this name is important as it used for blending + pvnet: + repo: openclimatefix/pvnet_uk_region + version: 35d55181a82440bdd087f380d650bfd0b64bd322 + summation: + repo: openclimatefix/pvnet_v2_summation + version: 9002baf1e9dc1ec141f3c4a1fa8447b6316a4558 + ecmwf_only: True + uses_satellite_data: False + uses_ocf_data_sampler: False # The day ahead model has not yet been re-trained with data-sampler. # It will be run with the legacy dataloader using ocf_datapipes - name: pvnet_day_ahead @@ -61,4 +72,5 @@ models: verbose: True save_gsp_to_recent: True day_ahead: True + uses_ocf_data_sampler: False diff --git a/pvnet_app/model_configs/pydantic_models.py b/pvnet_app/model_configs/pydantic_models.py index 3f3635b..e5e0a08 100644 --- a/pvnet_app/model_configs/pydantic_models.py +++ b/pvnet_app/model_configs/pydantic_models.py @@ -49,6 +49,12 @@ class Model(BaseModel): True, title="Uses Satellite Data", description="If this model uses satellite data" ) + uses_ocf_data_sampler: Optional[bool] = Field( + True, title="Uses OCF Data Sampler", description="If this model uses data sampler, old one uses ocf_datapipes" + ) + + + class Models(BaseModel): """A group of ml models""" @@ -60,8 +66,8 @@ class Models(BaseModel): @field_validator("models") @classmethod def name_must_be_unique(cls, v: List[Model]) -> List[Model]: - """Ensure that all model names are unique""" - names = [model.name for model in v] + """Ensure that all model names are unique, respect to using ocf_data_sampler or not""" + names = [(model.name,model.uses_ocf_data_sampler) for model in v] unique_names = set(names) if len(names) != len(unique_names): @@ -73,6 +79,7 @@ def get_all_models( get_ecmwf_only: Optional[bool] = False, get_day_ahead_only: Optional[bool] = False, run_extra_models: Optional[bool] = False, + use_ocf_data_sampler: Optional[bool] = True, ) -> List[Model]: """ Returns all the models for a given client @@ -81,6 +88,7 @@ def get_all_models( get_ecmwf_only: If only the ECMWF model should be returned get_day_ahead_only: If only the day ahead model should be returned run_extra_models: If extra models should be run + use_ocf_data_sampler: If the OCF Data Sampler should be used """ # load models from yaml file @@ -92,10 +100,12 @@ def get_all_models( models = config_pvnet_v2_model(models) + print(len(models.models)) if get_ecmwf_only: log.info("Using ECMWF model only") models.models = [model for model in models.models if model.ecmwf_only] + print(len(models.models)) if get_day_ahead_only: log.info("Using Day Ahead model only") models.models = [model for model in models.models if model.day_ahead] @@ -103,10 +113,19 @@ def get_all_models( log.info("Not using Day Ahead model") models.models = [model for model in models.models if not model.day_ahead] + print(len(models.models)) if not run_extra_models and not get_day_ahead_only and not get_ecmwf_only: log.info("Not running extra models") models.models = [model for model in models.models if model.name == "pvnet_v2"] + print(len(models.models)) + if use_ocf_data_sampler: + log.info("Using OCF Data Sampler") + models.models = [model for model in models.models if model.uses_ocf_data_sampler] + else: + log.info("Not using OCF Data Sampler, using ocf_datapipes") + models.models = [model for model in models.models if not model.uses_ocf_data_sampler] + return models.models diff --git a/tests/model_configs/test_pydantic_models.py b/tests/model_configs/test_pydantic_models.py index 585a55e..e21f23a 100644 --- a/tests/model_configs/test_pydantic_models.py +++ b/tests/model_configs/test_pydantic_models.py @@ -18,7 +18,7 @@ def test_get_all_models_get_ecmwf_only(): def test_get_all_models_get_day_ahead_only(): """Test for getting all models with ecmwf_only""" - models = get_all_models(get_day_ahead_only=True) + models = get_all_models(get_day_ahead_only=True, use_ocf_data_sampler=False) assert len(models) == 1 assert models[0].day_ahead @@ -28,3 +28,15 @@ def test_get_all_models_run_extra_models(): models = get_all_models(run_extra_models=True) assert len(models) == 5 + +def test_get_all_models_ocf_data_sampler(): + """Test for getting all models with ecmwf_only""" + models = get_all_models(use_ocf_data_sampler=True, run_extra_models=True) + assert len(models) == 5 + + models = get_all_models(use_ocf_data_sampler=False, run_extra_models=True) + assert len(models) == 1 + + models = get_all_models(use_ocf_data_sampler=False, run_extra_models=True, get_day_ahead_only=True) + assert len(models) == 1 +