diff --git a/imagen_pytorch/trainer.py b/imagen_pytorch/trainer.py index f74487a..c0ace12 100644 --- a/imagen_pytorch/trainer.py +++ b/imagen_pytorch/trainer.py @@ -178,7 +178,7 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs): split_kwargs_index = len_all_args - dict_len split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args] - chunk_sizes = tuple(map(len, split_all_args[0])) + chunk_sizes = tuple(map(len, first_tensor.split(split_size, dim = 0))) for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:] diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index 74de0ac..c22f0b5 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.18.7' +__version__ = '1.18.9'