diff --git a/examples/vector.ipynb b/examples/vector.ipynb index 21e7a22..daba302 100644 --- a/examples/vector.ipynb +++ b/examples/vector.ipynb @@ -7,8 +7,11 @@ "metadata": {}, "outputs": [], "source": [ + "from functools import partial\n", + "from geocube.rasterize import rasterize_image\n", + "from rasterio.enums import MergeAlg\n", "import geopandas as gpd\n", - "from ipyleaflet import LayersControl, Map, WidgetControl, basemaps\n", + "from ipyleaflet import LocalTileLayer, LayersControl, Map, WidgetControl, basemaps\n", "from ipywidgets import FloatSlider\n", "import xarray_leaflet\n", "import matplotlib.pyplot as plt" @@ -32,8 +35,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = gpd.read_file(\"bldg_footprints.shp\")\n", - "df[\"mask\"] = 1" + "df = gpd.read_file(\"bldg_footprints.shp\")" ] }, { @@ -54,7 +56,9 @@ "metadata": {}, "outputs": [], "source": [ - "l = df.leaflet.plot(m, measurement=\"mask\", colormap=plt.cm.inferno)" + "rasterize_function = partial(rasterize_image, merge_alg=MergeAlg.add, all_touched=False)\n", + "layer = partial(LocalTileLayer, max_zoom=20)\n", + "l = df.leaflet.plot(m, measurement=\"Height\", layer=layer, dynamic=False, rasterize_function=rasterize_function, colormap=plt.cm.viridis)" ] }, { @@ -94,7 +98,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.5" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/setup.cfg b/setup.cfg index 3c3751a..f96366e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,13 +23,13 @@ install_requires = jupyter_server >=0.2.0 rioxarray >=0.0.30 ipyleaflet >=0.13.1 + ipywidgets >=7.7.2 pillow >=7 matplotlib >=3 affine >=2 mercantile >=1 ipyspin >=0.1.6 ipyurl >=0.1.3 - jupyterlab-widgets >=1.0.0,<2 geocube <1.0.0 pygeos >=0.12,<1.0.0 zarr >=2.0.0,<3.0.0 diff --git a/xarray_leaflet/vector.py b/xarray_leaflet/vector.py index 5c801a3..468a022 100644 --- a/xarray_leaflet/vector.py +++ b/xarray_leaflet/vector.py @@ -1,11 +1,11 @@ import json +import math from functools import partial from pathlib import Path -from typing import Optional +from typing import Callable, Optional import mercantile import numpy as np -import pyproj import xarray as xr import zarr from geocube.api.core import make_geocube @@ -23,6 +23,7 @@ def __init__( self, df: GeoDataFrame, measurement: str, + rasterize_function: Optional[Callable], width: int, height: int, root_path: str = "", @@ -30,6 +31,9 @@ def __init__( # reproject to Web Mercator self.df = df.to_crs(epsg=3857) self.measurement = measurement + self.rasterize_function = rasterize_function or partial( + rasterize_image, merge_alg=MergeAlg.add, all_touched=True + ) self.width = width self.height = height self.zzarr = Zzarr(root_path, width, height) @@ -56,9 +60,7 @@ def get_da_tile(self, tile: mercantile.Tile) -> Optional[xr.DataArray]: vector_data=df_tile, resolution=(-dy, dx), measurements=[self.measurement], - rasterize_function=partial( - rasterize_image, merge_alg=MergeAlg.add, all_touched=True - ), + rasterize_function=self.rasterize_function, fill=0, geom=geom, ) @@ -82,15 +84,10 @@ def get_da_llbbox( self.tiles.append(tile) if all_none: return None - project = pyproj.Transformer.from_crs( - pyproj.CRS("EPSG:4326"), pyproj.CRS("EPSG:3857"), always_xy=True - ).transform - b = box(*bbox) - polygon = transform(project, b) - left, bottom, right, top = polygon.bounds - return self.zzarr.get_ds(z)["da"].sel( - x=slice(left, right), y=slice(top, bottom) - ) + da = self.get_da(z) + y0, x0 = deg2idx(bbox.north, bbox.west, z, self.height, self.width, math.floor) + y1, x1 = deg2idx(bbox.south, bbox.east, z, self.height, self.width, math.ceil) + return da[y0:y1, x0:x1] def get_da(self, z: int) -> xr.DataArray: return self.zzarr.get_ds(z)["da"] @@ -101,7 +98,7 @@ def __init__(self, root_path: str, width: int, height: int): self.root_path = Path(root_path) self.width = width self.height = height - self.ds = {} + self.z = None def open_zarr(self, mode: str, z: int) -> zarr.Array: path = self.root_path / str(z) @@ -114,32 +111,6 @@ def open_zarr(self, mode: str, z: int) -> zarr.Array: ) if mode == "w": # write Dataset to zarr - mi, ma = mercantile.minmax(z) - ul = mercantile.xy_bounds(mi, mi, z) - lr = mercantile.xy_bounds(ma, ma, z) - bbox = mercantile.Bbox(ul.left, lr.bottom, lr.right, ul.top) - x = zarr.open( - path / "x", - mode="w", - shape=(2**z * self.width,), - chunks=(2**z * self.width,), - dtype=" zarr.Array: ".zgroup": zgroup, "da/.zarray": zarray, "da/.zattrs": zattrs, - "x/.zarray": x_zarray, - "x/.zattrs": x_zattrs, - "y/.zarray": y_zarray, - "y/.zattrs": y_zattrs, }, zarr_consolidated_format=1, ) @@ -172,14 +139,23 @@ def write_to_zarr(self, tile: mercantile.Tile, data: np.ndarray): mode = "a" else: mode = "w" - self.array = self.open_zarr(mode, z) - self.array[ + array = self.open_zarr(mode, z) + array[ y * self.height : (y + 1) * self.height, # noqa x * self.width : (x + 1) * self.width, # noqa ] = data def get_ds(self, z: int) -> xr.Dataset: path = self.root_path / str(z) - if z not in self.ds: - self.ds[z] = xr.open_zarr(path) - return self.ds[z] + if z != self.z: + self.ds_z = xr.open_zarr(path) + self.z = z + return self.ds_z + + +def deg2idx(lat_deg, lon_deg, zoom, height, width, round_fun): + lat_rad = math.radians(lat_deg) + n = 2**zoom + xtile = round_fun(((lon_deg + 180) % 360) / 360 * n * width) + ytile = round_fun((1 - math.asinh(math.tan(lat_rad)) / math.pi) / 2 * n * height) + return ytile, xtile diff --git a/xarray_leaflet/xarray_leaflet.py b/xarray_leaflet/xarray_leaflet.py index 8208a2f..8de570e 100644 --- a/xarray_leaflet/xarray_leaflet.py +++ b/xarray_leaflet/xarray_leaflet.py @@ -44,6 +44,16 @@ def _map_ready_changed(self, change): def plot( self, m, + *, + # raster or vector options: + get_base_url: Optional[Callable] = None, + dynamic: Optional[bool] = None, + persist: bool = True, + tile_dir=None, + tile_height: int = 256, + tile_width: int = 256, + layer: Callable = LocalTileLayer, + # raster-only options: x_dim="x", y_dim="y", fit_bounds=True, @@ -54,15 +64,11 @@ def plot( transform3=passthrough, colormap=None, colorbar_position="topright", - persist=True, - dynamic=False, - tile_dir=None, - tile_height=256, - tile_width=256, resampling=Resampling.nearest, - get_base_url=None, + # vector-only options: measurement: Optional[str] = None, visible_callback: Optional[Callable] = None, + rasterize_function: Optional[Callable] = None, ): """Display an array as an interactive map. @@ -122,6 +128,9 @@ def plot( - the mercantile.LngLatBbox of the visible region and returning True if the layer should be shown, False otherwise. + rasterize_function: callable, optional + A callable passed to make_geocube. Defaults to: + partial(rasterize_image, merge_alg=MergeAlg.add, all_touched=True) Returns ------- @@ -129,23 +138,34 @@ def plot( A handler to the layer that is added to the map. """ - self.layer = LocalTileLayer() + self.layer = layer() if self.is_vector: # source is a GeoDataFrame (vector) self.visible_callback = visible_callback if measurement is None: raise RuntimeError("You must provide a 'measurement'.") + if dynamic is None: + dynamic = True + if not dynamic: + self.vmin = self._df[measurement].min() + self.vmax = self._df[measurement].max() self.measurement = measurement - dynamic = True zarr_temp_dir = tempfile.TemporaryDirectory(prefix="xarray_leaflet_zarr_") self.zvect = Zvect( - self._df, measurement, tile_width, tile_height, zarr_temp_dir.name + self._df, + measurement, + rasterize_function, + tile_width, + tile_height, + zarr_temp_dir.name, ) if colormap is None: colormap = plt.cm.viridis else: # source is a DataArray (raster) + if dynamic is None: + dynamic = False if "proj4def" in m.crs: # it's a custom projection if dynamic: @@ -363,6 +383,7 @@ def _get_vector_tiles(self, change=None): tiles = mercantile.tiles(west, south, east, north, z) if self.dynamic: + # get DataArray for the visible map llbbox = mercantile.LngLatBbox(west, south, east, north) da_visible = self.zvect.get_da_llbbox(llbbox, z) # check if we must show the layer @@ -372,32 +393,42 @@ def _get_vector_tiles(self, change=None): self.m.remove_control(self.spinner_control) return if da_visible is None: - self.max_value = 0 + vmin = vmax = 0 else: - self.max_value = da_visible.max() + vmin = da_visible.min() + vmax = da_visible.max() + else: + vmin = self.vmin + vmax = self.vmax + da_visible_computed = False for tile in tiles: x, y, z = tile path = f"{self.tile_path}/{z}/{x}/{y}.png" if self.dynamic or not os.path.exists(path): - xy_bbox = mercantile.xy_bounds(tile) - if self.dynamic: - if da_visible is not None: - da_tile = self.zvect.get_da(z).sel( - y=slice(xy_bbox.top, xy_bbox.bottom), - x=slice(xy_bbox.left, xy_bbox.right), - ) - else: - da_tile = None + if not self.dynamic and not da_visible_computed: + # get DataArray for the visible map + llbbox = mercantile.LngLatBbox(west, south, east, north) + da_visible = self.zvect.get_da_llbbox(llbbox, z) + da_visible_computed = True + if self.dynamic and da_visible is None: + da_tile = None + else: + da_tile = self.zvect.get_da(z)[ + y * self.tile_height : (y + 1) * self.tile_height, + x * self.tile_width : (x + 1) * self.tile_width, + ] if da_tile is None: write_image(path, None) else: - da_tile /= self.max_value + # normalize + da_tile = (da_tile - vmin) / (vmax - vmin) da_tile = self.colormap(da_tile) write_image(path, da_tile * 255) if self.dynamic: self.layer.redraw() + self.m.remove_control(self.spinner_control) def _get_raster_tiles(self, change=None):