diff --git a/containers/sat/Containerfile b/containers/sat/Containerfile index 03c792e..6e76b21 100644 --- a/containers/sat/Containerfile +++ b/containers/sat/Containerfile @@ -5,7 +5,7 @@ FROM quay.io/condaforge/miniforge3:latest AS build-venv RUN apt -qq update && apt -qq install -y build-essential RUN conda create -p /venv python=3.12 RUN /venv/bin/pip install --upgrade -q pip wheel setuptools -RUN conda install -p /venv -c conda-forge -y cartopy satpy[all] numpy +RUN conda install -p /venv -c conda-forge -y cartopy satpy[all]=0.50.0 numpy ENV GDAL_CONFIG=/venv/bin/gdal-config # Build the virtualenv diff --git a/containers/sat/download_process_sat.py b/containers/sat/download_process_sat.py index d5bea9a..396419b 100644 --- a/containers/sat/download_process_sat.py +++ b/containers/sat/download_process_sat.py @@ -2,6 +2,7 @@ Consolidates the old cli_downloader, backfill_hrv and backfill_nonhrv scripts. """ + import argparse import dataclasses import datetime as dt @@ -12,20 +13,25 @@ import pathlib import shutil import sys +import traceback from multiprocessing import Pool, cpu_count from typing import Literal -import diskcache as dc +import dask.delayed +import dask.distributed +import dask.diagnostics import eumdac import eumdac.cli import numpy as np import pandas as pd import pyproj import pyresample +import satpy.dataset.dataid import xarray as xr import yaml import zarr from ocf_blosc2 import Blosc2 + from satpy import Scene handler = logging.StreamHandler(sys.stdout) @@ -230,15 +236,12 @@ def process_scans( """ # Check zarr file exists for the year zarr_path: pathlib.Path = folder.parent / start.strftime(sat_config.zarr_fmtstr[dstype]) + zarr_times: list[dt.datetime] = [] if zarr_path.exists(): - zarr_times: list[dt.datetime] = xr.open_zarr(zarr_path).sortby("time").time.values.tolist() - last_zarr_time: dt.datetime = zarr_times[-1] + zarr_times = xr.open_zarr(zarr_path, consolidated=True).sortby("time").time.values.tolist() log.debug(f"Zarr store already exists at {zarr_path} for {zarr_times[0]}-{zarr_times[-1]}") else: - # Set dummy values for times already in zarr - last_zarr_time = dt.datetime(1970, 1, 1, tzinfo=dt.UTC) - zarr_times = [last_zarr_time, last_zarr_time] - log.debug(f"Zarr store does not exist at {zarr_path}, setting dummy times") + log.debug(f"Zarr store does not exist at {zarr_path}") # Get native files in order native_files: list[pathlib.Path] = list(folder.glob("*.nat")) @@ -259,38 +262,54 @@ def process_scans( if dataset is not None: dataset = _preprocess_function(dataset) datasets.append(dataset) - # Append to zarrs in hourly chunks (12 sets of 5 minute datasets) + # Append to zarrs in hourly chunks # * This is so zarr doesn't complain about mismatching chunk sizes - if len(datasets) == 12: + if len(datasets) == int(pd.Timedelta("1h") / pd.Timedelta(sat_config.cadence)): if pathlib.Path(zarr_path).exists(): log.debug(f"Appending to existing zarr store at {zarr_path}") mode = "a" else: log.debug(f"Creating new zarr store at {zarr_path}") mode = "w" + concat_ds: xr.Dataset = xr.concat(datasets, dim="time") _write_to_zarr( - xr.concat(datasets, dim="time"), + concat_ds, zarr_path.as_posix(), mode, - chunks={"time": 12}, + chunks={ + "time": len(datasets), + "x_geostationary": -1, + "y_geostationary": -1, + "variable": 1, + }, ) datasets = [] log.info(f"Process loop [{dstype}]: {i+1}/{len(native_files)}") # Consolidate zarr metadata - _rewrite_zarr_times(zarr_path.as_posix()) + if pathlib.Path(zarr_path).exists(): + _rewrite_zarr_times(zarr_path.as_posix()) return dstype +def _gen_token() -> eumdac.AccessToken: + """Generated an aces token from environment variables.""" + consumer_key: str = os.environ["EUMETSAT_CONSUMER_KEY"] + consumer_secret: str = os.environ["EUMETSAT_CONSUMER_SECRET"] + token = eumdac.AccessToken(credentials=(consumer_key, consumer_secret)) + + return token + + def _convert_scene_to_dataarray( scene: Scene, band: str, area: str, calculate_osgb: bool = True, ) -> xr.DataArray: - """Convertes a Scene with satellite data into a data array. + """Converts a Scene with satellite data into a data array. Args: scene: The satpy.Scene containing the satellite data @@ -314,17 +333,22 @@ def _convert_scene_to_dataarray( log.debug("Finished resample") scene = scene.crop(ll_bbox=GEOGRAPHIC_BOUNDS[area]) log.debug("Finished crop") - # Remove acq time from all bands because it is not useful, and can actually - # get in the way of combining multiple Zarr datasets. + + # Update the dataarray attributes based off of the satpy scene attributes data_attrs = {} for channel in scene.wishlist: + # Remove acq time from all bands because it is not useful, and can actually + # get in the way of combining multiple Zarr datasets. scene[channel] = scene[channel].drop_vars("acq_time", errors="ignore") for attr in scene[channel].attrs: new_name = channel["name"] + "_" + attr - data_attrs[new_name] = scene[channel].attrs[attr] + # Ignore the "area" and "_satpy_id" scene attributes as they are not serializable + # and their data is already present in other scene attrs anyway. + if attr not in ["area", "_satpy_id"]: + data_attrs[new_name] = scene[channel].attrs[attr].__repr__() + dataset: xr.Dataset = scene.to_xarray_dataset() dataarray = dataset.to_array() - log.debug("Converted to dataarray") # Lat and Lon are the same for all the channels now if calculate_osgb: @@ -348,8 +372,8 @@ def _convert_scene_to_dataarray( dataarray[name].attrs["coordinate_reference_system"] = "geostationary" log.debug("Calculated OSGB") # Round to the nearest 5 minutes - dataarray.attrs.update(data_attrs) - dataarray.attrs["end_time"] = pd.Timestamp(dataarray.attrs["end_time"]).round("5 min") + data_attrs["end_time"] = pd.Timestamp(dataarray.attrs["end_time"]).round("5 min").__str__() + dataarray.attrs = data_attrs # Rename x and y to make clear the coordinate system they are in dataarray = dataarray.rename({"x": "x_geostationary", "y": "y_geostationary"}) @@ -363,41 +387,7 @@ def _convert_scene_to_dataarray( return dataarray -def _serialize_attrs(attrs: dict) -> dict: - """Ensure each value of dict can be serialized. - - This is required before saving to Zarr because Zarr represents attrs values in a - JSON file (.zmetadata). - - The `area` field (which is a `pyresample.geometry.AreaDefinition` object gets turned - into a YAML string, which can be loaded again using - `area_definition = pyresample.area_config.load_area_from_string(data_array.attrs['area'])` - - Returns attrs dict where every value has been made serializable. - """ - for key, value in attrs.items(): - # Convert Dicts - if isinstance(value, dict): - # Convert np.float32 to Python floats (otherwise yaml.dump complains) - for inner_key in value: - inner_value = value[inner_key] - if isinstance(inner_value, np.floating): - value[inner_key] = float(inner_value) - attrs[key] = yaml.dump(value) - # Convert Numpy bools - if isinstance(value, bool | np.bool_): - attrs[key] = str(value) - # Convert area - if isinstance(value, pyresample.geometry.AreaDefinition): - attrs[key] = value.dump() - # Convert datetimes - if isinstance(value, dt.datetime): - attrs[key] = value.isoformat() - - return attrs - - -def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray | None: +def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray: """Rescale Xarray DataArray so all values lie in the range [0, 1]. Warning: The original `dataarray` will be modified in-place. @@ -420,11 +410,12 @@ def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray | "variable", ) + # For each channel, subtract the minimum and divide by the range dataarray -= [c.minimum for c in channels] dataarray /= [c.maximum - c.minimum for c in channels] + # Since the mins and maxes are approximations, clip the values to [0, 1] dataarray = dataarray.clip(min=0, max=1) dataarray = dataarray.astype(np.float32) - dataarray.attrs = _serialize_attrs(dataarray.attrs) # Must be serializable return dataarray @@ -433,30 +424,45 @@ def _open_and_scale_data( f: str, dstype: Literal["hrv", "nonhrv"], ) -> xr.Dataset | None: - """Opens a raw file and converts it to a normalised xarray dataset.""" + """Opens a raw file and converts it to a normalised xarray dataset. + + Args: + zarr_times: List of times already in the zarr store. + f: Path to the file to open. + dstype: Type of data to process (hrv or nonhrv). + """ # The reader is the same for each satellite as the sensor is the same - # * Hence "severi" in all cases + # * Hence "seviri" in all cases scene = Scene(filenames={"seviri_l1b_native": [f]}) scene.load([c.variable for c in CHANNELS[dstype]]) - da: xr.DataArray = _convert_scene_to_dataarray( - scene, - band=CHANNELS[dstype][0].variable, - area="RSS", - calculate_osgb=False, - ) - # Rescale the data, update the attributes, save as dataset - attrs = _serialize_attrs(da.attrs) - da = _rescale(da, CHANNELS[dstype]) - da.attrs.update(attrs) - da = da.transpose("time", "y_geostationary", "x_geostationary", "variable") - ds: xr.Dataset = da.to_dataset(name="data") - ds["data"] = ds.data.astype(np.float16) + try: + da: xr.DataArray = _convert_scene_to_dataarray( + scene, + band=CHANNELS[dstype][0].variable, + area="RSS", + calculate_osgb=False, + ) + except Exception as e: + log.error(f"Error converting scene to dataarray: {e}") + return None + + # Don't proceed if the dataarray time is already present in the zarr store + if da.time.values[0] in zarr_times: + log.debug(f"Skipping: {da.time.values[0]}") + return None - if ds.time.values[0] in zarr_times: - log.debug(f"Skipping: {ds.time.values[0]}") + # Rescale the data, save as dataset + try: + da = _rescale(da, CHANNELS[dstype]) + except Exception as e: + log.error(f"Error rescaling dataarray: {e}") return None + da = da.transpose("time", "y_geostationary", "x_geostationary", "variable") + ds: xr.Dataset = da.to_dataset(name="data", promote_attrs=True) + ds["data"] = ds["data"].astype(np.float16) + return ds @@ -465,12 +471,12 @@ def _preprocess_function(xr_data: xr.Dataset) -> xr.Dataset: attrs = xr_data.attrs y_coords = xr_data.coords["y_geostationary"].values x_coords = xr_data.coords["x_geostationary"].values - x_dataarray = xr.DataArray( + x_dataarray: xr.DataArray = xr.DataArray( data=np.expand_dims(xr_data.coords["x_geostationary"].values, axis=0), dims=["time", "x_geostationary"], coords={"time": xr_data.coords["time"].values, "x_geostationary": x_coords}, ) - y_dataarray = xr.DataArray( + y_dataarray: xr.DataArray = xr.DataArray( data=np.expand_dims(xr_data.coords["y_geostationary"].values, axis=0), dims=["time", "y_geostationary"], coords={"time": xr_data.coords["time"].values, "y_geostationary": y_coords}, @@ -498,21 +504,25 @@ def _write_to_zarr(dataset: xr.Dataset, zarr_name: str, mode: str, chunks: dict) extra_kwargs = mode_extra_kwargs[mode] sliced_ds: xr.Dataset = dataset.isel(x_geostationary=slice(0, 5548)).chunk(chunks) try: - sliced_ds.to_zarr( + write_job = sliced_ds.to_zarr( store=zarr_name, - compute=True, + compute=False, consolidated=True, mode=mode, **extra_kwargs, ) + with dask.diagnostics.ProgressBar(): + write_job.compute() except Exception as e: - log.error(f"Error writing to zarr: {e}") + log.error(f"Error writing dataset to zarr store {zarr_name} with mode {mode}: {e}") + traceback.print_tb(e.__traceback__) + return None def _rewrite_zarr_times(output_name: str) -> None: """Rewrites the time coordinates in the given zarr store.""" # Combine time coords - ds = xr.open_zarr(output_name) + ds = xr.open_zarr(output_name, consolidated=True) # Prevent numcodecs string error # See https://github.com/pydata/xarray/issues/3476#issuecomment-1205346130 @@ -524,6 +534,10 @@ def _rewrite_zarr_times(output_name: str) -> None: ds[v].encoding.clear() del ds["data"] + if "x_geostationary_coordinates" in ds: + del ds["x_geostationary_coordinates"] + if "y_geostationary_coordinates" in ds: + del ds["y_geostationary_coordinates"] # Need to remove these encodings to avoid chunking del ds.time.encoding["chunks"] del ds.time.encoding["preferred_chunks"] @@ -556,37 +570,39 @@ def _rewrite_zarr_times(output_name: str) -> None: choices=list(CONFIGS.keys()), ) parser.add_argument( - "--path", + "--path", "-p", help="Path to store the downloaded data", default="/mnt/disks/sat", type=pathlib.Path, ) parser.add_argument( - "--start_date", + "--start_date", "-s", help="Date to download from (YYYY-MM-DD)", type=dt.date.fromisoformat, required=False, default=str(dt.datetime.now(tz=dt.UTC).date()), ) parser.add_argument( - "--end_date", + "--end_date", "-e", help="Date to download to (YYYY-MM-DD)", type=dt.date.fromisoformat, required=False, default=str(dt.datetime.now(tz=dt.UTC).date()), ) +parser.add_argument( + "--delete_raw", "--rm", + help="Delete raw files after processing", + action="store_true", + default=False, +) -if __name__ == "__main__": +def run(args: argparse.Namespace) -> None: prog_start = dt.datetime.now(tz=dt.UTC) + log.info(f"{prog_start!s}: Running with args: {args}") - # Parse running args - args = parser.parse_args() + # Get running folder from args folder: pathlib.Path = args.path / args.sat - # Create a reusable cache - cache = dc.Cache(folder / ".cache/{args.sat}") - - log.info(f"{prog_start!s}: Running with args: {args}") # Get config for desired satellite sat_config = CONFIGS[args.sat] @@ -599,52 +615,41 @@ def _rewrite_zarr_times(output_name: str) -> None: freq=sat_config.cadence, ).tolist() - # Get average runtime from cache - secs_per_scan = cache.get("secs_per_scan", default=90) + # Estimate average runtime + secs_per_scan: int = 90 expected_runtime = pd.Timedelta(secs_per_scan * len(scan_times), "seconds") log.info(f"Downloading {len(scan_times)} scans. Expected runtime: {expected_runtime!s}") # Download data # We only parallelize if we have a number of files larger than the cpu count - consumer_key: str = os.environ["EUMETSAT_CONSUMER_KEY"] - consumer_secret: str = os.environ["EUMETSAT_CONSUMER_SECRET"] - token = eumdac.AccessToken(credentials=(consumer_key, consumer_secret)) - - results: list[pathlib.Path] = [] + token = _gen_token() + raw_paths: list[pathlib.Path] = [] if len(scan_times) > cpu_count(): log.debug(f"Concurrency: {cpu_count()}") pool = Pool(max(cpu_count(), 10)) # EUMDAC only allows for 10 concurrent requests - results = pool.starmap( + raw_paths = pool.starmap( download_scans, [(sat_config, folder, scan_time, token) for scan_time in scan_times], ) pool.close() pool.join() - results = list(itertools.chain(results)) + raw_paths = list(itertools.chain(raw_paths)) else: - results = [] + raw_paths = [] for scan_time in scan_times: result: list[pathlib.Path] = download_scans(sat_config, folder, scan_time, token) if len(result) > 0: - results.extend(result) + raw_paths.extend(result) - log.info(f"Downloaded {len(results)} files.") + log.info(f"Downloaded {len(raw_paths)} files.") log.info("Converting raw data to HRV and non-HRV Zarr Stores.") # Process the HRV and non-HRV data concurrently if possible completed_types: list[str] = [] - if cpu_count() > 1: - pool = Pool(cpu_count()) - completed_types = pool.starmap( - process_scans, - [(sat_config, folder, start, end, t) for t in ["hrv", "nonhrv"]], - ) - pool.close() - pool.join() - else: - for t in ["hrv", "nonhrv"]: - completed_type = process_scans(sat_config, folder, start, end, t) - completed_types.append(completed_type) + for t in ["hrv", "nonhrv"]: + log.info("Processing {t} data.") + completed_type = process_scans(sat_config, folder, start, end, t) + completed_types.append(completed_type) for completed_type in completed_types: log.info(f"Processed {completed_type} data.") @@ -653,5 +658,16 @@ def _rewrite_zarr_times(output_name: str) -> None: new_average_secs_per_scan: int = int( (secs_per_scan + (runtime.total_seconds() / len(scan_times))) / 2, ) - cache.set("secs_per_scan", new_average_secs_per_scan) log.info(f"Completed archive for args: {args}. ({new_average_secs_per_scan} seconds per scan).") + + # Delete raw files, if desired + if args.delete_raw: + log.info(f"Deleting {len(raw_paths)} raw files in {folder.as_posix()}.") + for f in raw_paths: + f.unlink() + + +if __name__ == "__main__": + # Parse running args + args = parser.parse_args() + run(args) diff --git a/containers/sat/test_download_process_sat.py b/containers/sat/test_download_process_sat.py new file mode 100644 index 0000000..9ac16ad --- /dev/null +++ b/containers/sat/test_download_process_sat.py @@ -0,0 +1,127 @@ +"""Tests for the satellite processing pipeline. + +Note that, since the files from EUMETSAT are so large, +they must be downloaded prior to running the tests - they +are too biug to include in the repository. As such, environment +variables must be set to authenticate with EUMETSAT +""" + +import datetime as dt +import pathlib +import unittest + +import download_process_sat as dps +import numpy as np +import pandas as pd +import xarray as xr +from satpy import Scene + + +class TestDownloadProcessSat(unittest.TestCase): + paths: list[pathlib.Path] + test_dataarrays: dict[str, xr.DataArray] + + @classmethod + def setUpClass(cls) -> None: + TIMESTAMP = pd.Timestamp("2024-01-01T00:00:00Z") + + token = dps._gen_token() + + for t in [TIMESTAMP + pd.Timedelta(t) for t in ["0m", "15m", "30m", "45m"]]: + paths = dps.download_scans( + sat_config=dps.CONFIGS["iodc"], + folder=pathlib.Path("/tmp/test_sat_data"), + scan_time=t, + token=token, + ) + cls.paths = paths + + attrs: dict = { + "end_time": TIMESTAMP + pd.Timedelta("15m"), + "modifiers": (), + "orbital_parameters": {"projection_longitude": 45.5, "projection_latitude": 0.0, + "projection_altitude": 35785831.0, "satellite_nominal_longitude": 45.5, + "satellite_nominal_latitude": 0.0, "satellite_actual_longitude": 45.703605543834364, + "satellite_actual_latitude": 7.281469039541501, + "satellite_actual_altitude": 35788121.627292305}, + "reader": "seviri_l1b_native", + "sensor": "seviri", + "resolution": 3000.403165817, + "start_time": dt.datetime(2024, 1, 1, 0, 0, tzinfo=dt.UTC), + "platform_name": "Meteosat-9", "area": "Area ID: msg_seviri_iodc_3km", + } + + cls.test_dataarrays = { + "hrv": xr.DataArray( + data=np.random.random((1, 1, 3712, 3712)), + dims=["time", "variable", "x_geostationary", "y_geostationary"], + coords={ + "time": [pd.Timestamp("2024-01-01T00:00:00Z")], + "variable": ["HRV"], + "x_geostationary": np.arange(3712), + "y_geostationary": np.arange(3712), + }, + attrs=attrs, + ), + "nonhrv": xr.DataArray( + data=np.random.random((1, 11, 3712, 3712)), + dims=["time", "variable", "x_geostationary", "y_geostationary"], + coords={ + "time": [pd.Timestamp("2024-01-01T00:00:00Z")], + "variable": [c.variable for c in dps.CHANNELS["nonhrv"]], + "x_geostationary": np.arange(3712), + "y_geostationary": np.arange(3712), + }, + attrs=attrs, + ), + } + + def test_download_scans(self) -> None: + self.assertGreater(len(self.paths), 0) + + def test_convert_scene_to_dataarray(self) -> None: + scene = Scene(filenames={"seviri_l1b_native": [self.paths[0].as_posix()]}) + scene.load([c.variable for c in dps.CHANNELS["nonhrv"]]) + da = dps._convert_scene_to_dataarray( + scene, + band=dps.CHANNELS["nonhrv"][0].variable, + area="RSS", + calculate_osgb=False, + ) + + with self.subTest("Returned dataarray is correct shape"): + self.assertDictEqual( + dict(da.sizes), + {"time": 1, "variable": 11, "x_geostationary": 3712, "y_geostationary": 3712}, + ) + self.assertIn("end_time", da.attrs) + + def test_rescale(self) -> None: + da: xr.DataArray = dps._rescale(self.test_dataarrays["nonhrv"], channels=dps.CHANNELS["nonhrv"]) + + self.assertGreater(da.values.max(), 0) + self.assertLess(da.values.min(), 1) + self.assertEqual(da.attrs, self.test_dataarrays["nonhrv"].attrs) + + def test_open_and_scale_data(self) -> None: + ds: xr.Dataset | None = dps._open_and_scale_data([], self.paths[0].as_posix(), "nonhrv") + + if ds is None: + self.fail("Dataset is None") + + ds.to_zarr("/tmp/test_sat_data/test.zarr", mode="w", consolidated=True) + ds2 = xr.open_zarr("/tmp/test_sat_data/test.zarr") + self.assertDictEqual(dict(ds.sizes), dict(ds2.sizes)) + self.assertNotEqual(dict(ds.attrs), {}) + + def test_process_scans(self) -> None: + + out: str = dps.process_scans( + dps.CONFIGS["iodc"], + pathlib.Path("/tmp/test_sat_data"), + pd.Timestamp("2024-01-01"), + pd.Timestamp("2024-01-02"), "nonhrv", + ) + + self.assertTrue(False) +