From 65ed844e8a7294bd541924c79bb84b629de2ef5f Mon Sep 17 00:00:00 2001 From: Simon Liu Date: Wed, 19 Feb 2025 11:06:05 -0800 Subject: [PATCH] test --- podaac/subsetter/new_new_tree.py | 165 ++++++++++++++++++++---- podaac/subsetter/subset.py | 96 +++++++++----- podaac/subsetter/xarray_enhancements.py | 5 + pytest.ini | 2 - tests/test_subset.py | 2 + 5 files changed, 216 insertions(+), 54 deletions(-) diff --git a/podaac/subsetter/new_new_tree.py b/podaac/subsetter/new_new_tree.py index 5e6892ad..ac7b5739 100644 --- a/podaac/subsetter/new_new_tree.py +++ b/podaac/subsetter/new_new_tree.py @@ -211,7 +211,15 @@ def _create_subset_dataset( return DataTree(name='root', dataset=result) -def where_tree(tree: DataTree, cond: Union[xr.Dataset, xr.DataArray], cut: bool) -> DataTree: +def get_condition(condition_dict, path): + while path: + cond = condition_dict.get(path) + if cond is not None: + return cond + path = "/".join(path.rstrip("/").split("/")[:-1]) # Remove last segment + return condition_dict.get("/", None) # Final fallback to root + +def where_tree(tree: DataTree, cond: Union[xr.Dataset, xr.DataArray], cut: bool, condition_dict) -> DataTree: """ Return a DataTree which meets the given condition, processing all nodes in the tree. @@ -229,7 +237,7 @@ def where_tree(tree: DataTree, cond: Union[xr.Dataset, xr.DataArray], cut: bool) xarray.DataTree The filtered DataTree with all nodes processed """ - def process_node(node: DataTree) -> Tuple[xr.Dataset, Dict[str, DataTree]]: + def process_node(node: DataTree, path: str) -> Tuple[xr.Dataset, Dict[str, DataTree]]: """ Process a single node and its children in the tree. @@ -237,12 +245,17 @@ def process_node(node: DataTree) -> Tuple[xr.Dataset, Dict[str, DataTree]]: ---------- node : DataTree The node to process + path : str + The current path of the node Returns ------- Tuple[xr.Dataset, Dict[str, DataTree]] Processed dataset and dictionary of processed child nodes """ + # Print the current path + cond = get_condition(condition_dict, path) + # Get the dataset directly from the node dataset = node.ds @@ -265,23 +278,33 @@ def process_node(node: DataTree) -> Tuple[xr.Dataset, Dict[str, DataTree]]: # Apply indexing to condition and dataset indexed_cond = cond.isel(**indexers) - try: - indexed_ds = dataset.isel(**indexers) - except Exception as ex: - indexed_ds = dataset + indexed_ds = dataset.isel(**indexers) # Get variables with and without indexers subset_vars, non_subset_vars = get_variables_with_indexers(dataset, indexers) + # dataset with variables that need to be subsetted + new_dataset_sub = indexed_ds[subset_vars].where(indexed_cond) + # data with variables that shouldn't be subsetted + new_dataset_non_sub = indexed_ds[non_subset_vars] + + """ + print(subset_vars) + print(non_subset_vars) # Process variables sub_ds = create_subset_dataset(dataset, subset_vars, indexers, cond) non_sub_ds = create_subset_dataset(dataset, non_subset_vars, indexers, cond) - + + print("############################") + print(path) + print('applying subsetting') + print("############################") + """ + # Merge the datasets - merged_ds = xr.merge([non_sub_ds, sub_ds]) - - processed_ds = merged_ds - processed_ds.attrs.update(dataset.attrs) + #merged_ds = xr.merge([non_sub_ds, sub_ds]) + merged_ds = xr.merge([new_dataset_non_sub, new_dataset_sub]) + merged_ds.attrs.update(dataset.attrs) # Restore original data types for var, dtype in original_dtypes.items(): @@ -305,7 +328,7 @@ def process_node(node: DataTree) -> Tuple[xr.Dataset, Dict[str, DataTree]]: processed_children = {} for child_name, child_node in node.children.items(): # Process the child node - child_ds, child_children = process_node(child_node) + child_ds, child_children = process_node(child_node, f"{path}/{child_name}") # Create new DataTree for the processed child child_tree = DataTree(name=child_name, dataset=child_ds) @@ -319,7 +342,7 @@ def process_node(node: DataTree) -> Tuple[xr.Dataset, Dict[str, DataTree]]: return processed_ds, processed_children # Start processing from root - root_ds, children = process_node(tree) + root_ds, children = process_node(tree, '') # Create new root tree preserving the original name and attributes result_tree = DataTree(name=tree.name, dataset=root_ds) @@ -331,6 +354,7 @@ def process_node(node: DataTree) -> Tuple[xr.Dataset, Dict[str, DataTree]]: # Copy over root attributes result_tree.attrs.update(tree.attrs) + print(result_tree) return result_tree """ @@ -486,9 +510,12 @@ def create_subset_dataset( if var_indexers: var_cond = cond.isel(**var_indexers) subset_dict[var_name] = indexed_var.where(var_cond) + print("Subsetting") else: subset_dict[var_name] = indexed_var + print("NOT WHAT") else: + print("WHAT") subset_dict[var_name] = dataset[var_name] return xr.Dataset(subset_dict) @@ -503,8 +530,6 @@ def get_variables_with_indexers(dataset, indexers): subset_vars = [] no_subset_vars = [] - - for i in list(dataset.variables.keys()): variable_dims = list(dataset[i].dims) if any(item in index_list for item in variable_dims): @@ -829,8 +854,6 @@ def find_matching_coords(dataset: xr.Dataset, match_list: List[str]) -> List[str return match_coord_vars - - def compute_time_variable_name_tree(tree, lat_var, total_time_vars): time_coord_name = [] @@ -915,16 +938,12 @@ def get_squeezed_dims(var_name: str) -> tuple: return None def traverse_tree(node, path): - print(node) """Recursively search through the tree for latitude and longitude coordinates.""" if node.ds is not None: return_time = find_time_in_dataset(node.ds, lat_var, path, total_time_vars) if return_time: - print("FOUND TIME") time_var = f"{path}/{return_time}" return time_var - #print(time_var) - #time_coord_name.append(time_var) for child_name, child_node in node.children.items(): new_path = f"{path}/{child_name}" if path else child_name @@ -982,4 +1001,106 @@ def get_variable_from_path(datatree: Any, path: str) -> Optional[Union[xr.DataAr except (AttributeError, TypeError): # Return None if any error occurs during traversal - return None \ No newline at end of file + return None + +def get_path(s): + """Extracts the path by removing the last part after the final '/'.""" + path = s.rsplit('/', 1)[0] if '/' in s else s + return f"/{path}" + +def tree_get_spatial_bounds(datatree: xr.Dataset, lat_var_names: List[str], lon_var_names: List[str]) -> Union[np.ndarray, None]: + """ + Get the spatial bounds for this dataset tree. These values are masked and scaled. + + Parameters + ---------- + datatree : xr.Dataset + Dataset tree to retrieve spatial bounds for + lat_var_names : List[str] + List of paths to latitude variables + lon_var_names : List[str] + List of paths to longitude variables + + Returns + ------- + np.array + [[lon min, lon max], [lat min, lat max]] + """ + if len(lat_var_names) != len(lon_var_names): + raise ValueError("Number of latitude and longitude paths must match") + + min_lats, max_lats, min_lons, max_lons = [], [], [], [] + + for lat_var_name, lon_var_name in zip(lat_var_names, lon_var_names): + try: + # Get variables from paths + lat_data = get_variable_from_path(datatree, lat_var_name) + lon_data = get_variable_from_path(datatree, lon_var_name) + + if get_path(lat_var_name) != get_path(lon_var_name): + continue + + # Get metadata attributes efficiently + lat_attrs = lat_data.attrs + lon_attrs = lon_data.attrs + + # Extract metadata with defaults + lat_scale = lat_attrs.get('scale_factor', 1.0) + lon_scale = lon_attrs.get('scale_factor', 1.0) + lat_offset = lat_attrs.get('add_offset', 0.0) + lon_offset = lon_attrs.get('add_offset', 0.0) + lon_valid_min = lon_attrs.get('valid_min', None) + + # Flatten and mask data + lats = lat_data.values.flatten() + lons = lon_data.values.flatten() + + # Apply fill value masks if present + lat_fill = lat_attrs.get('_FillValue') + lon_fill = lon_attrs.get('_FillValue') + + if lat_fill is not None: + lats = lats[lats != lat_fill] + if lon_fill is not None: + lons = lons[lons != lon_fill] + + if len(lats) == 0 or len(lons) == 0: + continue + + # Calculate bounds efficiently using vectorized operations + min_lat = round((np.nanmin(lats) * lat_scale) - lat_offset, 1) + max_lat = round((np.nanmax(lats) * lat_scale) - lat_offset, 1) + min_lon = round((np.nanmin(lons) * lon_scale) - lon_offset, 1) + max_lon = round((np.nanmax(lons) * lon_scale) - lon_offset, 1) + + # Handle longitude conversion to [-180, 180] format + if lon_valid_min == 0 or 0 <= min_lon <= max_lon <= 360: + if min_lon > 180: + min_lon -= 360 + if max_lon > 180: + max_lon -= 360 + if min_lon == max_lon: + min_lon = -180 + max_lon = 180 + + min_lats.append(min_lat) + max_lats.append(max_lat) + min_lons.append(min_lon) + max_lons.append(max_lon) + + except (KeyError, AttributeError) as e: + print(f"Warning: Error processing {lat_var_name}/{lon_var_name}: {str(e)}") + continue + + if not min_lats: # If no valid bounds were found + return None + + print(min_lons) + print(max_lons) + print(min_lats) + print(max_lats) + # Calculate overall bounds using numpy operations + return np.array([ + [min(min_lons), max(max_lons)], + [min(min_lats), max(max_lats)] + ]) \ No newline at end of file diff --git a/podaac/subsetter/subset.py b/podaac/subsetter/subset.py index d9fc4fe4..4467166b 100644 --- a/podaac/subsetter/subset.py +++ b/podaac/subsetter/subset.py @@ -164,8 +164,12 @@ def convert_bbox(bbox: np.ndarray, dataset: xr.Dataset, lat_var_name: str, lon_v Assumption that the provided bounding box is always between -180 --> 180 for longitude and -90, 90 for latitude. """ - return np.array([convert_bound(bbox[0], 360, dataset[lon_var_name]), - convert_bound(bbox[1], 180, dataset[lat_var_name])]) + + lon_data = new_new_tree.get_variable_from_path(dataset, lon_var_name) + lat_data = new_new_tree.get_variable_from_path(dataset, lat_var_name) + + return np.array([convert_bound(bbox[0], 360, lon_data), + convert_bound(bbox[1], 180, lat_data)]) def set_json_history(dataset: xr.Dataset, cut: bool, file_to_subset: str, @@ -947,12 +951,10 @@ def build_cond(str_timestamp, compare): return temporal_cond - - - - - - +def get_path(s): + """Extracts the path by removing the last part after the final '/'.""" + path = s.rsplit('/', 1)[0] if '/' in s else s + return f"/{path}" def subset_with_bbox(dataset: xr.Dataset, # pylint: disable=too-many-branches @@ -1011,6 +1013,48 @@ def subset_with_bbox(dataset: xr.Dataset, # pylint: disable=too-many-branches datasets = [] total_list = [] # don't include repeated variables + + print(lat_bounds[0]) + print(lat_bounds[1]) + print(lon_bounds[0]) + print(lon_bounds[1]) + + subset_dictionary = {} + for lat_var_name, lon_var_name, time_var_name in zip(lat_var_names, lon_var_names, time_var_names): + + lat_path = get_path(lat_var_name) + lon_path = get_path(lon_var_name) + time_path = get_path(time_var_name) + + temporal_cond = new_build_temporal_cond(min_time, max_time, dataset, time_var_name) + + lon_data = new_new_tree.get_variable_from_path(dataset, lon_var_name) + lat_data = new_new_tree.get_variable_from_path(dataset, lat_var_name) + + operation = ( + oper((lon_data >= lon_bounds[0]), (lon_data <= lon_bounds[1])) & + (lat_data >= lat_bounds[0]) & + (lat_data <= lat_bounds[1]) & + temporal_cond + ) + + + if lat_path == lon_path and lat_path == time_path and lon_path == time_path: + print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") + print(lon_data) + print(lat_data) + print(operation) + print(lat_path) + print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%") + subset_dictionary[lat_path] = operation + + print("##########################") + print(subset_dictionary) + print("##########################") + + return_dataset = xre.tree_where(dataset, subset_dictionary, cut) + return return_dataset + """ for lat_var_name, lon_var_name, time_var_name, diffs in zip( # pylint: disable=too-many-nested-blocks lat_var_names, lon_var_names, time_var_names, diff_count ): @@ -1074,18 +1118,11 @@ def subset_with_bbox(dataset: xr.Dataset, # pylint: disable=too-many-branches lon_data = new_new_tree.get_variable_from_path(dataset, lon_var_name) lat_data = new_new_tree.get_variable_from_path(dataset, lat_var_name) - print(lon_var_name) - print(lat_var_name) - print(lon_data) - print(lat_data) - #print(lon_bounds[0]) - #print(lon_bounds[1]) - #print(lat_bounds[0]) - #print(lat_bounds[1]) + #print(lon_var_name) + #print(lat_var_name) #print(lon_data) #print(lat_data) - - print(group_dataset) + #print(group_dataset) group_dataset = xre.where( group_dataset, @@ -1099,7 +1136,7 @@ def subset_with_bbox(dataset: xr.Dataset, # pylint: disable=too-many-branches cut ) return group_dataset - """ + group_dataset = xre.where( group_dataset, oper( @@ -1111,14 +1148,13 @@ def subset_with_bbox(dataset: xr.Dataset, # pylint: disable=too-many-branches temporal_cond, cut ) - """ - + datasets.append(group_dataset) total_list.extend(group_vars) if diffs == -1: return datasets - - dim_cleaned_datasets = dc.recreate_pixcore_dimensions(datasets) + """ + #dim_cleaned_datasets = dc.recreate_pixcore_dimensions(datasets) return dim_cleaned_datasets @@ -1485,9 +1521,9 @@ def subset(file_to_subset: str, bbox: np.ndarray, output_file: str, else: raise ValueError('Either bbox or shapefile must be provided') - print("##############################") - print(datasets) - print("##############################") + #print("##############################") + #print(datasets) + #print("##############################") datasets.to_netcdf(output_file) spatial_bounds = [] @@ -1536,8 +1572,8 @@ def subset(file_to_subset: str, bbox: np.ndarray, output_file: str, ]]) """ - return get_spatial_bounds( - dataset=dataset, - lat_var_names=lat_var_names, - lon_var_names=lon_var_names + return new_new_tree.tree_get_spatial_bounds( + dataset, + lat_var_names, + lon_var_names ) diff --git a/podaac/subsetter/xarray_enhancements.py b/podaac/subsetter/xarray_enhancements.py index e9c6384b..c0351a57 100644 --- a/podaac/subsetter/xarray_enhancements.py +++ b/podaac/subsetter/xarray_enhancements.py @@ -184,6 +184,11 @@ def get_variables_with_indexers(dataset, indexers): return subset_vars, no_subset_vars +def tree_where(dataset, cond_dictionary, cut) -> xr.Dataset: + + return new_new_tree.where_tree(dataset, None, cut, cond_dictionary) + + def where(dataset: xr.Dataset, cond: Union[xr.Dataset, xr.DataArray], cut: bool) -> xr.Dataset: """ Return a dataset which meets the given condition. diff --git a/pytest.ini b/pytest.ini index 85c12f3f..e69de29b 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +0,0 @@ -[pytest] -addopts = --reruns 3 --reruns-delay 15 diff --git a/tests/test_subset.py b/tests/test_subset.py index d143c4aa..ddc2dddd 100644 --- a/tests/test_subset.py +++ b/tests/test_subset.py @@ -976,6 +976,8 @@ def test_group_subset(data_dir, subset_output_dir): output_file=os.path.join(subset_output_dir, s6_output_file_name) ) + print(bounds) + # Check that bounds are within requested bbox assert bounds[0][0] >= bbox[0][0] assert bounds[0][1] <= bbox[0][1]