Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
sliu008 committed Feb 19, 2025
1 parent 281afad commit 65ed844
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 54 deletions.
165 changes: 143 additions & 22 deletions podaac/subsetter/new_new_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,15 @@ def _create_subset_dataset(

return DataTree(name='root', dataset=result)

def where_tree(tree: DataTree, cond: Union[xr.Dataset, xr.DataArray], cut: bool) -> DataTree:
def get_condition(condition_dict, path):
while path:
cond = condition_dict.get(path)
if cond is not None:
return cond
path = "/".join(path.rstrip("/").split("/")[:-1]) # Remove last segment
return condition_dict.get("/", None) # Final fallback to root

def where_tree(tree: DataTree, cond: Union[xr.Dataset, xr.DataArray], cut: bool, condition_dict) -> DataTree:
"""
Return a DataTree which meets the given condition, processing all nodes in the tree.
Expand All @@ -229,20 +237,25 @@ def where_tree(tree: DataTree, cond: Union[xr.Dataset, xr.DataArray], cut: bool)
xarray.DataTree
The filtered DataTree with all nodes processed
"""
def process_node(node: DataTree) -> Tuple[xr.Dataset, Dict[str, DataTree]]:
def process_node(node: DataTree, path: str) -> Tuple[xr.Dataset, Dict[str, DataTree]]:
"""
Process a single node and its children in the tree.
Parameters
----------
node : DataTree
The node to process
path : str
The current path of the node
Returns
-------
Tuple[xr.Dataset, Dict[str, DataTree]]
Processed dataset and dictionary of processed child nodes
"""
# Print the current path
cond = get_condition(condition_dict, path)

# Get the dataset directly from the node
dataset = node.ds

Expand All @@ -265,23 +278,33 @@ def process_node(node: DataTree) -> Tuple[xr.Dataset, Dict[str, DataTree]]:

# Apply indexing to condition and dataset
indexed_cond = cond.isel(**indexers)
try:
indexed_ds = dataset.isel(**indexers)
except Exception as ex:
indexed_ds = dataset
indexed_ds = dataset.isel(**indexers)

# Get variables with and without indexers
subset_vars, non_subset_vars = get_variables_with_indexers(dataset, indexers)

# dataset with variables that need to be subsetted
new_dataset_sub = indexed_ds[subset_vars].where(indexed_cond)
# data with variables that shouldn't be subsetted
new_dataset_non_sub = indexed_ds[non_subset_vars]

"""
print(subset_vars)
print(non_subset_vars)
# Process variables
sub_ds = create_subset_dataset(dataset, subset_vars, indexers, cond)
non_sub_ds = create_subset_dataset(dataset, non_subset_vars, indexers, cond)

print("############################")
print(path)
print('applying subsetting')
print("############################")
"""

# Merge the datasets
merged_ds = xr.merge([non_sub_ds, sub_ds])

processed_ds = merged_ds
processed_ds.attrs.update(dataset.attrs)
#merged_ds = xr.merge([non_sub_ds, sub_ds])
merged_ds = xr.merge([new_dataset_non_sub, new_dataset_sub])
merged_ds.attrs.update(dataset.attrs)

# Restore original data types
for var, dtype in original_dtypes.items():
Expand All @@ -305,7 +328,7 @@ def process_node(node: DataTree) -> Tuple[xr.Dataset, Dict[str, DataTree]]:
processed_children = {}
for child_name, child_node in node.children.items():
# Process the child node
child_ds, child_children = process_node(child_node)
child_ds, child_children = process_node(child_node, f"{path}/{child_name}")

# Create new DataTree for the processed child
child_tree = DataTree(name=child_name, dataset=child_ds)
Expand All @@ -319,7 +342,7 @@ def process_node(node: DataTree) -> Tuple[xr.Dataset, Dict[str, DataTree]]:
return processed_ds, processed_children

# Start processing from root
root_ds, children = process_node(tree)
root_ds, children = process_node(tree, '')

# Create new root tree preserving the original name and attributes
result_tree = DataTree(name=tree.name, dataset=root_ds)
Expand All @@ -331,6 +354,7 @@ def process_node(node: DataTree) -> Tuple[xr.Dataset, Dict[str, DataTree]]:
# Copy over root attributes
result_tree.attrs.update(tree.attrs)

print(result_tree)
return result_tree

"""
Expand Down Expand Up @@ -486,9 +510,12 @@ def create_subset_dataset(
if var_indexers:
var_cond = cond.isel(**var_indexers)
subset_dict[var_name] = indexed_var.where(var_cond)
print("Subsetting")
else:
subset_dict[var_name] = indexed_var
print("NOT WHAT")
else:
print("WHAT")
subset_dict[var_name] = dataset[var_name]

return xr.Dataset(subset_dict)
Expand All @@ -503,8 +530,6 @@ def get_variables_with_indexers(dataset, indexers):
subset_vars = []
no_subset_vars = []



for i in list(dataset.variables.keys()):
variable_dims = list(dataset[i].dims)
if any(item in index_list for item in variable_dims):
Expand Down Expand Up @@ -829,8 +854,6 @@ def find_matching_coords(dataset: xr.Dataset, match_list: List[str]) -> List[str
return match_coord_vars




def compute_time_variable_name_tree(tree, lat_var, total_time_vars):

time_coord_name = []
Expand Down Expand Up @@ -915,16 +938,12 @@ def get_squeezed_dims(var_name: str) -> tuple:
return None

def traverse_tree(node, path):
print(node)
"""Recursively search through the tree for latitude and longitude coordinates."""
if node.ds is not None:
return_time = find_time_in_dataset(node.ds, lat_var, path, total_time_vars)
if return_time:
print("FOUND TIME")
time_var = f"{path}/{return_time}"
return time_var
#print(time_var)
#time_coord_name.append(time_var)

for child_name, child_node in node.children.items():
new_path = f"{path}/{child_name}" if path else child_name
Expand Down Expand Up @@ -982,4 +1001,106 @@ def get_variable_from_path(datatree: Any, path: str) -> Optional[Union[xr.DataAr

except (AttributeError, TypeError):
# Return None if any error occurs during traversal
return None
return None

def get_path(s):
"""Extracts the path by removing the last part after the final '/'."""
path = s.rsplit('/', 1)[0] if '/' in s else s
return f"/{path}"

def tree_get_spatial_bounds(datatree: xr.Dataset, lat_var_names: List[str], lon_var_names: List[str]) -> Union[np.ndarray, None]:
"""
Get the spatial bounds for this dataset tree. These values are masked and scaled.
Parameters
----------
datatree : xr.Dataset
Dataset tree to retrieve spatial bounds for
lat_var_names : List[str]
List of paths to latitude variables
lon_var_names : List[str]
List of paths to longitude variables
Returns
-------
np.array
[[lon min, lon max], [lat min, lat max]]
"""
if len(lat_var_names) != len(lon_var_names):
raise ValueError("Number of latitude and longitude paths must match")

min_lats, max_lats, min_lons, max_lons = [], [], [], []

for lat_var_name, lon_var_name in zip(lat_var_names, lon_var_names):
try:
# Get variables from paths
lat_data = get_variable_from_path(datatree, lat_var_name)
lon_data = get_variable_from_path(datatree, lon_var_name)

if get_path(lat_var_name) != get_path(lon_var_name):
continue

# Get metadata attributes efficiently
lat_attrs = lat_data.attrs
lon_attrs = lon_data.attrs

# Extract metadata with defaults
lat_scale = lat_attrs.get('scale_factor', 1.0)
lon_scale = lon_attrs.get('scale_factor', 1.0)
lat_offset = lat_attrs.get('add_offset', 0.0)
lon_offset = lon_attrs.get('add_offset', 0.0)
lon_valid_min = lon_attrs.get('valid_min', None)

# Flatten and mask data
lats = lat_data.values.flatten()
lons = lon_data.values.flatten()

# Apply fill value masks if present
lat_fill = lat_attrs.get('_FillValue')
lon_fill = lon_attrs.get('_FillValue')

if lat_fill is not None:
lats = lats[lats != lat_fill]
if lon_fill is not None:
lons = lons[lons != lon_fill]

if len(lats) == 0 or len(lons) == 0:
continue

# Calculate bounds efficiently using vectorized operations
min_lat = round((np.nanmin(lats) * lat_scale) - lat_offset, 1)
max_lat = round((np.nanmax(lats) * lat_scale) - lat_offset, 1)
min_lon = round((np.nanmin(lons) * lon_scale) - lon_offset, 1)
max_lon = round((np.nanmax(lons) * lon_scale) - lon_offset, 1)

# Handle longitude conversion to [-180, 180] format
if lon_valid_min == 0 or 0 <= min_lon <= max_lon <= 360:
if min_lon > 180:
min_lon -= 360
if max_lon > 180:
max_lon -= 360
if min_lon == max_lon:
min_lon = -180
max_lon = 180

min_lats.append(min_lat)
max_lats.append(max_lat)
min_lons.append(min_lon)
max_lons.append(max_lon)

except (KeyError, AttributeError) as e:
print(f"Warning: Error processing {lat_var_name}/{lon_var_name}: {str(e)}")
continue

if not min_lats: # If no valid bounds were found
return None

print(min_lons)
print(max_lons)
print(min_lats)
print(max_lats)
# Calculate overall bounds using numpy operations
return np.array([
[min(min_lons), max(max_lons)],
[min(min_lats), max(max_lats)]
])
Loading

0 comments on commit 65ed844

Please sign in to comment.