Skip to content

Commit

Permalink
fix for auto splitting of args and kwargs into imagen, for #301
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 13, 2023
1 parent 8e0e97c commit bd303fb
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion imagen_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs):
split_kwargs_index = len_all_args - dict_len

split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
chunk_sizes = tuple(map(len, split_all_args[0]))
chunk_sizes = tuple(map(len, first_tensor.split(split_size, dim = 0)))

for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.18.7'
__version__ = '1.18.9'

0 comments on commit bd303fb

Please sign in to comment.