diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index 1ac6e17..b8248c5 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -1153,10 +1153,14 @@ def forward(self, texts: List[str]) -> Tensor: return_tensors="pt", ) + device = next(self.transformer.parameters()).device + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device) + self.transformer.eval() embedding = self.transformer( - input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"] + input_ids=input_ids, attention_mask=attention_mask )["last_hidden_state"] return embedding diff --git a/setup.py b/setup.py index 20608e1..18d770d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.63", + version="0.0.64", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",