Description
I was just trying to run the example notebook - end_to_end_demo_with_multiple_geos as is without any modification.
The part where I run the following:
mmm.fit(
media=media_data_train,
media_prior=costs,
target=target_train,
extra_features=extra_features_train,
number_warmup=number_warmup,
number_samples=number_samples,
seed=SEED)
It throws an error: Typeerror :
TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y'
The error is pointing to the following code snippet from the media_transforms.py:
@jax.jit
def apply_exponent_safe(
data: jnp.ndarray,
exponent: jnp.ndarray,
) -> jnp.ndarray:
"""Applies an exponent to given data in a gradient safe way.
More info on the double jnp.where can be found:
https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
Args:
data: Input data to use.
exponent: Exponent required for the operations.
Returns:
The result of the exponent operation with the inputs provided.
"""
exponent_safe = jnp.where(data == 0, 1, data) ** exponent
return jnp.where(data == 0, 0, exponent_safe)
Would appreciate any help in troubleshooting/debugging this. Has anyone faced this issue before?