From bd303fb79a9991c92d38b94f3ca9f195fd3b60a0 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 12 Jan 2023 16:21:02 -0800 Subject: [PATCH] fix for auto splitting of args and kwargs into imagen, for https://github.com/lucidrains/imagen-pytorch/issues/301 --- imagen_pytorch/trainer.py | 2 +- imagen_pytorch/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/imagen_pytorch/trainer.py b/imagen_pytorch/trainer.py index f74487a..c0ace12 100644 --- a/imagen_pytorch/trainer.py +++ b/imagen_pytorch/trainer.py @@ -178,7 +178,7 @@ def split_args_and_kwargs(*args, split_size = None, **kwargs): split_kwargs_index = len_all_args - dict_len split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args] - chunk_sizes = tuple(map(len, split_all_args[0])) + chunk_sizes = tuple(map(len, first_tensor.split(split_size, dim = 0))) for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:] diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index 74de0ac..c22f0b5 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.18.7' +__version__ = '1.18.9'