diff --git a/src/anemoi/datasets/data/grids.py b/src/anemoi/datasets/data/grids.py index e8859cdb..15915f13 100644 --- a/src/anemoi/datasets/data/grids.py +++ b/src/anemoi/datasets/data/grids.py @@ -12,6 +12,7 @@ from functools import cached_property import numpy as np +from scipy.spatial import cKDTree from .debug import Node from .debug import debug_indexing @@ -142,95 +143,250 @@ def tree(self): class Cutout(GridsBase): - def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, neighbours=5, plot=False): - from anemoi.datasets.grids import cutout_mask - + def __init__(self, datasets, axis=3, cropping_distance=2.0, neighbours=5, min_distance_km=None, plot=None): + """Initializes a Cutout object for hierarchical management of Limited Area + Models (LAMs) and a global dataset, handling overlapping regions. + + Args: + datasets (list): List of LAM and global datasets. + axis (int): Concatenation axis, must be set to 3. + cropping_distance (float): Distance threshold in degrees for + cropping cutouts. + neighbours (int): Number of neighboring points to consider when + constructing masks. + min_distance_km (float, optional): Minimum distance threshold in km + between grid points. + plot (bool, optional): Flag to enable or disable visualization + plots. + """ super().__init__(datasets, axis) - assert len(datasets) == 2, "CutoutGrids requires two datasets" + assert len(datasets) >= 2, "CutoutGrids requires at least two datasets" assert axis == 3, "CutoutGrids requires axis=3" + assert cropping_distance >= 0, "cropping_distance must be a non-negative number" + if min_distance_km is not None: + assert min_distance_km >= 0, "min_distance_km must be a non-negative number" + + self.lams = datasets[:-1] # Assume the last dataset is the global one + self.globe = datasets[-1] + self.axis = axis + self.cropping_distance = cropping_distance + self.neighbours = neighbours + self.min_distance_km = min_distance_km + self.plot = plot + self.masks = [] # To store the masks for each LAM dataset + self.global_mask = np.ones(self.globe.shape[-1], dtype=bool) + + # Initialize cumulative masks + self._initialize_masks() + + def _initialize_masks(self): + """Generates hierarchical masks for each LAM dataset by excluding + overlapping regions with previous LAMs and creating a global mask for + the global dataset. + + Raises: + ValueError: If the global mask dimension does not match the global + dataset grid points. + """ + from anemoi.datasets.grids import cutout_mask - # We assume that the LAM is the first dataset, and the global is the second - # Note: the second fields does not really need to be global - - self.lam, self.globe = datasets - self.mask = cutout_mask( - self.lam.latitudes, - self.lam.longitudes, - self.globe.latitudes, - self.globe.longitudes, - plot=plot, - min_distance_km=min_distance_km, - cropping_distance=cropping_distance, - neighbours=neighbours, - ) - assert len(self.mask) == self.globe.shape[3], ( - len(self.mask), - self.globe.shape[3], - ) + for i, lam in enumerate(self.lams): + assert len(lam.shape) == len( + self.globe.shape + ), "LAMs and global dataset must have the same number of dimensions" + lam_lats = lam.latitudes + lam_lons = lam.longitudes + # Create a mask for the global dataset excluding all LAM points + global_overlap_mask = cutout_mask( + lam.latitudes, + lam.longitudes, + self.globe.latitudes, + self.globe.longitudes, + plot=False, + min_distance_km=self.min_distance_km, + cropping_distance=self.cropping_distance, + neighbours=self.neighbours, + ) + + # Ensure the mask dimensions match the global grid points + if global_overlap_mask.shape[0] != self.globe.shape[-1]: + raise ValueError("Global mask dimension does not match global dataset grid " "points.") + self.global_mask[~global_overlap_mask] = False + + # Create a mask for the LAM datasets hierarchically, excluding + # points from previous LAMs + lam_current_mask = np.ones(lam.shape[-1], dtype=bool) + if i > 0: + for j in range(i): + prev_lam = self.lams[j] + prev_lam_lats = prev_lam.latitudes + prev_lam_lons = prev_lam.longitudes + # Check for overlap by computing distances + if self.has_overlap(prev_lam_lats, prev_lam_lons, lam_lats, lam_lons): + lam_overlap_mask = cutout_mask( + prev_lam_lats, + prev_lam_lons, + lam_lats, + lam_lons, + plot=False, + min_distance_km=self.min_distance_km, + cropping_distance=self.cropping_distance, + neighbours=self.neighbours, + ) + lam_current_mask[~lam_overlap_mask] = False + self.masks.append(lam_current_mask) + + def has_overlap(self, lats1, lons1, lats2, lons2, distance_threshold=1.0): + """Checks for overlapping points between two sets of latitudes and + longitudes within a specified distance threshold. + + Args: + lats1, lons1 (np.ndarray): Latitude and longitude arrays for the + first dataset. + lats2, lons2 (np.ndarray): Latitude and longitude arrays for the + second dataset. + distance_threshold (float): Distance in degrees to consider as + overlapping. + + Returns: + bool: True if any points overlap within the distance threshold, + otherwise False. + """ + # Create KDTree for the first set of points + tree = cKDTree(np.vstack((lats1, lons1)).T) + + # Query the second set of points against the first tree + distances, _ = tree.query(np.vstack((lats2, lons2)).T, k=1) + + # Check if any distance is less than the specified threshold + return np.any(distances < distance_threshold) + + def __getitem__(self, index): + """Retrieves data from the masked LAMs and global dataset based on the + given index. + + Args: + index (int or slice or tuple): Index specifying the data to + retrieve. + + Returns: + np.ndarray: Data array from the masked datasets based on the index. + """ + if isinstance(index, (int, slice)): + index = (index, slice(None), slice(None), slice(None)) + return self._get_tuple(index) + + def _get_tuple(self, index): + """Helper method that applies masks and retrieves data from each dataset + according to the specified index. + + Args: + index (tuple): Index specifying slices to retrieve data. + + Returns: + np.ndarray: Concatenated data array from all datasets based on the + index. + """ + index, changes = index_to_slices(index, self.shape) + # Select data from each LAM + lam_data = [lam[index] for lam in self.lams] + + # First apply spatial indexing on `self.globe` and then apply the mask + globe_data_sliced = self.globe[index[:3]] + globe_data = globe_data_sliced[..., self.global_mask] + + # Concatenate LAM data with global data + result = np.concatenate(lam_data + [globe_data], axis=self.axis) + return apply_index_to_slices_changes(result, changes) def collect_supporting_arrays(self, collected, *path): - collected.append((path, "cutout_mask", self.mask)) + """Collects supporting arrays, including masks for each LAM and the global + dataset. + + Args: + collected (list): List to which the supporting arrays are appended. + *path: Variable length argument list specifying the paths for the masks. + """ + # Append masks for each LAM + for i, (lam, mask) in enumerate(zip(self.lams, self.masks)): + collected.append((path + (f"lam_{i}",), "cutout_mask", mask)) + + # Append the global mask + collected.append((path + ("global",), "cutout_mask", self.global_mask)) @cached_property def shape(self): - shape = self.lam.shape - # Number of non-zero masked values in the globe dataset - nb_globe = np.count_nonzero(self.mask) - return shape[:-1] + (shape[-1] + nb_globe,) + """Returns the shape of the Cutout, accounting for retained grid points + across all LAMs and the global dataset. + + Returns: + tuple: Shape of the concatenated masked datasets. + """ + shapes = [np.sum(mask) for mask in self.masks] + global_shape = np.sum(self.global_mask) + return tuple(self.lams[0].shape[:-1] + (sum(shapes) + global_shape,)) def check_same_resolution(self, d1, d2): # Turned off because we are combining different resolutions pass @property - def latitudes(self): - return np.concatenate([self.lam.latitudes, self.globe.latitudes[self.mask]]) + def grids(self): + """Returns the number of grid points for each LAM and the global dataset + after applying masks. - @property - def longitudes(self): - return np.concatenate([self.lam.longitudes, self.globe.longitudes[self.mask]]) + Returns: + tuple: Count of retained grid points for each dataset. + """ + grids = [np.sum(mask) for mask in self.masks] + grids.append(np.sum(self.global_mask)) + return tuple(grids) - def __getitem__(self, index): - if isinstance(index, (int, slice)): - index = (index, slice(None), slice(None), slice(None)) - return self._get_tuple(index) + @property + def latitudes(self): + """Returns the concatenated latitudes of each LAM and the global dataset + after applying masks. - @debug_indexing - @expand_list_indexing - def _get_tuple(self, index): - assert self.axis >= len(index) or index[self.axis] == slice( - None - ), f"No support for selecting a subset of the 1D values {index} ({self.tree()})" - index, changes = index_to_slices(index, self.shape) + Returns: + np.ndarray: Concatenated latitude array for the masked datasets. + """ + lam_latitudes = np.concatenate([lam.latitudes[mask] for lam, mask in zip(self.lams, self.masks)]) - # In case index_to_slices has changed the last slice - index, _ = update_tuple(index, self.axis, slice(None)) + assert ( + len(lam_latitudes) + len(self.globe.latitudes[self.global_mask]) == self.shape[-1] + ), "Mismatch in number of latitudes" - lam_data = self.lam[index] - globe_data = self.globe[index] + latitudes = np.concatenate([lam_latitudes, self.globe.latitudes[self.global_mask]]) + return latitudes - globe_data = globe_data[:, :, :, self.mask] + @property + def longitudes(self): + """Returns the concatenated longitudes of each LAM and the global dataset + after applying masks. - result = np.concatenate([lam_data, globe_data], axis=self.axis) + Returns: + np.ndarray: Concatenated longitude array for the masked datasets. + """ + lam_longitudes = np.concatenate([lam.longitudes[mask] for lam, mask in zip(self.lams, self.masks)]) - return apply_index_to_slices_changes(result, changes) + assert ( + len(lam_longitudes) + len(self.globe.longitudes[self.global_mask]) == self.shape[-1] + ), "Mismatch in number of longitudes" - @property - def grids(self): - for d in self.datasets: - if len(d.grids) > 1: - raise NotImplementedError("CutoutGrids does not support multi-grids datasets as inputs") - shape = self.lam.shape - return (shape[-1], self.shape[-1] - shape[-1]) + longitudes = np.concatenate([lam_longitudes, self.globe.longitudes[self.global_mask]]) + return longitudes def tree(self): + """Generates a hierarchical tree structure for the `Cutout` instance and + its associated datasets. + + Returns: + Node: A `Node` object representing the `Cutout` instance as the root + node, with each dataset in `self.datasets` represented as a child + node. + """ return Node(self, [d.tree() for d in self.datasets]) - # def metadata_specific(self): - # return super().metadata_specific( - # mask=serialise_mask(self.mask), - # ) - def grids_factory(args, kwargs): if "ensemble" in kwargs: diff --git a/tools/grids/grids3.yaml b/tools/grids/grids3.yaml new file mode 100644 index 00000000..75f91961 --- /dev/null +++ b/tools/grids/grids3.yaml @@ -0,0 +1,42 @@ +common: + mars_request: &mars_request + expver: "0001" + grid: 0.25/0.25 + area: [40, 25, 20, 60] + rotation: [-20, -40] + +dates: + start: 2024-01-01 00:00:00 + end: 2024-01-01 18:00:00 + frequency: 6h + +input: + join: + - mars: + <<: *mars_request + param: [2t, 10u, 10v, lsm] + levtype: sfc + stream: oper + type: an + - mars: + <<: *mars_request + param: [q, t, z] + levtype: pl + level: [50, 100] + stream: oper + type: an + - accumulations: + <<: *mars_request + levtype: sfc + param: [cp, tp] + - forcings: + template: ${input.join.0.mars} + param: + - cos_latitude + - sin_latitude + +output: + order_by: [valid_datetime, param_level, number] + remapping: + param_level: "{param}_{levelist}" + statistics: param_level diff --git a/tools/grids/grids4.yaml b/tools/grids/grids4.yaml new file mode 100644 index 00000000..39b72706 --- /dev/null +++ b/tools/grids/grids4.yaml @@ -0,0 +1,41 @@ +common: + mars_request: &mars_request + expver: "0001" + grid: 0.5/0.5 + area: [30, 90, 10, 120] + +dates: + start: 2024-01-01 00:00:00 + end: 2024-01-01 18:00:00 + frequency: 6h + +input: + join: + - mars: + <<: *mars_request + param: [2t, 10u, 10v, lsm] + levtype: sfc + stream: oper + type: an + - mars: + <<: *mars_request + param: [q, t, z] + levtype: pl + level: [50, 100] + stream: oper + type: an + - accumulations: + <<: *mars_request + levtype: sfc + param: [cp, tp] + - forcings: + template: ${input.join.0.mars} + param: + - cos_latitude + - sin_latitude + +output: + order_by: [valid_datetime, param_level, number] + remapping: + param_level: "{param}_{levelist}" + statistics: param_level diff --git a/tools/grids/grids5.yaml b/tools/grids/grids5.yaml new file mode 100644 index 00000000..42aab132 --- /dev/null +++ b/tools/grids/grids5.yaml @@ -0,0 +1,41 @@ +common: + mars_request: &mars_request + expver: "0001" + grid: 0.2/0.2 + area: [25, 100, 20, 105] + +dates: + start: 2024-01-01 00:00:00 + end: 2024-01-01 18:00:00 + frequency: 6h + +input: + join: + - mars: + <<: *mars_request + param: [2t, 10u, 10v, lsm] + levtype: sfc + stream: oper + type: an + - mars: + <<: *mars_request + param: [q, t, z] + levtype: pl + level: [50, 100] + stream: oper + type: an + - accumulations: + <<: *mars_request + levtype: sfc + param: [cp, tp] + - forcings: + template: ${input.join.0.mars} + param: + - cos_latitude + - sin_latitude + +output: + order_by: [valid_datetime, param_level, number] + remapping: + param_level: "{param}_{levelist}" + statistics: param_level diff --git a/tools/grids/grids6.yaml b/tools/grids/grids6.yaml new file mode 100644 index 00000000..641618fd --- /dev/null +++ b/tools/grids/grids6.yaml @@ -0,0 +1,41 @@ +common: + mars_request: &mars_request + expver: "0001" + grid: 10/10 + area: [90, -40, -40, 180] + +dates: + start: 2024-01-01 00:00:00 + end: 2024-01-01 18:00:00 + frequency: 6h + +input: + join: + - mars: + <<: *mars_request + param: [2t, 10u, 10v, lsm] + levtype: sfc + stream: oper + type: an + - mars: + <<: *mars_request + param: [q, t, z] + levtype: pl + level: [50, 100] + stream: oper + type: an + - accumulations: + <<: *mars_request + levtype: sfc + param: [cp, tp] + - forcings: + template: ${input.join.0.mars} + param: + - cos_latitude + - sin_latitude + +output: + order_by: [valid_datetime, param_level, number] + remapping: + param_level: "{param}_{levelist}" + statistics: param_level diff --git a/tools/grids/grids7.yaml b/tools/grids/grids7.yaml new file mode 100644 index 00000000..6e9cc965 --- /dev/null +++ b/tools/grids/grids7.yaml @@ -0,0 +1,41 @@ +common: + mars_request: &mars_request + expver: "0001" + grid: 2/2 + area: [90, -40, -40, 180] + +dates: + start: 2024-01-01 00:00:00 + end: 2024-01-01 18:00:00 + frequency: 6h + +input: + join: + - mars: + <<: *mars_request + param: [2t, 10u, 10v, lsm] + levtype: sfc + stream: oper + type: an + - mars: + <<: *mars_request + param: [q, t, z]R. E + levtype: pl + level: [50, 100] + stream: oper + type: an + - accumulations: + <<: *mars_request + levtype: sfc + param: [cp, tp] + - forcings: + template: ${input.join.0.mars} + param: + - cos_latitude + - sin_latitude + +output: + order_by: [valid_datetime, param_level, number] + remapping: + param_level: "{param}_{levelist}" + statistics: param_level diff --git a/tools/grids/grids_multilam.ipynb b/tools/grids/grids_multilam.ipynb new file mode 100644 index 00000000..6f7fa9f6 --- /dev/null +++ b/tools/grids/grids_multilam.ipynb @@ -0,0 +1,1060 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from anemoi.datasets import open_dataset\n", + "from anemoi.datasets.data.grids import Cutout" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load the data\n", + "Datasets generated from the grids*.yaml files in tools/grids/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = 'dir_with_your_zarr_data' \n", + "f_global = data_dir + '/grids1.zarr'\n", + "f_lam1 = data_dir + '/grids2.zarr'\n", + "f_lam2 = data_dir + '/grids3.zarr'\n", + "f_lam3 = data_dir + '/grids4.zarr'\n", + "f_lam4 = data_dir + '/grids5.zarr'\n", + "f_lam5 = data_dir + '/grids6.zarr'\n", + "f_lam6 = data_dir + '/grids7.zarr'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "global_dataset = open_dataset(f_global)\n", + "lam_dataset_1 = open_dataset(f_lam1)\n", + "lam_dataset_2 = open_dataset(f_lam2)\n", + "lam_dataset_3 = open_dataset(f_lam3)\n", + "lam_dataset_4 = open_dataset(f_lam4)\n", + "lam_dataset_5 = open_dataset(f_lam5)\n", + "lam_dataset_6 = open_dataset(f_lam6)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# Define and run some tests" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def test_cutout_initialization(lam_dataset_1, lam_dataset_2, global_dataset):\n", + " \"\"\"Ensure that the Cutout class correctly initializes with multiple Limited \n", + " Area Models (LAMs) and a global dataset.\"\"\"\n", + " cutout = Cutout(\n", + " [lam_dataset_1, lam_dataset_2, global_dataset], \n", + " axis=3,\n", + " )\n", + " \n", + " assert len(cutout.lams) == 2\n", + " assert cutout.globe is not None\n", + " assert len(cutout.masks) == 2\n", + "\n", + "def test_cutout_mask_generation(lam_dataset, global_dataset):\n", + " \"\"\"\"Ensure that the cutout_mask function correctly generates masks for LAMs \n", + " and excludes overlapping regions.\"\"\"\n", + " cutout = Cutout(\n", + " [lam_dataset, global_dataset], axis=3)\n", + " mask = cutout.masks[0]\n", + " lam = cutout.lams[0]\n", + " \n", + " assert mask is not None\n", + " assert isinstance(mask, np.ndarray)\n", + " assert isinstance(cutout.global_mask, np.ndarray)\n", + " assert mask.shape[-1] == lam.shape[-1]\n", + " assert cutout.global_mask.shape[-1] == global_dataset.shape[-1]\n", + " \n", + " \n", + "def test_cutout_getitem(lam_dataset, global_dataset):\n", + " \"\"\"Verify that the __getitem__ method correctly returns the appropriate \n", + " data when indexing the Cutout object.\"\"\"\n", + " cutout = Cutout([lam_dataset, global_dataset], axis=3)\n", + " \n", + " data = cutout[0, :, :, :]\n", + " expected_shape = cutout.shape[1:]\n", + " assert data is not None\n", + " assert data.shape == expected_shape\n", + " \n", + "def test_latitudes_longitudes_concatenation(lam_dataset_1, lam_dataset_2, global_dataset):\n", + " \"\"\"Ensure that latitudes and longitudes are correctly \n", + " concatenated from all LAMs and the masked global dataset.\"\"\"\n", + " cutout = Cutout(\n", + " [lam_dataset_1, lam_dataset_2, global_dataset], \n", + " axis=3\n", + " )\n", + " \n", + " latitudes = cutout.latitudes\n", + " longitudes = cutout.longitudes\n", + " \n", + " assert latitudes is not None\n", + " assert longitudes is not None\n", + " assert len(latitudes) == cutout.shape[-1]\n", + " assert len(longitudes) == cutout.shape[-1]\n", + " \n", + "def test_overlapping_lams(lam_dataset_1, lam_dataset_2, global_dataset):\n", + " \"\"\"Confirm that overlapping regions between LAMs and the global dataset are \n", + " correctly handled by the masks.\"\"\"\n", + " # lam_dataset_2 has to overlap with lam_dataset_1\n", + " cutout = Cutout(\n", + " [lam_dataset_1, lam_dataset_2, global_dataset], \n", + " axis=3\n", + " )\n", + " \n", + " # Verify that the overlapping region in lam_dataset_2 is excluded\n", + " assert np.count_nonzero(cutout.masks[1] == False) > 0\n", + " \n", + "def test_open_dataset_cutout(lam_dataset_1, global_dataset):\n", + " \"\"\"Ensure that open_dataset(cutout=[...]) works correctly with the new \n", + " Cutout implementation.\"\"\"\n", + " ds = open_dataset(\n", + " cutout=[lam_dataset_1, global_dataset]\n", + " )\n", + "\n", + " assert isinstance(ds, Cutout)\n", + " assert len(ds.lams) == 1\n", + " assert ds.globe is not None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_cutout_initialization(lam_dataset_1, lam_dataset_2, global_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_cutout_mask_generation(lam_dataset_1, global_dataset)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_cutout_getitem(lam_dataset_1, global_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_latitudes_longitudes_concatenation(lam_dataset_1, lam_dataset_2, global_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_overlapping_lams(lam_dataset_1, lam_dataset_2, global_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_open_dataset_cutout(lam_dataset_1, global_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# Plot function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_grid(ds, path, s=0.1, c=\"r\", grids=None, point=None, central_latitude=-20.0, central_longitude=165.0):\n", + " import matplotlib.pyplot as plt\n", + " import cartopy.crs as ccrs\n", + " import numpy as np\n", + "\n", + " lats, lons = ds.latitudes, ds.longitudes\n", + "\n", + " fig = plt.figure(figsize=(9, 9))\n", + " proj = ccrs.NearsidePerspective(\n", + " central_latitude=central_latitude, \n", + " central_longitude=central_longitude, \n", + " satellite_height=4e6\n", + " )\n", + "\n", + " ax = plt.axes(projection=proj)\n", + "\n", + " def fill():\n", + " # Make sure we have a full globe\n", + " lons, lats = np.meshgrid(np.arange(-180, 180, 1), np.arange(-90, 90, 1))\n", + " x, y, _ = proj.transform_points(\n", + " ccrs.PlateCarree(), lons.flatten(), lats.flatten()\n", + " ).T\n", + "\n", + " mask = np.invert(np.logical_or(np.isinf(x), np.isinf(y)))\n", + " x = np.compress(mask, x)\n", + " y = np.compress(mask, y)\n", + "\n", + " # ax.tricontourf(x, y, values)\n", + " ax.scatter(x, y, s=0, c=\"w\")\n", + "\n", + " fill()\n", + "\n", + " def plot(what, s, c):\n", + " x, y, _ = proj.transform_points(ccrs.PlateCarree(), lons[what], lats[what]).T\n", + "\n", + " mask = np.invert(np.logical_or(np.isinf(x), np.isinf(y)))\n", + " x = np.compress(mask, x)\n", + " y = np.compress(mask, y)\n", + "\n", + " # ax.tricontourf(x, y, values)\n", + " ax.scatter(x, y, s=s, c=c)\n", + "\n", + " if grids:\n", + " #print('s: ', s)\n", + " a = 0\n", + " for i, b in enumerate(grids):\n", + " if s[i] is not None:\n", + " plot(slice(a, a + b), s[i], c[i])\n", + " a += b\n", + " else:\n", + " plot(..., s, c)\n", + "\n", + " if point:\n", + " point = np.array(point, dtype=np.float64)\n", + " x, y, _ = proj.transform_points(ccrs.PlateCarree(), point[1], point[0]).T\n", + " ax.scatter(x, y, s=100, c=\"k\")\n", + "\n", + " ax.coastlines()\n", + "\n", + " if isinstance(path, str):\n", + " fig.savefig(path, bbox_inches=\"tight\")\n", + " else:\n", + " for p in path:\n", + " fig.savefig(p, bbox_inches=\"tight\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 1) Plot the datasets separately" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " global_dataset, \n", + " \"global_grids1.png\", \n", + " central_latitude=20.0, \n", + " central_longitude=75.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_1, \n", + " \"lam1_grids2.png\", \n", + " central_latitude=60.0, \n", + " central_longitude=15.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_2, \n", + " \"lam1_grids3.png\", \n", + " central_latitude=50.0, \n", + " central_longitude=75.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_3, \n", + " \"lam1_grids4.png\", \n", + " central_latitude=20.0, \n", + " central_longitude=105.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_4, \n", + " \"lam1_grids5.png\", \n", + " central_latitude=20.0, \n", + " central_longitude=105.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_5, \n", + " \"lam5_grids6.png\", \n", + " central_latitude=-20.0, \n", + " central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " lam_dataset_6, \n", + " \"lam6_grids7.png\", \n", + " central_latitude=-20.0, \n", + " central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 2) Test cutout with one LAM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_1, global_dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam1.png\",\n", + " s=[0.5, 0.5],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=50.0, \n", + " central_longitude=15.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3) a) Test two overlapping LAMs\n", + "The LAMs have different resolution and are rotated" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_2, lam_dataset_1, global_dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds.grids" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam1_lam2.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=50.0, \n", + " central_longitude=65.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## 3) b) The same LAMs but in a different order" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_1, lam_dataset_2, global_dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds.grids" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam2_lam1.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=50.0, \n", + " central_longitude=65.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 4) Test two LAMS that are not overlapping" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_3, lam_dataset_2, global_dataset]) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam3_lam2.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=40.0, \n", + " central_longitude=95.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5) Test multiple LAMS \n", + "\n", + "- LAMs with different resolutions\n", + "- Rotated LAMs\n", + "- LAMs with no overlap.\n", + "- LAM contained within other LAM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[\n", + " lam_dataset_4, \n", + " lam_dataset_3, \n", + " lam_dataset_2, \n", + " lam_dataset_1, \n", + " global_dataset\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam4_lam3_lam2_lam1.png\",\n", + " s=[0.1, 0.1, 0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " #c=[\"g\", \"r\", \"b\", \"c\", \"y\", \"k\"],\n", + " c=[\"g\", \"r\", \"b\", \"c\", \"y\"],\n", + " central_latitude=50.0, \n", + " central_longitude=95.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# Test small LAM behind bigger LAM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_3, lam_dataset_4, global_dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam3_lam4.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=50.0, \n", + " central_longitude=95.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 6 a) Test cutout with a coarser resolution LAM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using LAMs with very low resolution can be a challenge, depending on how it compares to the resolution of the global dataset and the other LAMs.\n", + "\n", + "TODO: A future implementation could consider a list of `min_distance_km` and `neighbours`, so that there is value one for each LAM." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using default values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(cutout=[lam_dataset_5, global_dataset])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=-30.0, \n", + " central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Has some issues when using default parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=50.0, \n", + " central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6) a) i) Test the parameter `min_distance_km`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[lam_dataset_5, global_dataset], \n", + " min_distance_km=600\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=-30.0, \n", + " central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=50.0, \n", + " central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6) a) ii) Test the parameter `neighbours`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[lam_dataset_5, global_dataset], \n", + " neighbours=200\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=-30.0, \n", + " central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam5.png\",\n", + " s=[0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\"],\n", + " central_latitude=50.0, \n", + " central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6) b) A similar example, where the LAM resolution is not so low" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6) b) i) Test the parameter `min_distance_km`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[\n", + " lam_dataset_4, \n", + " lam_dataset_3, \n", + " lam_dataset_2, \n", + " lam_dataset_1, \n", + " lam_dataset_6, \n", + " global_dataset\n", + " ], \n", + " min_distance_km=200\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam4_lam3_lam2_lam1_lam6.png\",\n", + " s=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\", \"c\", \"y\", \"k\"],\n", + " central_latitude=-30.0, \n", + " central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam4_lam3_lam2_lam1_lam6.png\",\n", + " s=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\", \"c\", \"y\", \"k\"],\n", + " central_latitude=50.0, \n", + " central_longitude=95.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6) b) ii) Test the parameter `neighbours`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[\n", + " lam_dataset_4, \n", + " lam_dataset_3, \n", + " lam_dataset_2, \n", + " lam_dataset_1, \n", + " lam_dataset_6, \n", + " global_dataset\n", + " ], \n", + " neighbours=10\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam4_lam3_lam2_lam1_lam6.png\",\n", + " s=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\", \"c\", \"y\", \"k\"],\n", + " central_latitude=-30.0, \n", + " central_longitude=165.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam4_lam3_lam2_lam1_lam6.png\",\n", + " s=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\", \"c\", \"y\", \"k\"],\n", + " central_latitude=50.0, \n", + " central_longitude=95.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 7) Test thinning with cutout" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[\n", + " {\"dataset\": lam_dataset_2, \"thinning\": 2}, \n", + " {\"dataset\": lam_dataset_1, \"thinning\": 8}, \n", + " {\"dataset\": global_dataset, \"thinning\": 2}\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_thinning2_global_lam2_lam1.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=50.0, \n", + " central_longitude=65.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# 8) Test cropping with cutout" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds = open_dataset(\n", + " cutout=[\n", + " {\"dataset\": lam_dataset_1, \"area\": (60, 0, 20, 80)}, \n", + " {\"dataset\": lam_dataset_2}, \n", + " {\"dataset\": global_dataset}\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_grid(\n", + " ds,\n", + " \"cutout_global_lam1cropped_lam2.png\",\n", + " s=[0.1, 0.1, 0.1],\n", + " grids=ds.grids,\n", + " c=[\"g\", \"r\", \"b\"],\n", + " central_latitude=50.0, \n", + " central_longitude=65.0\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}