Skip to content

Commit

Permalink
custom types removed.
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Aug 7, 2024
1 parent 3d4114e commit 47d1fe7
Show file tree
Hide file tree
Showing 25 changed files with 165 additions and 151 deletions.
4 changes: 2 additions & 2 deletions sup3r/bias/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,9 @@ def _match_zero_rate(bias_data, base_data):
Parameters
----------
bias_data : T_Array
bias_data : Union[np.ndarray, da.core.Array]
1D array of biased data observations.
base_data : T_Array
base_data : Union[np.ndarray, da.core.Array]
1D array of base data observations.
Returns
Expand Down
9 changes: 5 additions & 4 deletions sup3r/bias/bias_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
"""

import logging
from typing import Union
from warnings import warn

import dask.array as da
import numpy as np
import pandas as pd
from rex.utilities.bc_utils import QuantileDeltaMapping
from scipy.ndimage import gaussian_filter

from sup3r.preprocessing import Rasterizer
from sup3r.typing import T_Array

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,7 +115,7 @@ def get_spatial_bc_factors(lat_lon, feature_name, bias_fp, threshold=0.1):


def get_spatial_bc_quantiles(
lat_lon: T_Array,
lat_lon: Union[np.ndarray, da.core.Array],
base_dset: str,
feature_name: str,
bias_fp: str,
Expand All @@ -131,7 +132,7 @@ def get_spatial_bc_quantiles(
Parameters
----------
lat_lon : T_Array
lat_lon : Union[np.ndarray, da.core.Array]
Array of latitudes and longitudes for the domain to bias correct
(n_lats, n_lons, 2)
base_dset : str
Expand Down Expand Up @@ -480,7 +481,7 @@ def local_qdm_bc(
Parameters
----------
data : T_Array
data : Union[np.ndarray, da.core.Array]
Sup3r input data to be bias corrected, assumed to be 3D with shape
(spatial, spatial, temporal) for a single feature.
lat_lon : np.ndarray
Expand Down
2 changes: 1 addition & 1 deletion sup3r/bias/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def bias_correct_feature(
Returns
-------
data : T_Array
data : Union[np.ndarray, da.core.Array]
Data corrected by the bias_correct_method ready for input to the
forward pass through the generative model.
"""
Expand Down
8 changes: 4 additions & 4 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def pad_source_data(self, input_data, pad_width, exo_data, mode='reflect'):
Parameters
----------
input_data : T_Array
input_data : Union[np.ndarray, da.core.Array]
Source input data from data handler class, shape is:
(spatial_1, spatial_2, temporal, features)
pad_width : tuple
Expand All @@ -158,7 +158,7 @@ def pad_source_data(self, input_data, pad_width, exo_data, mode='reflect'):
Returns
-------
out : T_Array
out : Union[np.ndarray, da.core.Array]
Padded copy of source input data from data handler class, shape is:
(spatial_1, spatial_2, temporal, features)
exo_data : dict
Expand Down Expand Up @@ -287,7 +287,7 @@ def _reshape_data_chunk(model, data_chunk, exo_data):
----------
model : Sup3rGan
Sup3rGan or similar sup3r model
data_chunk : T_Array
data_chunk : Union[np.ndarray, da.core.Array]
Low resolution data for a single spatiotemporal chunk that is going
to be passed to the model generate function.
exo_data : dict | None
Expand All @@ -296,7 +296,7 @@ def _reshape_data_chunk(model, data_chunk, exo_data):
Returns
-------
data_chunk : T_Array
data_chunk : Union[np.ndarray, da.core.Array]
Same as input but reshaped to (temporal, spatial_1, spatial_2,
features) if the model is a spatial-first model or
(n_obs, spatial_1, spatial_2, temporal, features) if the
Expand Down
2 changes: 1 addition & 1 deletion sup3r/pipeline/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def hr_crop_slices(self):
list has a crop slice for each spatial dimension and temporal
dimension and then slice(None) for the feature dimension.
model.generate()[hr_crop_slice] gives the cropped generator output
corresponding to output_array[hr_slice]
corresponding to outpuUnion[np.ndarray, da.core.Array][hr_slice]
"""
if self._hr_crop_slices is None:
self._hr_crop_slices = []
Expand Down
37 changes: 19 additions & 18 deletions sup3r/pipeline/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

import dask.array as da
import numpy as np
import pandas as pd

Expand All @@ -27,7 +28,6 @@
get_input_handler_class,
log_args,
)
from sup3r.typing import T_Array
from sup3r.utilities.utilities import Timer

logger = logging.getLogger(__name__)
Expand All @@ -38,13 +38,13 @@ class ForwardPassChunk:
"""Structure storing chunk data and attributes for a specific chunk going
through the generator."""

input_data: T_Array
input_data: Union[np.ndarray, da.core.Array]
exo_data: Dict
hr_crop_slice: slice
lr_pad_slice: slice
hr_lat_lon: T_Array
hr_lat_lon: Union[np.ndarray, da.core.Array]
hr_times: pd.DatetimeIndex
gids: T_Array
gids: Union[np.ndarray, da.core.Array]
out_file: str
pad_width: Tuple[tuple, tuple, tuple]
index: int
Expand Down Expand Up @@ -76,14 +76,14 @@ class ForwardPassStrategy:
string with a unix-style file path which will be passed through
glob.glob
model_kwargs : str | list
Keyword arguments to send to `model_class.load(**model_kwargs)` to
Keyword arguments to send to ``model_class.load(**model_kwargs)`` to
initialize the GAN. Typically this is just the string path to the
model directory, but can be multiple models or arguments for more
complex models.
fwp_chunk_shape : tuple
Max shape (spatial_1, spatial_2, temporal) of an unpadded coarse chunk
to use for a forward pass. The number of nodes that the
:class:`ForwardPassStrategy` is set to distribute to is calculated by
:class:`.ForwardPassStrategy` is set to distribute to is calculated by
dividing up the total time index from all file_paths by the temporal
part of this chunk shape. Each node will then be parallelized across
parallel processes by the spatial chunk shape. If temporal_pad /
Expand All @@ -100,25 +100,26 @@ class ForwardPassStrategy:
the fwp_chunk_shape.
model_class : str
Name of the sup3r model class for the GAN model to load. The default is
the basic spatial / spatiotemporal Sup3rGan model. This will be loaded
from sup3r.models
the basic spatial / spatiotemporal ``Sup3rGan`` model. This will be
loaded from ``sup3r.models``
out_pattern : str
Output file pattern. Must include {file_id} format key. Each output
file will have a unique file_id filled in and the ext determines the
output type. If pattern is None then data will be returned in an array
and not saved.
input_handler_name : str | None
Class to use for input data. Provide a string name to match an
rasterizer or handler class in `sup3r.preprocessing`
rasterizer or handler class in ``sup3r.preprocessing``
input_handler_kwargs : dict | None
Any kwargs for initializing the `input_handler_name` class.
Any kwargs for initializing the ``input_handler_name`` class.
exo_handler_kwargs : dict | None
Dictionary of args to pass to :class:`ExoDataHandler` for extracting
exogenous features for multistep foward pass. This should be a nested
dictionary with keys for each exogenous feature. The dictionaries
corresponding to the feature names should include the path to exogenous
data source, the resolution of the exogenous data, and how the
exogenous data should be used in the model. e.g. {'topography':
Dictionary of args to pass to
:class:`~sup3r.preprocessing.data_handlers.ExoDataHandler` for
extracting exogenous features for multistep foward pass. This should be
a nested dictionary with keys for each exogenous feature. The
dictionaries corresponding to the feature names should include the path
to exogenous data source, the resolution of the exogenous data, and how
the exogenous data should be used in the model. e.g. {'topography':
{'file_paths': 'path to input files', 'source_file': 'path to exo
data', 'steps': [..]}.
bias_correct_method : str | None
Expand Down Expand Up @@ -153,13 +154,13 @@ class ForwardPassStrategy:
node. If 1 then all forward passes on chunks distributed to a single
node will be run serially. pass_workers=2 is the minimum number of
workers required to run the ForwardPass initialization and
:meth:`ForwardPass.run_chunk()` methods concurrently.
:meth:`~.forward_pass.ForwardPass.run_chunk()` methods concurrently.
max_nodes : int | None
Maximum number of nodes to distribute spatiotemporal chunks across. If
None then a node will be used for each temporal chunk.
head_node : bool
Whether initialization is taking place on the head node of a multi node
job launch. When this is true :class:`ForwardPassStrategy` is only
job launch. When this is true :class:`.ForwardPassStrategy` is only
partially initialized to provide the head node enough information for
how to distribute jobs across nodes. Preflight tasks like bias
correction will be skipped because they will be performed on the nodes
Expand Down
2 changes: 1 addition & 1 deletion sup3r/postprocessing/collectors/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_data(
Returns
-------
f_data : T_Array
f_data : Union[np.ndarray, da.core.Array]
Data array from the fpath cast as input dtype.
row_slice : slice
final_time_index[row_slice] = new_time_index
Expand Down
2 changes: 1 addition & 1 deletion sup3r/postprocessing/writers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _ensure_dset_in_output(cls, out_file, dset, data=None):
Pre-existing H5 file output path
dset : str
Dataset name
data : T_Array | None
data : Union[np.ndarray, da.core.Array] | None
Optional data to write to dataset if initializing.
"""

Expand Down
17 changes: 10 additions & 7 deletions sup3r/preprocessing/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
parse_ellipsis,
parse_to_list,
)
from sup3r.typing import T_Array

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,7 +110,9 @@ def parse_keys(self, keys):
dim_keys = parse_ellipsis(dim_keys, dim_num=len(self._ds.dims))
return features, dict(zip(ordered_dims(self._ds.dims), dim_keys))

def __getitem__(self, keys) -> Union[T_Array, Self]:
def __getitem__(
self, keys
) -> Union[Union[np.ndarray, da.core.Array], Self]:
"""Method for accessing variables. keys can optionally include a
feature name or list of feature names as the first entry of a keys
tuple.
Expand Down Expand Up @@ -164,7 +165,7 @@ def __setitem__(self, keys, data):
keys to set. This can be a string like 'temperature' or a list
like ``['u', 'v']``. ``data`` will be iterated over in the latter
case.
data : T_Array | xr.DataArray
data : Union[np.ndarray, da.core.Array] | xr.DataArray
array object used to set variable data. If ``variable`` is a list
then this is expected to have a trailing dimension with length
equal to the length of the list.
Expand Down Expand Up @@ -212,7 +213,7 @@ def values(self, *args, **kwargs):
..., features)``"""
return np.asarray(self.to_array(*args, **kwargs))

def to_dataarray(self) -> T_Array:
def to_dataarray(self) -> Union[np.ndarray, da.core.Array]:
"""Return xr.DataArray for the contained xr.Dataset."""
if not self.features:
coords = [self._ds[f] for f in Dimension.coords_2d()]
Expand Down Expand Up @@ -390,7 +391,7 @@ def interpolate_na(self, **kwargs):
return type(self)(self._ds)

@staticmethod
def _needs_fancy_indexing(keys) -> T_Array:
def _needs_fancy_indexing(keys) -> Union[np.ndarray, da.core.Array]:
"""We use `.vindex` if keys require fancy indexing."""
where_list = [
ind for ind in keys if isinstance(ind, np.ndarray) and ind.ndim > 0
Expand Down Expand Up @@ -444,7 +445,9 @@ def add_dims_to_data_vars(self, vals):
new_vals[k] = v
return new_vals

def assign(self, vals: Dict[str, Union[T_Array, tuple]]):
def assign(
self, vals: Dict[str, Union[Union[np.ndarray, da.core.Array], tuple]]
):
"""Override xarray assign and assign_coords methods to enable update
without explicitly providing dimensions if variable already exists.
Expand Down Expand Up @@ -506,7 +509,7 @@ def time_step(self):
return float(mode(sec_diff, keepdims=False).mode)

@property
def lat_lon(self) -> T_Array:
def lat_lon(self) -> Union[np.ndarray, da.core.Array]:
"""Base lat lon for contained data."""
coords = [self._ds[d] for d in Dimension.coords_2d()]
return self._stack_features(coords)
Expand Down
6 changes: 3 additions & 3 deletions sup3r/preprocessing/batch_queues/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def transform(
Parameters
----------
samples : T_Array
samples : Union[np.ndarray, da.core.Array]
High resolution batch of samples.
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
Expand All @@ -60,11 +60,11 @@ def transform(
Returns
-------
low_res : T_Array
low_res : Union[np.ndarray, da.core.Array]
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
high_res : T_Array
high_res : Union[np.ndarray, da.core.Array]
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
Expand Down
Loading

0 comments on commit 47d1fe7

Please sign in to comment.