Skip to content

Encountering a typeerror while fitting the lightweightmmm modelΒ #345

Open
@IG2804

Description

@IG2804

I was just trying to run the example notebook - end_to_end_demo_with_multiple_geos as is without any modification.
The part where I run the following:

mmm.fit(
media=media_data_train,
media_prior=costs,
target=target_train,
extra_features=extra_features_train,
number_warmup=number_warmup,
number_samples=number_samples,
seed=SEED)

It throws an error: Typeerror :
TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y'

The error is pointing to the following code snippet from the media_transforms.py:
@jax.jit
def apply_exponent_safe(
data: jnp.ndarray,
exponent: jnp.ndarray,
) -> jnp.ndarray:
"""Applies an exponent to given data in a gradient safe way.

More info on the double jnp.where can be found:
https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf

Args:
data: Input data to use.
exponent: Exponent required for the operations.

Returns:
The result of the exponent operation with the inputs provided.
"""
exponent_safe = jnp.where(data == 0, 1, data) ** exponent
return jnp.where(data == 0, 0, exponent_safe)

Would appreciate any help in troubleshooting/debugging this. Has anyone faced this issue before?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions