From 8dccc930bc8c2649d730f5062ba0cfc4c44878e5 Mon Sep 17 00:00:00 2001 From: zhengp0 Date: Tue, 25 Jun 2024 09:02:43 -0700 Subject: [PATCH] remove all Union type hints --- src/mrtool/core/data.py | 46 +++++++++++++-------------- src/mrtool/core/model.py | 5 ++- src/mrtool/core/utils.py | 8 ++--- src/mrtool/cov_selection/covfinder.py | 7 ++-- 4 files changed, 30 insertions(+), 36 deletions(-) diff --git a/src/mrtool/core/data.py b/src/mrtool/core/data.py index 57468be..2f6ee65 100644 --- a/src/mrtool/core/data.py +++ b/src/mrtool/core/data.py @@ -8,7 +8,7 @@ import warnings from dataclasses import dataclass, field -from typing import Any, Union +from typing import Any import numpy as np import pandas as pd @@ -208,9 +208,7 @@ def _assert_not_empty(self): if self.is_empty(): raise ValueError("MRData object is empty.") - def is_cov_normalized( - self, covs: Union[list[str], str, None] = None - ) -> bool: + def is_cov_normalized(self, covs: list[str] | str | None = None) -> bool: """Return true when covariates are normalized.""" if covs is None: covs = list(self.covs.keys()) @@ -237,11 +235,11 @@ def reset(self): def load_df( self, data: pd.DataFrame, - col_obs: Union[str, None] = None, - col_obs_se: Union[str, None] = None, - col_covs: Union[list[str], None] = None, - col_study_id: Union[str, None] = None, - col_data_id: Union[str, None] = None, + col_obs: str | None = None, + col_obs_se: str | None = None, + col_covs: list[str] | None = None, + col_study_id: str | None = None, + col_data_id: str | None = None, ): """Load data from data frame.""" self.reset() @@ -273,10 +271,10 @@ def load_df( def load_xr( self, data, - var_obs: Union[str, None] = None, - var_obs_se: Union[str, None] = None, - var_covs: Union[list[str], None] = None, - coord_study_id: Union[str, None] = None, + var_obs: str | None = None, + var_obs_se: str | None = None, + var_covs: list[str] | None = None, + coord_study_id: str | None = None, ): """Load data from xarray.""" self.reset() @@ -314,11 +312,11 @@ def to_df(self) -> pd.DataFrame: return df - def has_covs(self, covs: Union[list[str], str]) -> bool: + def has_covs(self, covs: list[str] | str) -> bool: """If the data has the provided covariates. Args: - covs (Union[list[str], str]): + covs (list[str] | str): list of covariate names or one covariate name. Returns: @@ -330,11 +328,11 @@ def has_covs(self, covs: Union[list[str], str]) -> bool: else: return all([cov in self.covs for cov in covs]) - def has_studies(self, studies: Union[list[Any], Any]) -> bool: + def has_studies(self, studies: list[Any] | Any) -> bool: """If the data has provided study_id Args: - studies Union[list[Any], Any]: + studies list[Any] | Any: list of studies or one study. Returns: @@ -346,7 +344,7 @@ def has_studies(self, studies: Union[list[Any], Any]) -> bool: else: return all([study in self.studies for study in studies]) - def _assert_has_covs(self, covs: Union[list[str], str]): + def _assert_has_covs(self, covs: list[str] | str): """Assert has covariates otherwise raise ValueError.""" if not self.has_covs(covs): covs = to_list(covs) @@ -355,7 +353,7 @@ def _assert_has_covs(self, covs: Union[list[str], str]): f"MRData object do not contain covariates: {missing_covs}." ) - def _assert_has_studies(self, studies: Union[list[Any], Any]): + def _assert_has_studies(self, studies: list[Any] | Any): """Assert has studies otherwise raise ValueError.""" if not self.has_studies(studies): studies = to_list(studies) @@ -366,11 +364,11 @@ def _assert_has_studies(self, studies: Union[list[Any], Any]): f"MRData object do not contain studies: {missing_studies}." ) - def get_covs(self, covs: Union[list[str], str]) -> np.ndarray: + def get_covs(self, covs: list[str] | str) -> np.ndarray: """Get covariate matrix. Args: - covs (Union[list[str], str]): + covs (list[str] | str): list of covariate names or one covariate name. Returns: @@ -385,11 +383,11 @@ def get_covs(self, covs: Union[list[str], str]) -> np.ndarray: [self.covs[cov_names][:, None] for cov_names in covs] ) - def get_study_data(self, studies: Union[list[Any], Any]) -> "MRData": + def get_study_data(self, studies: list[Any] | Any) -> "MRData": """Get study specific data. Args: - studies (Union[list[Any], Any]): list of studies or one study. + studies (list[Any] | Any): list of studies or one study. Returns MRData: Data object contains the study specific data. @@ -399,7 +397,7 @@ def get_study_data(self, studies: Union[list[Any], Any]) -> "MRData": index = np.array([study in studies for study in self.study_id]) return self._get_data(index) - def normalize_covs(self, covs: Union[list[str], str, None] = None): + def normalize_covs(self, covs: list[str] | str | None = None): """Normalize covariates by the largest absolute value for each covariate.""" if covs is None: covs = list(self.covs.keys()) diff --git a/src/mrtool/core/model.py b/src/mrtool/core/model.py index c25eac5..4298b62 100644 --- a/src/mrtool/core/model.py +++ b/src/mrtool/core/model.py @@ -7,7 +7,6 @@ """ from copy import deepcopy -from typing import Union import numpy as np import pandas as pd @@ -453,7 +452,7 @@ def __init__( data: MRData, ensemble_cov_model: CovModel, ensemble_knots: NDArray, - cov_models: Union[list[CovModel], None] = None, + cov_models: list[CovModel] | None = None, inlier_pct: float = 1.0, ): """Constructor of `MRBeRT` @@ -462,7 +461,7 @@ def __init__( data (MRData): Data for meta-regression. ensemble_cov_model (CovModel): Covariates model which will be used with ensemble. - cov_models (Union[list[CovModel], None], optional): + cov_models (list[CovModel] | None, optional): Other covariate models, assume to be mutual exclusive with ensemble_cov_mdoel. inlier_pct (float): A float number between 0 and 1 indicate the percentage of inliers. """ diff --git a/src/mrtool/core/utils.py b/src/mrtool/core/utils.py index 859610d..d620f97 100644 --- a/src/mrtool/core/utils.py +++ b/src/mrtool/core/utils.py @@ -5,7 +5,7 @@ `utils` module of the `mrtool` package. """ -from typing import Any, Union +from typing import Any import numpy as np import pandas as pd @@ -294,7 +294,7 @@ def avg_integral(mat, spline=None, use_spline_intercept=False): def sample_knots( num_knots: int, knot_bounds: np.ndarray, - min_dist: Union[float, np.ndarray], + min_dist: float | np.ndarray, num_samples: int = 1, ) -> np.ndarray: """Sample knot vectors given a set of rules. @@ -367,9 +367,7 @@ def _check_knot_bounds(num_knots: int, knot_bounds: np.ndarray) -> np.ndarray: return knot_bounds -def _check_min_dist( - num_knots: int, min_dist: Union[float, np.ndarray] -) -> np.ndarray: +def _check_min_dist(num_knots: int, min_dist: float | np.ndarray) -> np.ndarray: """Check knot min_dist.""" if np.isscalar(min_dist): min_dist = np.tile(min_dist, num_knots + 1) diff --git a/src/mrtool/cov_selection/covfinder.py b/src/mrtool/cov_selection/covfinder.py index 04423e4..1125fc6 100644 --- a/src/mrtool/cov_selection/covfinder.py +++ b/src/mrtool/cov_selection/covfinder.py @@ -6,7 +6,6 @@ import warnings from copy import deepcopy -from typing import Union import numpy as np @@ -23,7 +22,7 @@ def __init__( self, data: MRData, covs: list[str], - pre_selected_covs: Union[list[str], None] = None, + pre_selected_covs: list[str] | None = None, normalized_covs: bool = True, num_samples: int = 1000, laplace_threshold: float = 1e-5, @@ -34,7 +33,7 @@ def __init__( beta_gprior: dict[str, np.ndarray] = None, beta_gprior_std: float = 1.0, bias_zero: bool = False, - use_re: Union[dict, None] = None, + use_re: dict | None = None, ): """Covariate Finder. @@ -59,7 +58,7 @@ def __init__( beta_gprior_std (float, optional): Loose beta Gaussian prior standard deviation. Default to 1. bias_zero (bool, optional): If `True`, fit when specify the Gaussian prior it will be mean zero. Default to `False`. - use_re (Union[dict, None], optional): + use_re (dict | None, optional): A dictionary of use_re for each covariate. When `None` we have an uninformative prior for the random effects variance. Default to `None`. """