Skip to content

Commit

Permalink
fix: adapt to new jax API
Browse files Browse the repository at this point in the history
  • Loading branch information
GaetanLepage committed Nov 9, 2023
1 parent f0be4de commit 22ffccf
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/dalle_mini/model/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,8 +1599,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion src/dalle_mini/model/partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from flax.core.frozen_dict import freeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.experimental import PartitionSpec as P
from jax.sharding import PartitionSpec as P

# utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
# Sentinels
Expand Down

0 comments on commit 22ffccf

Please sign in to comment.