From d9eef74de0bc89eb2606ad7ff14ad87a46cbacfc Mon Sep 17 00:00:00 2001 From: Eric Leung Date: Thu, 14 Dec 2023 11:42:08 -0500 Subject: [PATCH] Align casing of NumPyro with canonical spelling --- README.md | 2 +- docs/index.rst | 2 +- lightweight_mmm/core/transformations/lagging.py | 4 ++-- .../core/transformations/saturation.py | 4 ++-- lightweight_mmm/lightweight_mmm.py | 16 ++++++++-------- lightweight_mmm/lightweight_mmm_test.py | 2 +- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 49d5e7d..c8f78d4 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Taking a Bayesian approach to MMM allows an advertiser to integrate prior inform - Report on both parameter and model uncertainty and propagate it to your budget optimisation. - Construct hierarchical models, with generally tighter credible intervals, using breakout dimensions such as geography. -The LightweightMMM package (built using [Numpyro](https://github.com/pyro-ppl/numpyro) and [JAX](https://github.com/google/jax)) helps advertisers easily build Bayesian MMM models by providing the functionality to appropriately scale data, evaluate models, optimise budget allocations and plot common graphs used in the field. +The LightweightMMM package (built using [NumPyro](https://github.com/pyro-ppl/numpyro) and [JAX](https://github.com/google/jax)) helps advertisers easily build Bayesian MMM models by providing the functionality to appropriately scale data, evaluate models, optimise budget allocations and plot common graphs used in the field. ## Theory diff --git a/docs/index.rst b/docs/index.rst index b74c4ff..64c8d2c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,7 +8,7 @@ library that allows users to easily train MMMs and obtain channel attribution information. The library also includes capabilities for optimizing media allocation as well as plotting common graphs in the field. -It is built in python3 and makes use of Numpyro and JAX. +It is built in python3 and makes use of NumPyro and JAX. Installation ------------ diff --git a/lightweight_mmm/core/transformations/lagging.py b/lightweight_mmm/core/transformations/lagging.py index 4739ae7..5902243 100644 --- a/lightweight_mmm/core/transformations/lagging.py +++ b/lightweight_mmm/core/transformations/lagging.py @@ -95,7 +95,7 @@ def carryover( custom_priors: The custom priors we want the model to take instead of the default ones. number_lags: Number of lags for the carryover function. - prefix: Prefix to use in the variable name for Numpyro. + prefix: Prefix to use in the variable name for NumPyro. Returns: The transformed media data. @@ -178,7 +178,7 @@ def adstock( default ones. The possible names of parameters for adstock and exponent are "lag_weight" and "exponent". normalise: Whether to normalise the output values. - prefix: Prefix to use in the variable name for Numpyro. + prefix: Prefix to use in the variable name for NumPyro. Returns: The transformed media data. diff --git a/lightweight_mmm/core/transformations/saturation.py b/lightweight_mmm/core/transformations/saturation.py index 8a6f4df..d7d624a 100644 --- a/lightweight_mmm/core/transformations/saturation.py +++ b/lightweight_mmm/core/transformations/saturation.py @@ -62,7 +62,7 @@ def hill( custom_priors: The custom priors we want the model to take instead of the default ones. The possible names of parameters for hill_adstock and exponent are "lag_weight", "half_max_effective_concentration" and "slope". - prefix: Prefix to use in the variable name for Numpyro. + prefix: Prefix to use in the variable name for NumPyro. Returns: The transformed media data. @@ -112,7 +112,7 @@ def exponent( national models and 3 for geo models. custom_priors: The custom priors we want the model to take instead of the default ones. - prefix: Prefix to use in the variable name for Numpyro. + prefix: Prefix to use in the variable name for NumPyro. Returns: The transformed media data. diff --git a/lightweight_mmm/lightweight_mmm.py b/lightweight_mmm/lightweight_mmm.py index b22d2cc..3a52866 100644 --- a/lightweight_mmm/lightweight_mmm.py +++ b/lightweight_mmm/lightweight_mmm.py @@ -214,9 +214,9 @@ def _create_list_of_attributes_to_compare( def _preprocess_custom_priors( self, custom_priors: Dict[str, Prior]) -> MutableMapping[str, Prior]: - """Preprocesses the user input custom priors to Numpyro distributions. + """Preprocesses the user input custom priors to NumPyro distributions. - If numpyro distributions are given they remains untouched, however if any + If NumPyro distributions are given they remains untouched, however if any other option is passed, it is passed to the default distribution to alter its constructor values. @@ -224,7 +224,7 @@ def _preprocess_custom_priors( custom_priors: Mapping of the name of the prior to its custom value. Returns: - A mapping of names to numpyro distributions based on user input and + A mapping of names to NumPyro distributions based on user input and default values. """ default_priors = { @@ -246,8 +246,8 @@ def _preprocess_custom_priors( **custom_priors[prior_name]) elif not isinstance(custom_priors[prior_name], dist.Distribution): raise ValueError( - "Priors given must be a Numpyro distribution or one of the " - "following to fit in the constructor of our default Numpyro " + "Priors given must be a NumPyro distribution or one of the " + "following to fit in the constructor of our default NumPyro " "distribution. It could be given as args or kwargs as long as it " "is the correct format for such object. Please refer to our " "documentation on custom priors to know more.") @@ -295,7 +295,7 @@ def fit( number_chains: Number of chains to sample. Default is 2. target_accept_prob: Target acceptance probability for step size in the NUTS sampler. Default is .85. - init_strategy: Initialization function for numpyro NUTS. The available + init_strategy: Initialization function for NumPyro NUTS. The available options can be found in https://num.pyro.ai/en/stable/utilities.html#initialization-strategies. Default is numpyro.infer.init_to_median. @@ -395,7 +395,7 @@ def fit( logging.info("Model has been fitted") def print_summary(self) -> None: - """Calls print_summary function from numpyro to print parameters summary. + """Calls print_summary function from NumPyro to print parameters summary. """ # TODO(): add name selection for print. self._mcmc.print_summary() @@ -431,7 +431,7 @@ def _predict( frequency: Frequency of the seasonality. transform_function: Media transform function to use within the model. weekday_seasonality: Allow daily weekday estimation. - model: Numpyro model to use for numpyro.infer.Predictive. + model: NumPyro model to use for numpyro.infer.Predictive. posterior_samples: Mapping of the posterior samples. custom_priors: The custom priors we want the model to take instead of the default ones. Refer to the full documentation on custom priors for diff --git a/lightweight_mmm/lightweight_mmm_test.py b/lightweight_mmm/lightweight_mmm_test.py index 8fba670..54b6e9c 100644 --- a/lightweight_mmm/lightweight_mmm_test.py +++ b/lightweight_mmm/lightweight_mmm_test.py @@ -210,7 +210,7 @@ def test_fit_with_custom_prior_raises_valueerror_if_wrong_format( with self.assertRaisesRegex( ValueError, - "Priors given must be a Numpyro distribution or one of the "): + "Priors given must be a NumPyro distribution or one of the "): mmm_object.fit( media=media, target=target,