Skip to content

Commit

Permalink
remove all Union type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengp0 committed Jun 25, 2024
1 parent 07d9cc1 commit 8dccc93
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 36 deletions.
46 changes: 22 additions & 24 deletions src/mrtool/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import warnings
from dataclasses import dataclass, field
from typing import Any, Union
from typing import Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -208,9 +208,7 @@ def _assert_not_empty(self):
if self.is_empty():
raise ValueError("MRData object is empty.")

def is_cov_normalized(
self, covs: Union[list[str], str, None] = None
) -> bool:
def is_cov_normalized(self, covs: list[str] | str | None = None) -> bool:
"""Return true when covariates are normalized."""
if covs is None:
covs = list(self.covs.keys())
Expand All @@ -237,11 +235,11 @@ def reset(self):
def load_df(
self,
data: pd.DataFrame,
col_obs: Union[str, None] = None,
col_obs_se: Union[str, None] = None,
col_covs: Union[list[str], None] = None,
col_study_id: Union[str, None] = None,
col_data_id: Union[str, None] = None,
col_obs: str | None = None,
col_obs_se: str | None = None,
col_covs: list[str] | None = None,
col_study_id: str | None = None,
col_data_id: str | None = None,
):
"""Load data from data frame."""
self.reset()
Expand Down Expand Up @@ -273,10 +271,10 @@ def load_df(
def load_xr(
self,
data,
var_obs: Union[str, None] = None,
var_obs_se: Union[str, None] = None,
var_covs: Union[list[str], None] = None,
coord_study_id: Union[str, None] = None,
var_obs: str | None = None,
var_obs_se: str | None = None,
var_covs: list[str] | None = None,
coord_study_id: str | None = None,
):
"""Load data from xarray."""
self.reset()
Expand Down Expand Up @@ -314,11 +312,11 @@ def to_df(self) -> pd.DataFrame:

return df

def has_covs(self, covs: Union[list[str], str]) -> bool:
def has_covs(self, covs: list[str] | str) -> bool:
"""If the data has the provided covariates.
Args:
covs (Union[list[str], str]):
covs (list[str] | str):
list of covariate names or one covariate name.
Returns:
Expand All @@ -330,11 +328,11 @@ def has_covs(self, covs: Union[list[str], str]) -> bool:
else:
return all([cov in self.covs for cov in covs])

def has_studies(self, studies: Union[list[Any], Any]) -> bool:
def has_studies(self, studies: list[Any] | Any) -> bool:
"""If the data has provided study_id
Args:
studies Union[list[Any], Any]:
studies list[Any] | Any:
list of studies or one study.
Returns:
Expand All @@ -346,7 +344,7 @@ def has_studies(self, studies: Union[list[Any], Any]) -> bool:
else:
return all([study in self.studies for study in studies])

def _assert_has_covs(self, covs: Union[list[str], str]):
def _assert_has_covs(self, covs: list[str] | str):
"""Assert has covariates otherwise raise ValueError."""
if not self.has_covs(covs):
covs = to_list(covs)
Expand All @@ -355,7 +353,7 @@ def _assert_has_covs(self, covs: Union[list[str], str]):
f"MRData object do not contain covariates: {missing_covs}."
)

def _assert_has_studies(self, studies: Union[list[Any], Any]):
def _assert_has_studies(self, studies: list[Any] | Any):
"""Assert has studies otherwise raise ValueError."""
if not self.has_studies(studies):
studies = to_list(studies)
Expand All @@ -366,11 +364,11 @@ def _assert_has_studies(self, studies: Union[list[Any], Any]):
f"MRData object do not contain studies: {missing_studies}."
)

def get_covs(self, covs: Union[list[str], str]) -> np.ndarray:
def get_covs(self, covs: list[str] | str) -> np.ndarray:
"""Get covariate matrix.
Args:
covs (Union[list[str], str]):
covs (list[str] | str):
list of covariate names or one covariate name.
Returns:
Expand All @@ -385,11 +383,11 @@ def get_covs(self, covs: Union[list[str], str]) -> np.ndarray:
[self.covs[cov_names][:, None] for cov_names in covs]
)

def get_study_data(self, studies: Union[list[Any], Any]) -> "MRData":
def get_study_data(self, studies: list[Any] | Any) -> "MRData":
"""Get study specific data.
Args:
studies (Union[list[Any], Any]): list of studies or one study.
studies (list[Any] | Any): list of studies or one study.
Returns
MRData: Data object contains the study specific data.
Expand All @@ -399,7 +397,7 @@ def get_study_data(self, studies: Union[list[Any], Any]) -> "MRData":
index = np.array([study in studies for study in self.study_id])
return self._get_data(index)

def normalize_covs(self, covs: Union[list[str], str, None] = None):
def normalize_covs(self, covs: list[str] | str | None = None):
"""Normalize covariates by the largest absolute value for each covariate."""
if covs is None:
covs = list(self.covs.keys())
Expand Down
5 changes: 2 additions & 3 deletions src/mrtool/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

from copy import deepcopy
from typing import Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -453,7 +452,7 @@ def __init__(
data: MRData,
ensemble_cov_model: CovModel,
ensemble_knots: NDArray,
cov_models: Union[list[CovModel], None] = None,
cov_models: list[CovModel] | None = None,
inlier_pct: float = 1.0,
):
"""Constructor of `MRBeRT`
Expand All @@ -462,7 +461,7 @@ def __init__(
data (MRData): Data for meta-regression.
ensemble_cov_model (CovModel):
Covariates model which will be used with ensemble.
cov_models (Union[list[CovModel], None], optional):
cov_models (list[CovModel] | None, optional):
Other covariate models, assume to be mutual exclusive with ensemble_cov_mdoel.
inlier_pct (float): A float number between 0 and 1 indicate the percentage of inliers.
"""
Expand Down
8 changes: 3 additions & 5 deletions src/mrtool/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
`utils` module of the `mrtool` package.
"""

from typing import Any, Union
from typing import Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -294,7 +294,7 @@ def avg_integral(mat, spline=None, use_spline_intercept=False):
def sample_knots(
num_knots: int,
knot_bounds: np.ndarray,
min_dist: Union[float, np.ndarray],
min_dist: float | np.ndarray,
num_samples: int = 1,
) -> np.ndarray:
"""Sample knot vectors given a set of rules.
Expand Down Expand Up @@ -367,9 +367,7 @@ def _check_knot_bounds(num_knots: int, knot_bounds: np.ndarray) -> np.ndarray:
return knot_bounds


def _check_min_dist(
num_knots: int, min_dist: Union[float, np.ndarray]
) -> np.ndarray:
def _check_min_dist(num_knots: int, min_dist: float | np.ndarray) -> np.ndarray:
"""Check knot min_dist."""
if np.isscalar(min_dist):
min_dist = np.tile(min_dist, num_knots + 1)
Expand Down
7 changes: 3 additions & 4 deletions src/mrtool/cov_selection/covfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import warnings
from copy import deepcopy
from typing import Union

import numpy as np

Expand All @@ -23,7 +22,7 @@ def __init__(
self,
data: MRData,
covs: list[str],
pre_selected_covs: Union[list[str], None] = None,
pre_selected_covs: list[str] | None = None,
normalized_covs: bool = True,
num_samples: int = 1000,
laplace_threshold: float = 1e-5,
Expand All @@ -34,7 +33,7 @@ def __init__(
beta_gprior: dict[str, np.ndarray] = None,
beta_gprior_std: float = 1.0,
bias_zero: bool = False,
use_re: Union[dict, None] = None,
use_re: dict | None = None,
):
"""Covariate Finder.
Expand All @@ -59,7 +58,7 @@ def __init__(
beta_gprior_std (float, optional): Loose beta Gaussian prior standard deviation. Default to 1.
bias_zero (bool, optional):
If `True`, fit when specify the Gaussian prior it will be mean zero. Default to `False`.
use_re (Union[dict, None], optional):
use_re (dict | None, optional):
A dictionary of use_re for each covariate. When `None` we have an uninformative prior
for the random effects variance. Default to `None`.
"""
Expand Down

0 comments on commit 8dccc93

Please sign in to comment.