From 7fabeb387336d5a018bc218d5d83e3a4d60585b5 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Thu, 13 Oct 2022 17:03:57 +0200 Subject: [PATCH] fix: t5 embedder device --- audio_diffusion_pytorch/modules.py | 6 +++++- setup.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index 1ac6e17..b8248c5 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -1153,10 +1153,14 @@ def forward(self, texts: List[str]) -> Tensor: return_tensors="pt", ) + device = next(self.transformer.parameters()).device + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device) + self.transformer.eval() embedding = self.transformer( - input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"] + input_ids=input_ids, attention_mask=attention_mask )["last_hidden_state"] return embedding diff --git a/setup.py b/setup.py index 20608e1..18d770d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.63", + version="0.0.64", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",