Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align casing of NumPyro with canonical spelling #291

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Taking a Bayesian approach to MMM allows an advertiser to integrate prior inform
- Report on both parameter and model uncertainty and propagate it to your budget optimisation.
- Construct hierarchical models, with generally tighter credible intervals, using breakout dimensions such as geography.

The LightweightMMM package (built using [Numpyro](https://github.com/pyro-ppl/numpyro) and [JAX](https://github.com/google/jax)) helps advertisers easily build Bayesian MMM models by providing the functionality to appropriately scale data, evaluate models, optimise budget allocations and plot common graphs used in the field.
The LightweightMMM package (built using [NumPyro](https://github.com/pyro-ppl/numpyro) and [JAX](https://github.com/google/jax)) helps advertisers easily build Bayesian MMM models by providing the functionality to appropriately scale data, evaluate models, optimise budget allocations and plot common graphs used in the field.

## Theory

Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ library that allows users to easily train MMMs and obtain channel attribution
information. The library also includes capabilities for optimizing media
allocation as well as plotting common graphs in the field.

It is built in python3 and makes use of Numpyro and JAX.
It is built in python3 and makes use of NumPyro and JAX.

Installation
------------
Expand Down
4 changes: 2 additions & 2 deletions lightweight_mmm/core/transformations/lagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def carryover(
custom_priors: The custom priors we want the model to take instead of the
default ones.
number_lags: Number of lags for the carryover function.
prefix: Prefix to use in the variable name for Numpyro.
prefix: Prefix to use in the variable name for NumPyro.

Returns:
The transformed media data.
Expand Down Expand Up @@ -178,7 +178,7 @@ def adstock(
default ones. The possible names of parameters for adstock and exponent
are "lag_weight" and "exponent".
normalise: Whether to normalise the output values.
prefix: Prefix to use in the variable name for Numpyro.
prefix: Prefix to use in the variable name for NumPyro.

Returns:
The transformed media data.
Expand Down
4 changes: 2 additions & 2 deletions lightweight_mmm/core/transformations/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def hill(
custom_priors: The custom priors we want the model to take instead of the
default ones. The possible names of parameters for hill_adstock and
exponent are "lag_weight", "half_max_effective_concentration" and "slope".
prefix: Prefix to use in the variable name for Numpyro.
prefix: Prefix to use in the variable name for NumPyro.

Returns:
The transformed media data.
Expand Down Expand Up @@ -112,7 +112,7 @@ def exponent(
national models and 3 for geo models.
custom_priors: The custom priors we want the model to take instead of the
default ones.
prefix: Prefix to use in the variable name for Numpyro.
prefix: Prefix to use in the variable name for NumPyro.

Returns:
The transformed media data.
Expand Down
16 changes: 8 additions & 8 deletions lightweight_mmm/lightweight_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,17 +214,17 @@ def _create_list_of_attributes_to_compare(
def _preprocess_custom_priors(
self,
custom_priors: Dict[str, Prior]) -> MutableMapping[str, Prior]:
"""Preprocesses the user input custom priors to Numpyro distributions.
"""Preprocesses the user input custom priors to NumPyro distributions.

If numpyro distributions are given they remains untouched, however if any
If NumPyro distributions are given they remains untouched, however if any
other option is passed, it is passed to the default distribution to alter
its constructor values.

Args:
custom_priors: Mapping of the name of the prior to its custom value.

Returns:
A mapping of names to numpyro distributions based on user input and
A mapping of names to NumPyro distributions based on user input and
default values.
"""
default_priors = {
Expand All @@ -246,8 +246,8 @@ def _preprocess_custom_priors(
**custom_priors[prior_name])
elif not isinstance(custom_priors[prior_name], dist.Distribution):
raise ValueError(
"Priors given must be a Numpyro distribution or one of the "
"following to fit in the constructor of our default Numpyro "
"Priors given must be a NumPyro distribution or one of the "
"following to fit in the constructor of our default NumPyro "
"distribution. It could be given as args or kwargs as long as it "
"is the correct format for such object. Please refer to our "
"documentation on custom priors to know more.")
Expand Down Expand Up @@ -295,7 +295,7 @@ def fit(
number_chains: Number of chains to sample. Default is 2.
target_accept_prob: Target acceptance probability for step size in the
NUTS sampler. Default is .85.
init_strategy: Initialization function for numpyro NUTS. The available
init_strategy: Initialization function for NumPyro NUTS. The available
options can be found in
https://num.pyro.ai/en/stable/utilities.html#initialization-strategies.
Default is numpyro.infer.init_to_median.
Expand Down Expand Up @@ -395,7 +395,7 @@ def fit(
logging.info("Model has been fitted")

def print_summary(self) -> None:
"""Calls print_summary function from numpyro to print parameters summary.
"""Calls print_summary function from NumPyro to print parameters summary.
"""
# TODO(): add name selection for print.
self._mcmc.print_summary()
Expand Down Expand Up @@ -431,7 +431,7 @@ def _predict(
frequency: Frequency of the seasonality.
transform_function: Media transform function to use within the model.
weekday_seasonality: Allow daily weekday estimation.
model: Numpyro model to use for numpyro.infer.Predictive.
model: NumPyro model to use for numpyro.infer.Predictive.
posterior_samples: Mapping of the posterior samples.
custom_priors: The custom priors we want the model to take instead of the
default ones. Refer to the full documentation on custom priors for
Expand Down
2 changes: 1 addition & 1 deletion lightweight_mmm/lightweight_mmm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_fit_with_custom_prior_raises_valueerror_if_wrong_format(

with self.assertRaisesRegex(
ValueError,
"Priors given must be a Numpyro distribution or one of the "):
"Priors given must be a NumPyro distribution or one of the "):
mmm_object.fit(
media=media,
target=target,
Expand Down