diff --git a/CHANGELOG.md b/CHANGELOG.md index d2701e09..6c9c92bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - SDAP-473: Added support for matchup job prioritization - SDAP-483: Added `.asf.yaml` to configure Jira auto-linking. - SDAP-487: Added script to migrate existing `doms.doms_data` data to new schema. +- SDAP-440: Added `CassandraSwathProxy` to process tiles in a method optimized for swath format +- SDAP-440: Set up framework to roll out changes to SDAP algorithms to work with swath formatted data for both tile types rather than having tiles formatted as gridded which is very memory inefficient. ### Changed - SDAP-453: Updated results storage and retrieval to support output JSON from `/cdmsresults` that matches output from `/match_spark`. - **NOTE:** Deploying these changes to an existing SDAP deployment will require modifying the Cassandra database with stored results. There is a script to do so at `/tools/update-doms-data-schema/update.py` diff --git a/analysis/webservice/algorithms/NexusCalcHandler.py b/analysis/webservice/algorithms/NexusCalcHandler.py index 13edd0c8..2ba6d893 100644 --- a/analysis/webservice/algorithms/NexusCalcHandler.py +++ b/analysis/webservice/algorithms/NexusCalcHandler.py @@ -39,7 +39,10 @@ def validate(cls): def __init__(self, tile_service_factory, **kwargs): self._tile_service_factory = tile_service_factory - self._tile_service = tile_service_factory() + if 'desired_projection' in kwargs: + self._tile_service = tile_service_factory(desired_projection=kwargs['desired_projection']) + else: + self._tile_service = tile_service_factory() def _get_tile_service(self): return self._tile_service diff --git a/analysis/webservice/algorithms/doms/BaseDomsHandler.py b/analysis/webservice/algorithms/doms/BaseDomsHandler.py index 84c91633..4ebee3de 100644 --- a/analysis/webservice/algorithms/doms/BaseDomsHandler.py +++ b/analysis/webservice/algorithms/doms/BaseDomsHandler.py @@ -48,8 +48,8 @@ class BaseDomsQueryCalcHandler(NexusCalcHandler): - def __init__(self, tile_service_factory): - NexusCalcHandler.__init__(self, tile_service_factory) + def __init__(self, tile_service_factory, **kwargs): + NexusCalcHandler.__init__(self, tile_service_factory, **kwargs) def getDataSourceByName(self, source): for s in config.ENDPOINTS: diff --git a/data-access/nexustiles/dao/CassandraProxy.py b/data-access/nexustiles/dao/CassandraProxy.py index 96f7c4c6..ae6e77ad 100644 --- a/data-access/nexustiles/dao/CassandraProxy.py +++ b/data-access/nexustiles/dao/CassandraProxy.py @@ -31,6 +31,7 @@ logger = logging.getLogger(__name__) + class NexusTileData(Model): __table_name__ = 'sea_surface_temp' tile_id = columns.UUID(primary_key=True) @@ -53,7 +54,7 @@ def get_raw_data_array(self): return from_shaped_array(the_tile_data.variable_data) - def get_lat_lon_time_data_meta(self): + def get_lat_lon_time_data_meta(self, projection='grid'): """ Retrieve data from data store and metadata from metadata store for this tile. For gridded tiles, the tile shape of the data diff --git a/data-access/nexustiles/dao/CassandraSwathProxy.py b/data-access/nexustiles/dao/CassandraSwathProxy.py new file mode 100644 index 00000000..dcfbc47e --- /dev/null +++ b/data-access/nexustiles/dao/CassandraSwathProxy.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import uuid +from configparser import NoOptionError + +import nexusproto.DataTile_pb2 as nexusproto +import numpy as np +from cassandra.auth import PlainTextAuthProvider +from cassandra.cqlengine import columns, connection, CQLEngineException +from cassandra.cluster import NoHostAvailable +from cassandra.cqlengine.models import Model +from cassandra.policies import TokenAwarePolicy, DCAwareRoundRobinPolicy, WhiteListRoundRobinPolicy +from multiprocessing.synchronize import Lock +from nexusproto.serialization import from_shaped_array + +INIT_LOCK = Lock(ctx=None) + +logger = logging.getLogger(__name__) + + +class NexusTileData(Model): + __table_name__ = 'sea_surface_temp' + tile_id = columns.UUID(primary_key=True) + tile_blob = columns.Blob() + + __nexus_tile = None + + def _get_nexus_tile(self): + if self.__nexus_tile is None: + self.__nexus_tile = nexusproto.TileData.FromString(self.tile_blob) + + return self.__nexus_tile + + def get_raw_data_array(self): + + nexus_tile = self._get_nexus_tile() + the_tile_type = nexus_tile.tile.WhichOneof("tile_type") + + the_tile_data = getattr(nexus_tile.tile, the_tile_type) + + return from_shaped_array(the_tile_data.variable_data) + + def get_lat_lon_time_data_meta(self, projection='grid'): + """ + Retrieve data from data store and metadata from metadata store + for this tile. For swath tiles, the tile shape of the data + will match the input shape. For example, if the input was a + 30x30 tile, all variables will also be 30x30. However, if the + tile is a gridded tile, the lat and lon arrays will be reflected + into 2 dimensions to reflect how the equivalent data would be + represented in swath format. + + Multi-variable tile will also include an extra dimension in the + data array. For example, a 30 x 30 x 30 array would be + transformed to N x 30 x 30 x 30 where N is the number of + variables in this tile. + + latitude_data, longitude_data, np.array([grid_tile.time]), grid_tile_data, meta_data, is_multi_var + + :return: latitude data + :return: longitude data + :return: time data + :return: data + :return: meta data dictionary + :return: boolean flag, True if this tile has more than one variable + """ + is_multi_var = False + + if self._get_nexus_tile().HasField('grid_tile'): + grid_tile = self._get_nexus_tile().grid_tile + + grid_tile_data = np.ma.masked_invalid(from_shaped_array(grid_tile.variable_data)) + latitude_data = np.ma.masked_invalid(from_shaped_array(grid_tile.latitude)) + longitude_data = np.ma.masked_invalid(from_shaped_array(grid_tile.longitude)) + + reflected_lon_array = np.broadcast_to(longitude_data, (len(latitude_data), len(longitude_data))) + reflected_lat_array = np.broadcast_to(latitude_data, (len(longitude_data), len(latitude_data))) + reflected_lat_array = np.transpose(reflected_lat_array) + + time_array = np.broadcast_to(grid_tile.time, grid_tile_data.shape) + + # if len(grid_tile_data.shape) == 2: + # grid_tile_data = grid_tile_data[np.newaxis, :] + + # Extract the meta data + meta_data = {} + for meta_data_obj in grid_tile.meta_data: + name = meta_data_obj.name + meta_array = np.ma.masked_invalid(from_shaped_array(meta_data_obj.meta_data)) + if len(meta_array.shape) == 2: + meta_array = meta_array[np.newaxis, :] + meta_data[name] = meta_array + + return reflected_lat_array, reflected_lon_array, time_array, grid_tile_data, meta_data, is_multi_var + elif self._get_nexus_tile().HasField('swath_tile'): + swath_tile = self._get_nexus_tile().swath_tile + + latitude_data = np.ma.masked_invalid(from_shaped_array(swath_tile.latitude)) + longitude_data = np.ma.masked_invalid(from_shaped_array(swath_tile.longitude)) + time_data = np.ma.masked_invalid(from_shaped_array(swath_tile.time)) + + # Simplify the tile if the time dimension is the same value repeated + # if np.all(time_data == np.min(time_data)): + # time_data = np.array([np.min(time_data)]) + + swath_tile_data = np.ma.masked_invalid(from_shaped_array(swath_tile.variable_data)) + tile_data = swath_tile_data + + # Extract the metadata + meta_data = {} + for meta_data_obj in swath_tile.meta_data: + name = meta_data_obj.name + actual_meta_array = np.ma.masked_invalid(from_shaped_array(meta_data_obj.meta_data)) + meta_data[name] = actual_meta_array + + return latitude_data, longitude_data, time_data, tile_data, meta_data, is_multi_var + # TODO: Do we use this? + # elif self._get_nexus_tile().HasField('time_series_tile'): + # time_series_tile = self._get_nexus_tile().time_series_tile + # + # time_series_tile_data = np.ma.masked_invalid(from_shaped_array(time_series_tile.variable_data)) + # time_data = np.ma.masked_invalid(from_shaped_array(time_series_tile.time)).reshape(-1) + # latitude_data = np.ma.masked_invalid(from_shaped_array(time_series_tile.latitude)) + # longitude_data = np.ma.masked_invalid(from_shaped_array(time_series_tile.longitude)) + # + # reshaped_array = np.ma.masked_all((len(time_data), len(latitude_data), len(longitude_data))) + # idx = np.arange(len(latitude_data)) + # reshaped_array[:, idx, idx] = time_series_tile_data + # tile_data = reshaped_array + # # Extract the meta data + # meta_data = {} + # for meta_data_obj in time_series_tile.meta_data: + # name = meta_data_obj.name + # meta_array = np.ma.masked_invalid(from_shaped_array(meta_data_obj.meta_data)) + # + # reshaped_meta_array = np.ma.masked_all((len(time_data), len(latitude_data), len(longitude_data))) + # idx = np.arange(len(latitude_data)) + # reshaped_meta_array[:, idx, idx] = meta_array + # + # meta_data[name] = reshaped_meta_array + # + # return latitude_data, longitude_data, time_data, tile_data, meta_data, is_multi_var + elif self._get_nexus_tile().HasField('swath_multi_variable_tile'): + swath_tile = self._get_nexus_tile().swath_multi_variable_tile + is_multi_var = True + + latitude_data = np.ma.masked_invalid(from_shaped_array(swath_tile.latitude)) + longitude_data = np.ma.masked_invalid(from_shaped_array(swath_tile.longitude)) + time_data = np.ma.masked_invalid(from_shaped_array(swath_tile.time)) + + # Simplify the tile if the time dimension is the same value repeated + # if np.all(time_data == np.min(time_data)): + # time_data = np.array([np.min(time_data)]) + + swath_tile_data = np.ma.masked_invalid(from_shaped_array(swath_tile.variable_data)) + swath_tile_data = np.moveaxis(swath_tile_data, -1, 0) + + tile_data = [] + + for variable_array in swath_tile_data: + tile_data.append(variable_array) + + # Extract the metadata + + meta_data = {} + for meta_data_obj in swath_tile.meta_data: + name = meta_data_obj.name + actual_meta_array = np.ma.masked_invalid(from_shaped_array(meta_data_obj.meta_data)) + meta_data[name] = actual_meta_array + + return latitude_data, longitude_data, time_data, np.ma.array(tile_data), meta_data, is_multi_var + elif self._get_nexus_tile().HasField('grid_multi_variable_tile'): + grid_multi_variable_tile = self._get_nexus_tile().grid_multi_variable_tile + is_multi_var = True + + grid_tile_data = np.ma.masked_invalid(from_shaped_array(grid_multi_variable_tile.variable_data)) + latitude_data = np.ma.masked_invalid(from_shaped_array(grid_multi_variable_tile.latitude)) + longitude_data = np.ma.masked_invalid(from_shaped_array(grid_multi_variable_tile.longitude)) + + reflected_lon_array = np.broadcast_to(longitude_data, (len(latitude_data), len(longitude_data))) + reflected_lat_array = np.broadcast_to(latitude_data, (len(longitude_data), len(latitude_data))) + reflected_lat_array = np.transpose(reflected_lat_array) + + reflected_lat_array = np.expand_dims(reflected_lat_array, axis=0) + reflected_lon_array = np.expand_dims(reflected_lon_array, axis=0) + + time = np.array([grid_multi_variable_tile.time]) + reflected_time_array = np.broadcast_to(time, (len(latitude_data), len(longitude_data))) + + reflected_time_array = np.expand_dims(reflected_time_array, axis=0) + + # If there are 3 dimensions, that means the time dimension + # was squeezed. Add back in + if len(grid_tile_data.shape) == 3: + grid_tile_data = np.expand_dims(grid_tile_data, axis=0) + # If there are 4 dimensions, that means the time dimension + # is present. Move the multivar dimension. + if len(grid_tile_data.shape) == 4: + grid_tile_data = np.moveaxis(grid_tile_data, -1, 0) + + # Extract the meta data + meta_data = {} + for meta_data_obj in grid_multi_variable_tile.meta_data: + name = meta_data_obj.name + meta_array = np.ma.masked_invalid(from_shaped_array(meta_data_obj.meta_data)) + if len(meta_array.shape) == 2: + meta_array = meta_array[np.newaxis, :] + meta_data[name] = meta_array + + return reflected_lat_array, reflected_lon_array, reflected_time_array, grid_tile_data, meta_data, is_multi_var + else: + raise NotImplementedError("Only supports grid_tile, swath_tile, swath_multi_variable_tile, and time_series_tile") + + +class CassandraSwathProxy(object): + def __init__(self, config): + self.config = config + self.__cass_url = config.get("cassandra", "host") + self.__cass_username = config.get("cassandra", "username") + self.__cass_password = config.get("cassandra", "password") + self.__cass_keyspace = config.get("cassandra", "keyspace") + self.__cass_local_DC = config.get("cassandra", "local_datacenter") + self.__cass_protocol_version = config.getint("cassandra", "protocol_version") + self.__cass_dc_policy = config.get("cassandra", "dc_policy") + + try: + self.__cass_port = config.getint("cassandra", "port") + except NoOptionError: + self.__cass_port = 9042 + + with INIT_LOCK: + try: + connection.get_cluster() + except CQLEngineException: + self.__open() + + def __open(self): + if self.__cass_dc_policy == 'DCAwareRoundRobinPolicy': + dc_policy = DCAwareRoundRobinPolicy(self.__cass_local_DC) + token_policy = TokenAwarePolicy(dc_policy) + elif self.__cass_dc_policy == 'WhiteListRoundRobinPolicy': + token_policy = WhiteListRoundRobinPolicy([self.__cass_url]) + + if self.__cass_username and self.__cass_password: + auth_provider = PlainTextAuthProvider(username=self.__cass_username, password=self.__cass_password) + else: + auth_provider = None + try: + connection.setup( + [host for host in self.__cass_url.split(',')], self.__cass_keyspace, + protocol_version=self.__cass_protocol_version, load_balancing_policy=token_policy, + port=self.__cass_port, + auth_provider=auth_provider + ) + except NoHostAvailable as e: + logger.error("Cassandra is not accessible, SDAP will not server local datasets", e) + + def fetch_nexus_tiles(self, *tile_ids): + tile_ids = [uuid.UUID(str(tile_id)) for tile_id in tile_ids if + (isinstance(tile_id, str) or isinstance(tile_id, str))] + + res = [] + for tile_id in tile_ids: + filterResults = NexusTileData.objects.filter(tile_id=tile_id) + if len(filterResults) > 0: + res.append(filterResults[0]) + + return res diff --git a/data-access/nexustiles/model/nexusmodel.py b/data-access/nexustiles/model/nexusmodel.py index 7db4d614..a5f67058 100644 --- a/data-access/nexustiles/model/nexusmodel.py +++ b/data-access/nexustiles/model/nexusmodel.py @@ -91,6 +91,7 @@ class Tile(object): data: np.array = None is_multi: bool = None meta_data: dict = None + projection: str = 'grid' def __str__(self): return str(self.get_summary()) @@ -128,27 +129,53 @@ def get_summary(self): def nexus_point_generator(self, include_nan=False): indices = self.get_indices(include_nan) + if len(indices) == 0 or (isinstance(indices, np.ndarray) and indices.size == 0): + return + if self.projection == 'grid': + lat_slice = slice(1, 2) + lon_slice = slice(2, 3) + time_slice = slice(0, 1) + else: + lat_slice = slice(None) + lon_slice = slice(None) + time_slice = slice(None) + if include_nan: for index in indices: - time = self.times[index[0]] - lat = self.latitudes[index[1]] - lon = self.longitudes[index[2]] + time = self.times[index[time_slice]] + lat = self.latitudes[index[lat_slice]] + lon = self.longitudes[index[lon_slice]] + if self.is_multi: - data_vals = [data[index] for data in self.data] + data_vals = [] + + for data in self.data: + val = data[index] + + data_vals.append(val if val is not np.ma.masked else np.nan) else: data_vals = self.data[index] + point = NexusPoint(lat, lon, None, time, index, data_vals) yield point else: for index in indices: index = tuple(index) - time = self.times[index[0]] - lat = self.latitudes[index[1]] - lon = self.longitudes[index[2]] + + time = self.times[index[time_slice]] + lat = self.latitudes[index[lat_slice]] + lon = self.longitudes[index[lon_slice]] + if self.is_multi: - data_vals = [data[index] for data in self.data] + data_vals = [] + + for data in self.data: + val = data[index] + + data_vals.append(val if val is not np.ma.masked else np.nan) else: data_vals = self.data[index] + point = NexusPoint(lat, lon, None, time, index, data_vals) yield point @@ -156,7 +183,7 @@ def get_indices(self, include_nan=False): if include_nan: return list(np.ndindex(self.data.shape)) if self.is_multi: - combined_data_inv_mask = reduce(np.logical_and, [data.mask for data in self.data]) + combined_data_inv_mask = reduce(np.logical_and, [np.ma.getmaskarray(data) for data in self.data]) return np.argwhere(np.logical_not(combined_data_inv_mask)) else: return np.transpose(np.where(np.ma.getmaskarray(self.data) == False)).tolist() diff --git a/data-access/nexustiles/nexustiles.py b/data-access/nexustiles/nexustiles.py index a3aa61e9..9dd3bf80 100644 --- a/data-access/nexustiles/nexustiles.py +++ b/data-access/nexustiles/nexustiles.py @@ -14,9 +14,9 @@ # limitations under the License. import configparser +import json import logging import sys -import json from datetime import datetime from functools import wraps, reduce @@ -27,11 +27,11 @@ from shapely.geometry import MultiPolygon, box from .dao import CassandraProxy +from .dao import CassandraSwathProxy from .dao import DynamoProxy +from .dao import ElasticsearchProxy from .dao import S3Proxy from .dao import SolrProxy -from .dao import ElasticsearchProxy - from .model.nexusmodel import Tile, BBox, TileStats, TileVariable EPOCH = timezone('UTC').localize(datetime(1970, 1, 1)) @@ -40,7 +40,7 @@ level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt="%Y-%m-%dT%H:%M:%S", stream=sys.stdout) -logger = logging.getLogger("testing") +logger = logging.getLogger(__name__) def tile_data(default_fetch=True): @@ -56,7 +56,7 @@ def fetch_data_for_func(*args, **kwargs): if ('fetch_data' in kwargs and kwargs['fetch_data']) or ('fetch_data' not in kwargs and default_fetch): if len(tiles) > 0: cassandra_start = datetime.now() - args[0].fetch_data_for_tiles(*tiles) + args[0].fetch_data_for_tiles(*tiles, desired_projection=args[0].desired_projection) cassandra_duration += (datetime.now() - cassandra_start).total_seconds() if 'metrics_callback' in kwargs and kwargs['metrics_callback'] is not None: @@ -79,20 +79,28 @@ class NexusTileServiceException(Exception): class NexusTileService(object): - def __init__(self, skipDatastore=False, skipMetadatastore=False, config=None): + def __init__(self, skipDatastore=False, skipMetadatastore=False, config=None, desired_projection='grid'): self._datastore = None self._metadatastore = None self._config = configparser.RawConfigParser() self._config.read(NexusTileService._get_config_files('config/datastores.ini')) + if desired_projection not in ['grid', 'swath']: + raise ValueError(f'Invalid value provided for NexusTileService desired_projection: {desired_projection}') + + self.desired_projection = desired_projection + if config: self.override_config(config) if not skipDatastore: datastore = self._config.get("datastore", "store") if datastore == "cassandra": - self._datastore = CassandraProxy.CassandraProxy(self._config) + if desired_projection == "grid": + self._datastore = CassandraProxy.CassandraProxy(self._config) + else: + self._datastore = CassandraSwathProxy.CassandraSwathProxy(self._config) elif datastore == "s3": self._datastore = S3Proxy.S3Proxy(self._config) elif datastore == "dynamo": @@ -107,6 +115,9 @@ def __init__(self, skipDatastore=False, skipMetadatastore=False, config=None): elif metadatastore == "elasticsearch": self._metadatastore = ElasticsearchProxy.ElasticsearchProxy(self._config) + logger.info(f'Created new NexusTileService with data store {type(self._datastore)} and metadata ' + f'store {type(self._metadatastore)}. Desired projection: {desired_projection}') + def override_config(self, config): for section in config.sections(): if self._config.has_section(section): # only override preexisting section, ignores the other @@ -352,21 +363,39 @@ def get_distinct_bounding_boxes_in_polygon(self, bounding_polygon, ds, start_tim bounds = self._metadatastore.find_distinct_bounding_boxes_in_polygon(bounding_polygon, ds, start_time, end_time) return [box(*b) for b in bounds] + def _data_mask_logical_or(self, tile): + # Or together the masks of the individual arrays to create the new mask + if self.desired_projection == 'grid': + data_mask = ma.getmaskarray(tile.times)[:, np.newaxis, np.newaxis] \ + | ma.getmaskarray(tile.latitudes)[np.newaxis, :, np.newaxis] \ + | ma.getmaskarray(tile.longitudes)[np.newaxis, np.newaxis, :] + else: + if len(tile.times.shape) == 1: + data_mask = ma.getmaskarray(tile.times)[:, np.newaxis, np.newaxis] \ + | ma.getmaskarray(tile.latitudes)[np.newaxis, :, :] \ + | ma.getmaskarray(tile.longitudes)[np.newaxis, :, :] + else: + data_mask = ma.getmaskarray(tile.times) \ + | ma.getmaskarray(tile.latitudes) \ + | ma.getmaskarray(tile.longitudes) + + return data_mask + def mask_tiles_to_bbox(self, min_lat, max_lat, min_lon, max_lon, tiles): for tile in tiles: tile.latitudes = ma.masked_outside(tile.latitudes, min_lat, max_lat) tile.longitudes = ma.masked_outside(tile.longitudes, min_lon, max_lon) - # Or together the masks of the individual arrays to create the new mask - data_mask = ma.getmaskarray(tile.times)[:, np.newaxis, np.newaxis] \ - | ma.getmaskarray(tile.latitudes)[np.newaxis, :, np.newaxis] \ - | ma.getmaskarray(tile.longitudes)[np.newaxis, np.newaxis, :] + data_mask = self._data_mask_logical_or(tile) # If this is multi-var, need to mask each variable separately. if tile.is_multi: # Combine space/time mask with existing mask on data - data_mask = reduce(np.logical_or, [tile.data[0].mask, data_mask]) + # Data masks are ANDed because we want to mask out only when ALL data vars are invalid + combined_data_mask = reduce(np.logical_and, [ma.getmaskarray(d) for d in tile.data]) + # We now OR in the bounds mask because out of bounds data must be excluded regardless of validity + data_mask = np.logical_or(combined_data_mask, data_mask) num_vars = len(tile.data) multi_data_mask = np.repeat(data_mask[np.newaxis, ...], num_vars, axis=0) @@ -384,10 +413,7 @@ def mask_tiles_to_bbox_and_time(self, min_lat, max_lat, min_lon, max_lon, start_ tile.latitudes = ma.masked_outside(tile.latitudes, min_lat, max_lat) tile.longitudes = ma.masked_outside(tile.longitudes, min_lon, max_lon) - # Or together the masks of the individual arrays to create the new mask - data_mask = ma.getmaskarray(tile.times)[:, np.newaxis, np.newaxis] \ - | ma.getmaskarray(tile.latitudes)[np.newaxis, :, np.newaxis] \ - | ma.getmaskarray(tile.longitudes)[np.newaxis, np.newaxis, :] + data_mask = self._data_mask_logical_or(tile) tile.data = ma.masked_where(data_mask, tile.data) @@ -418,10 +444,7 @@ def mask_tiles_to_time_range(self, start_time, end_time, tiles): for tile in tiles: tile.times = ma.masked_outside(tile.times, start_time, end_time) - # Or together the masks of the individual arrays to create the new mask - data_mask = ma.getmaskarray(tile.times)[:, np.newaxis, np.newaxis] \ - | ma.getmaskarray(tile.latitudes)[np.newaxis, :, np.newaxis] \ - | ma.getmaskarray(tile.longitudes)[np.newaxis, np.newaxis, :] + data_mask = self._data_mask_logical_or(tile) # If this is multi-var, need to mask each variable separately. if tile.is_multi: @@ -450,7 +473,7 @@ def get_tile_count(self, ds, bounding_polygon=None, start_time=0, end_time=-1, m """ return self._metadatastore.get_tile_count(ds, bounding_polygon, start_time, end_time, metadata, **kwargs) - def fetch_data_for_tiles(self, *tiles): + def fetch_data_for_tiles(self, *tiles, **kwargs): nexus_tile_ids = set([tile.tile_id for tile in tiles]) matched_tile_data = self._datastore.fetch_nexus_tiles(*nexus_tile_ids) @@ -461,8 +484,12 @@ def fetch_data_for_tiles(self, *tiles): if len(missing_data) > 0: raise Exception("Missing data for tile_id(s) %s." % missing_data) + desired_projection = kwargs['desired_projection'] if 'desired_projection' in kwargs else self.desired_projection + for a_tile in tiles: - lats, lons, times, data, meta, is_multi_var = tile_data_by_id[a_tile.tile_id].get_lat_lon_time_data_meta() + lats, lons, times, data, meta, is_multi_var = tile_data_by_id[a_tile.tile_id].get_lat_lon_time_data_meta( + projection=desired_projection + ) a_tile.latitudes = lats a_tile.longitudes = lons @@ -470,6 +497,7 @@ def fetch_data_for_tiles(self, *tiles): a_tile.data = data a_tile.meta_data = meta a_tile.is_multi = is_multi_var + a_tile.projection = desired_projection del (tile_data_by_id[a_tile.tile_id])