From a6e72f3661613317fcb30dbfa067f14bbddd0d5d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Sep 2024 13:39:11 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocf_datapipes/config/model.py | 8 +++++--- ocf_datapipes/training/pvnet_site.py | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/ocf_datapipes/config/model.py b/ocf_datapipes/config/model.py index 5fdd32e9..9fd89c8b 100644 --- a/ocf_datapipes/config/model.py +++ b/ocf_datapipes/config/model.py @@ -502,9 +502,11 @@ class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin): description="The temporal resolution (in minutes) of the data." "Note that this needs to be divisible by 5.", ) - satellite_scaling_methods: Optional[List[str]] = Field(['mean_std'], - description='There are few ways to scale the satellite data. ' - '1. None, 2. mean_std, 3. min_max') + satellite_scaling_methods: Optional[List[str]] = Field( + ["mean_std"], + description="There are few ways to scale the satellite data. " + "1. None, 2. mean_std, 3. min_max", + ) class HRVSatellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin): diff --git a/ocf_datapipes/training/pvnet_site.py b/ocf_datapipes/training/pvnet_site.py index bf61e766..066567b0 100644 --- a/ocf_datapipes/training/pvnet_site.py +++ b/ocf_datapipes/training/pvnet_site.py @@ -24,9 +24,9 @@ NWP_MEANS, NWP_STDS, RSS_MEAN, - RSS_STD, RSS_RAW_MAX, RSS_RAW_MIN, + RSS_STD, ) from ocf_datapipes.utils.utils import ( combine_to_single_dataset, @@ -276,9 +276,9 @@ def construct_sliced_data_pipeline( roi_width_pixels=conf_sat.satellite_image_size_pixels_width, ) scaling_methods = conf_sat.satellite_scaling_methods - if 'min_max' in scaling_methods: + if "min_max" in scaling_methods: sat_datapipe = sat_datapipe.normalize(min_values=RSS_RAW_MIN, max_values=RSS_RAW_MAX) - if 'mean_std': + if "mean_std": sat_datapipe = sat_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD) if "pv" in datapipes_dict: