Skip to content

Commit

Permalink
update jax array typing and modernize other type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengp0 committed Mar 27, 2024
1 parent 1f66340 commit 5a703d1
Show file tree
Hide file tree
Showing 18 changed files with 763 additions and 687 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "regmod"
version = "0.1.1"
version = "0.1.2"
description = "General regression models"
readme = "README.rst"
requires-python = ">=3.10"
Expand All @@ -18,8 +18,8 @@ dependencies = [
"pandas",
"xspline==0.0.7",
"msca",
"jax[cpu]==0.4.5",
"jaxlib==0.4.4",
"jax",
"jaxlib",
]

[project.optional-dependencies]
Expand Down
6 changes: 6 additions & 0 deletions src/regmod/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Any, Callable
from collections.abc import Iterable
from numpy.typing import ArrayLike, NDArray
from pandas import DataFrame
from jax import Array as JaxArray
from msca.linalg.matrix import Matrix
64 changes: 34 additions & 30 deletions src/regmod/data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""
Data Module
"""

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

import numpy as np
from numpy import ndarray
from pandas import DataFrame
from regmod._typing import NDArray, DataFrame


@dataclass
Expand All @@ -15,9 +14,9 @@ class Data:
Parameters
----------
col_obs : Optional[Union[str, List[str]]], optional
col_obs : str, optional
Column name(s) for observation. Default is `None`.
col_covs : List[str], optional
col_covs : list[str], optional
Column names for covariates. Default is an empty list.
col_weights : str, default="weights"
Column name for weights. Default is `'weights'`. If `col_weights` is
Expand All @@ -40,10 +39,10 @@ class Data:
offset
trim_weights
num_obs
col_obs : Optional[Union[str, List[str]]]
col_obs : str, optional
Column name for observation, can be a single string, a list of string or
`None`. When it is `None` you cannot access property `obs`.
col_covs : List[str]
col_covs : list[str]
A list of column names for covariates.
col_weights : str
Column name for weights. `weights` can be used in the likelihood
Expand All @@ -59,7 +58,7 @@ class Data:
`col_offset` will be added to the data frame filled with 0.
df : pd.DataFrame
Data frame for the object. Default is an empty data frame.
cols : List[str]
cols : list[str]
All the relevant columns, including, `col_obs` (if not `None`),
`col_covs`, `col_weights`, `col_offset` and `'trim_weights'`.
Expand Down Expand Up @@ -89,17 +88,18 @@ class Data:
-----
* This class should be replaced by a subclass of a more general dataclass
* `get_covs` seems very redundant should only keep `get_cols`.
"""

col_obs: Optional[Union[str, List[str]]] = None
col_covs: List[str] = field(default_factory=list)
col_obs: str | None = None
col_covs: list[str] = field(default_factory=list)
col_weights: str = "weights"
col_offset: str = "offset"
df: DataFrame = field(default_factory=DataFrame)
subset_cols: bool = False

def __post_init__(self):
self.col_covs = list(set(self.col_covs).union({'intercept'}))
self.col_covs = list(set(self.col_covs).union({"intercept"}))
self.cols = self.col_covs + [self.col_weights, self.col_offset, "trim_weights"]
if self.col_obs is not None:
if isinstance(self.col_obs, str):
Expand All @@ -121,6 +121,7 @@ def is_empty(self) -> bool:
-------
bool
Return `True` when `self.df` is empty.
"""
return self.num_obs == 0

Expand All @@ -131,24 +132,26 @@ def check_cols(self) -> None:
------
ValueError
Raised if any col in `self.cols` is not in `self.df`.
"""
for col in self.cols:
if col not in self.df.columns:
raise ValueError(f"Missing columnn {col}.")

def parse_df(self, df: Optional[DataFrame] = None) -> DataFrame:
def parse_df(self, df: DataFrame | None = None) -> DataFrame:
"""Subset `df` with `self.cols`.
Parameters
----------
df : Optional[DataFrame], optional
df : DataFrame, optional
Data Frame used to create subset. When it is `None`, it will use
`self.df`.
Returns
-------
DataFrame
Copy of input data frame with given subset columns.
"""
if df is not None:
if self.subset_cols:
Expand All @@ -159,6 +162,7 @@ def parse_df(self, df: Optional[DataFrame] = None) -> DataFrame:
def fill_df(self) -> None:
"""Automatically add columns `'intercept'`, `col_weights`, `col_offset`
and `'trim_weights'`, if they are not present in the `self.df`.
"""
if "intercept" not in self.df.columns:
self.df["intercept"] = 1.0
Expand All @@ -184,6 +188,7 @@ def attach_df(self, df: DataFrame):
----------
df : DataFrame
Data frame to be attached.
"""
self.parse_df(df)
self.fill_df()
Expand All @@ -202,26 +207,24 @@ def copy(self, with_df=False) -> "Data":
-------
Data
Copied instance of `Data` class.
"""
df = self.df.copy() if with_df else DataFrame(columns=self.cols)
return Data(self.col_obs,
self.col_covs,
self.col_weights,
self.col_offset,
df)
return Data(self.col_obs, self.col_covs, self.col_weights, self.col_offset, df)

def get_cols(self, cols: Union[List[str], str]) -> ndarray:
def get_cols(self, cols: str | list[str]) -> NDArray:
"""Access columns in `self.df`.
Parameters
----------
cols : Union[List[str], str]
cols : str | list[str]
Column name(s) need to accessed.
Returns
-------
ndarray
NDArray
Numpy array corresponding to the column(s).
"""
return self.df[cols].to_numpy()

Expand All @@ -231,53 +234,54 @@ def num_obs(self) -> int:
return self.df.shape[0]

@property
def obs(self) -> ndarray:
def obs(self) -> NDArray:
"""Observation column(s)."""
if self.col_obs is None:
raise ValueError("This data object does not contain observations.")
return self.get_cols(self.col_obs)

@property
def covs(self) -> Dict[str, ndarray]:
def covs(self) -> dict[str, NDArray]:
"""Covariates dictionary with column names as keys and corresponding
numpy array as the column.
"""
return self.df[self.col_covs].to_dict(orient="list")

@property
def weights(self) -> ndarray:
def weights(self) -> NDArray:
"""Weights column."""
return self.get_cols(self.col_weights)

@property
def offset(self) -> ndarray:
def offset(self) -> NDArray:
"""Offset column."""
return self.get_cols(self.col_offset)

@property
def trim_weights(self) -> ndarray:
def trim_weights(self) -> NDArray:
"""Trimming weights column."""
return self.get_cols("trim_weights")

@trim_weights.setter
def trim_weights(self, weights: Union[float, ndarray]):
def trim_weights(self, weights: float | NDArray):
if np.any(weights < 0.0) or np.any(weights > 1.0):
raise ValueError("trim_weights has to be between 0 and 1.")
self.df["trim_weights"] = weights

def get_covs(self, col_covs: Union[List[str], str]) -> ndarray:
def get_covs(self, col_covs: str | list[str]) -> NDArray:
"""Access covariates in `self.df`.
Parameters
----------
col_covs : Union[List[str], str]
col_covs : str | list[str]
Column name(s) of the covariates.
Returns
-------
ndarray
NDArray
Return the corresponding column(s) in the data frame. Always return
matrix even if `col_covs` is a single string.
"""
if not isinstance(col_covs, list):
col_covs = [col_covs]
Expand Down
82 changes: 41 additions & 41 deletions src/regmod/function.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""
Function module
"""

from dataclasses import dataclass, field
from typing import Callable

import numpy as np
from regmod._typing import Callable


@dataclass
Expand Down Expand Up @@ -60,8 +61,8 @@ def exp_d2fun(x):

def expit_fun(x):
neg_indices = x < 0
z = np.exp(-np.sqrt(x*x))
y = 1/(1 + z)
z = np.exp(-np.sqrt(x * x))
y = 1 / (1 + z)
if np.isscalar(x):
if neg_indices:
y = 1 - y
Expand All @@ -71,15 +72,15 @@ def expit_fun(x):


def expit_dfun(x):
z = np.exp(-np.sqrt(x*x))
y = z/(1 + z)**2
z = np.exp(-np.sqrt(x * x))
y = z / (1 + z) ** 2
return y


def expit_d2fun(x):
neg_indices = x < 0
z = np.exp(-np.sqrt(x*x))
y = z*(z - 1)/(z + 1)**3
z = np.exp(-np.sqrt(x * x))
y = z * (z - 1) / (z + 1) ** 3
if np.isscalar(x):
if neg_indices:
y = -y
Expand All @@ -93,55 +94,54 @@ def log_fun(x):


def log_dfun(x):
return 1/x
return 1 / x


def log_d2fun(x):
return -1/x**2
return -1 / x**2


def logit_fun(x):
return np.log(x/(1 - x))
return np.log(x / (1 - x))


def logit_dfun(x):
return 1/(x*(1 - x))
return 1 / (x * (1 - x))


def logit_d2fun(x):
return (2*x - 1)/(x*(1 - x))**2
return (2 * x - 1) / (x * (1 - x)) ** 2


fun_list = [
SmoothFunction(name="identity",
fun=identity_fun,
inv_fun=identity_fun,
dfun=identity_dfun,
d2fun=identity_d2fun),
SmoothFunction(name="exp",
fun=exp_fun,
inv_fun=log_fun,
dfun=exp_dfun,
d2fun=exp_d2fun),
SmoothFunction(name="expit",
fun=expit_fun,
inv_fun=logit_fun,
dfun=expit_dfun,
d2fun=expit_d2fun),
SmoothFunction(name="log",
fun=log_fun,
inv_fun=exp_fun,
dfun=log_dfun,
d2fun=log_d2fun),
SmoothFunction(name="logit",
fun=logit_fun,
inv_fun=expit_fun,
dfun=logit_dfun,
d2fun=logit_d2fun),
SmoothFunction(
name="identity",
fun=identity_fun,
inv_fun=identity_fun,
dfun=identity_dfun,
d2fun=identity_d2fun,
),
SmoothFunction(
name="exp", fun=exp_fun, inv_fun=log_fun, dfun=exp_dfun, d2fun=exp_d2fun
),
SmoothFunction(
name="expit",
fun=expit_fun,
inv_fun=logit_fun,
dfun=expit_dfun,
d2fun=expit_d2fun,
),
SmoothFunction(
name="log", fun=log_fun, inv_fun=exp_fun, dfun=log_dfun, d2fun=log_d2fun
),
SmoothFunction(
name="logit",
fun=logit_fun,
inv_fun=expit_fun,
dfun=logit_dfun,
d2fun=logit_d2fun,
),
]


fun_dict = {
fun.name: fun
for fun in fun_list
}
fun_dict = {fun.name: fun for fun in fun_list}
Loading

0 comments on commit 5a703d1

Please sign in to comment.