You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It throws an error: Typeerror :
TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y'
The error is pointing to the following code snippet from the media_transforms.py:
@jax.jit
def apply_exponent_safe(
data: jnp.ndarray,
exponent: jnp.ndarray,
) -> jnp.ndarray:
"""Applies an exponent to given data in a gradient safe way.
Args:
data: Input data to use.
exponent: Exponent required for the operations.
Returns:
The result of the exponent operation with the inputs provided.
"""
exponent_safe = jnp.where(data == 0, 1, data) ** exponent
return jnp.where(data == 0, 0, exponent_safe)
Would appreciate any help in troubleshooting/debugging this. Has anyone faced this issue before?
The text was updated successfully, but these errors were encountered:
I was just trying to run the example notebook - end_to_end_demo_with_multiple_geos as is without any modification.
The part where I run the following:
mmm.fit(
media=media_data_train,
media_prior=costs,
target=target_train,
extra_features=extra_features_train,
number_warmup=number_warmup,
number_samples=number_samples,
seed=SEED)
It throws an error: Typeerror :
TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y'
The error is pointing to the following code snippet from the media_transforms.py:
@jax.jit
def apply_exponent_safe(
data: jnp.ndarray,
exponent: jnp.ndarray,
) -> jnp.ndarray:
"""Applies an exponent to given data in a gradient safe way.
More info on the double jnp.where can be found:
https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
Args:
data: Input data to use.
exponent: Exponent required for the operations.
Returns:
The result of the exponent operation with the inputs provided.
"""
exponent_safe = jnp.where(data == 0, 1, data) ** exponent
return jnp.where(data == 0, 0, exponent_safe)
Would appreciate any help in troubleshooting/debugging this. Has anyone faced this issue before?
The text was updated successfully, but these errors were encountered: