Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data sampler update #170

Closed
wants to merge 11 commits into from
224 changes: 202 additions & 22 deletions pvnet_app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,179 @@ def save_yaml_config(config: dict, path: str) -> None:
yaml.dump(config, file, default_flow_style=False)


def populate_config_with_data_data_filepaths(config: dict, gsp_path: str = "") -> dict:
"""Populate the data source filepaths in the config
def populate_config_with_data_data_filepaths(config: dict, gsp_path: str = "", schema_version: str = "v0") -> dict:
"""Populate the data source filepaths in the config with schema version handling

Args:
config: The data config
gsp_path: For lagacy usage only
config: The data config. This can be stored on HuggingFace
gsp_path: For legacy usage only
schema_version: Configuration schema version from ocf-datapipes is v0. N
ew ocf-data-sample is v1 (> version 0.0.X).
"""

production_paths = {
"gsp": gsp_path,
"nwp": {"ukv": nwp_ukv_path, "ecmwf": nwp_ecmwf_path},
"satellite": sat_path,
}

# Replace data sources
# Handle GSP and satellite sources
for source in ["gsp", "satellite"]:
if source in config["input_data"]:
if config["input_data"][source][f"{source}_zarr_path"] != "":
source_config = config["input_data"][source]

# Determine schema version and appropriate key
schema_version = source_config.get("config_schema_version", "v0")
zarr_path_key = "zarr_path" if schema_version == "v1" else f"{source}_zarr_path"

# Specific handling for GSP to ensure data sampler compatibility
if source == "gsp":
# Ensure backward compatibility and add missing keys
# source_config.setdefault('config_schema_version', schema_version)

# Add default keys for OCF data sampler
gsp_keys_to_ensure = [
'installed_capacity_mwp',
'generation_mw',
'effective_capacity_mwp',
'capacity_mwp'
]

# Ensure zarr path keys are consistent
if schema_version == "v1":
if f"gsp_zarr_path" in source_config:
source_config["zarr_path"] = source_config.pop(f"gsp_zarr_path")

for key in gsp_keys_to_ensure:
source_config.setdefault(key, True)
else:
if "zarr_path" in source_config:
source_config["gsp_zarr_path"] = source_config.pop("zarr_path")

if not source_config.get(zarr_path_key, ""):
assert source in production_paths, f"Missing production path: {source}"
config["input_data"][source][f"{source}_zarr_path"] = production_paths[source]
source_config[zarr_path_key] = production_paths[source]

# NWP is nested so much be treated separately
# Handle NWP - nested structure
if "nwp" in config["input_data"]:
nwp_config = config["input_data"]["nwp"]
for nwp_source in nwp_config.keys():
if nwp_config[nwp_source]["nwp_zarr_path"] != "":
assert "nwp" in production_paths, "Missing production path: nwp"
assert nwp_source in production_paths["nwp"], f"Missing NWP path: {nwp_source}"
nwp_config[nwp_source]["nwp_zarr_path"] = production_paths["nwp"][nwp_source]

# Check schema version for NWP
schema_version = nwp_config[nwp_source].get("config_schema_version", "v0")

# Set appropriate keys based on schema version
if schema_version == "v1":
path_key = "zarr_path"
provider_key = "provider"
channel_key = "channels"
height_key = "image_size_pixels_height"
width_key = "image_size_pixels_width"
else:
path_key = "nwp_zarr_path"
provider_key = "nwp_provider"
channel_key = "nwp_channels"
height_key = "nwp_image_size_pixels_height"
width_key = "nwp_image_size_pixels_width"

if not nwp_config[nwp_source].get(path_key, ""):
provider = nwp_config[nwp_source][provider_key].lower()
assert provider in production_paths["nwp"], f"Missing NWP path: {provider}"
nwp_config[nwp_source][path_key] = production_paths["nwp"][provider]

# Ensure all keys align with schema version
for v0_key, v1_key in [
("nwp_zarr_path", "zarr_path"),
("nwp_channels", "channels"),
("nwp_image_size_pixels_height", "image_size_pixels_height"),
("nwp_image_size_pixels_width", "image_size_pixels_width"),
("nwp_provider", "provider")
]:
if schema_version == "v1" and v0_key in nwp_config[nwp_source]:
nwp_config[nwp_source][v1_key] = nwp_config[nwp_source].pop(v0_key)
elif schema_version == "v0" and v1_key in nwp_config[nwp_source]:
nwp_config[nwp_source][v0_key] = nwp_config[nwp_source].pop(v1_key)

return config


# def overwrite_config_dropouts(config: dict) -> dict:
# """Overwrite the config dropout parameters for production with schema version handling

# Args:
# config: The data config
# """

# # Replace data source - satellite
# for source in ["satellite"]:
# if source in config["input_data"]:
# source_config = config["input_data"][source]

# # Check schema version
# schema_version = source_config.get("config_schema_version", "v0")
# zarr_path_key = "zarr_path" if schema_version == "v1" else f"{source}_zarr_path"

# if source_config.get(zarr_path_key, ""):
# source_config["dropout_timedeltas_minutes"] = None

# # Handle NWP separately - nested structure
# if "nwp" in config["input_data"]:
# nwp_config = config["input_data"]["nwp"]
# for nwp_source in nwp_config.keys():
# # Check schema version
# schema_version = nwp_config[nwp_source].get("config_schema_version", "v0")
# path_key = "zarr_path" if schema_version == "v1" else "nwp_zarr_path"

# if nwp_config[nwp_source].get(path_key, ""):
# nwp_config[nwp_source]["dropout_timedeltas_minutes"] = None

# return config


def overwrite_config_dropouts(config: dict) -> dict:
"""Overwrite the config drouput parameters for production
"""Overwrite the config dropout parameters for production with enhanced schema version handling

Args:
config: The data config
"""
# Ensure input_data exists
if "input_data" not in config:
config["input_data"] = {}

# Replace data sources
# Replace data source - satellite
for source in ["satellite"]:
if source in config["input_data"]:
if config["input_data"][source][f"{source}_zarr_path"] != "":
config["input_data"][source][f"dropout_timedeltas_minutes"] = None

# NWP is nested so much be treated separately
source_config = config["input_data"][source]

# Determine path keys
path_keys = [f"{source}_zarr_path", "zarr_path"]
path_key = next((key for key in path_keys if key in source_config), path_keys[0])

# Ensure dropout key exists
if "dropout_timedeltas_minutes" not in source_config:
source_config["dropout_timedeltas_minutes"] = None

# Set dropout to None if path exists
if source_config.get(path_key, ""):
source_config["dropout_timedeltas_minutes"] = None

# Handle NWP separately - nested structure
if "nwp" in config["input_data"]:
nwp_config = config["input_data"]["nwp"]
for nwp_source in nwp_config.keys():
if nwp_config[nwp_source]["nwp_zarr_path"] != "":
# Determine path keys
path_keys = ["nwp_zarr_path", "zarr_path"]
path_key = next((key for key in path_keys if key in nwp_config[nwp_source]), path_keys[0])

# Ensure dropout key exists
if "dropout_timedeltas_minutes" not in nwp_config[nwp_source]:
nwp_config[nwp_source]["dropout_timedeltas_minutes"] = None


# Set dropout to None if path exists
if nwp_config[nwp_source].get(path_key, ""):
nwp_config[nwp_source]["dropout_timedeltas_minutes"] = None

return config


def modify_data_config_for_production(
input_path: str, output_path: str, gsp_path: str = ""
) -> None:
Expand Down Expand Up @@ -134,3 +251,66 @@ def get_union_of_configs(config_paths: list[str]) -> dict:
common_config["input_data"]["nwp"][nwp_key] = nwp_conf

return common_config


# PURELY FOR TESTING PURPOSE FOR NOW
def ensure_zarr_metadata(zarr_path: str) -> None:
"""
Ensure that a Zarr store has proper metadata, creating it if necessary.

Args:
zarr_path (str): Path to the Zarr store
"""
import os
import json
import zarr
import xarray as xr

try:
# Check if the Zarr path exists
if not os.path.exists(zarr_path):
return

# Try to open with xarray first to generate metadata
try:
ds = xr.open_zarr(zarr_path)
ds.to_zarr(zarr_path, mode='a')
except Exception as e:
print(f"Error opening Zarr with xarray: {e}")

# Open the Zarr store
store = zarr.DirectoryStore(zarr_path)

# Ensure .zgroup exists
if '.zgroup' not in store:
store['.zgroup'] = json.dumps({"zarr_format": 2}).encode('utf-8')

# Check if .zmetadata exists
if '.zmetadata' not in store:
# Create basic metadata
metadata = {
"zarr_format": 2,
"variables": {},
"attributes": {}
}

# Try to list groups/datasets
try:
zgroup = zarr.open_group(store=store)
for key in zgroup.keys():
dataset = zgroup[key]
metadata["variables"][key] = {
"dtype": str(dataset.dtype),
"shape": dataset.shape
}
except Exception as e:
print(f"Error creating metadata: {e}")

# Write metadata
try:
store['.zmetadata'] = json.dumps(metadata).encode('utf-8')
except Exception as e:
print(f"Error writing .zmetadata: {e}")

except Exception as e:
print(f"Unexpected error in ensure_zarr_metadata: {e}")
1 change: 1 addition & 0 deletions pvnet_app/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_dataloader(
# Populate the data config with production data paths
modified_data_config_filename = Path(config_filename).parent / "data_config.yaml"

# TODO pass in schema_version to populated_data_config_filename using partial
modify_data_config_for_production(config_filename, modified_data_config_filename)

dataset = PVNetUKRegionalDataset(
Expand Down
27 changes: 24 additions & 3 deletions pvnet_app/model_configs/all_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Batches are prepared only once, so the extra models must be able to run on the batches created
# to run the pvnet_v2 model
models:

- name: pvnet_v2
pvnet:
repo: openclimatefix/pvnet_uk_region
Expand All @@ -13,13 +14,15 @@ models:
save_gsp_sum: False
verbose: True
save_gsp_to_recent: true

- name: pvnet_v2-sat0-samples-v1
pvnet:
repo: openclimatefix/pvnet_uk_region
version: 8a7cc21b64d25ce1add7a8547674be3143b2e650
summation:
repo: openclimatefix/pvnet_v2_summation
version: dcfdc17fda8e48c387122614bec8b284eaa868b9

# single source models
- name: pvnet_v2-sat0-only-samples-v1"
pvnet:
Expand Down Expand Up @@ -47,6 +50,7 @@ models:
version: 4fe6b1441b6dd549292c201ed85eee156ecc220c
ecmwf_only: True
uses_satellite_data: False

# This is the old model for pvnet and pvnet_ecmwf
- name: pvnet_v2
pvnet:
Expand All @@ -60,6 +64,7 @@ models:
verbose: True
save_gsp_to_recent: True
uses_ocf_data_sampler: False

- name: pvnet_ecmwf # this name is important as it used for blending
pvnet:
repo: openclimatefix/pvnet_uk_region
Expand All @@ -70,9 +75,9 @@ models:
ecmwf_only: True
uses_satellite_data: False
uses_ocf_data_sampler: False
# The day ahead model has not yet been re-trained with data-sampler.
# It will be run with the legacy dataloader using ocf_datapipes
- name: pvnet_day_ahead

# Legacy day ahead model without data-sampler
- name: pvnet_day_ahead_legacy
pvnet:
repo: openclimatefix/pvnet_uk_region_day_ahead
version: d87565731692a6003e43caac4feaed0f69e79272
Expand All @@ -86,3 +91,19 @@ models:
day_ahead: True
uses_ocf_data_sampler: False

# Day ahead model that utilises data-sampler
- name: pvnet_day_ahead
pvnet:
repo: openclimatefix/pvnet_uk_region_day_ahead
version: 263741ebb6b71559d113d799c9a579a973cc24ba
summation:
repo: null
version: null
# repo: openclimatefix/pvnet_summation_uk_national_day_ahead
# version: ed60c5d32a020242ca4739dcc6dbc8864f783a08
use_adjuster: True
save_gsp_sum: True
verbose: True
save_gsp_to_recent: True
day_ahead: True
uses_ocf_data_sampler: True
Loading
Loading