Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the option to put everything on device for faster inference speed #628

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions tortoise/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class TextToSpeech:

def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR,
enable_redaction=True, kv_cache=False, use_deepspeed=False, half=False, device=None,
tokenizer_vocab_file=None, tokenizer_basic=False):
tokenizer_vocab_file=None, tokenizer_basic=False, device_only=False):

"""
Constructor
Expand Down Expand Up @@ -263,16 +263,27 @@ def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR,
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
self.vocoder.eval(inference=True)

self.device_only = device_only

if self.device_only:
self.autoregressive = self.autoregressive.to(self.device)
self.diffusion = self.diffusion.to(self.device)
self.clvp = self.clvp.to(self.device)
self.vocoder = self.vocoder.to(self.device)

# Random latent generators (RLGs) are loaded lazily.
self.rlg_auto = None
self.rlg_diffusion = None

@contextmanager
def temporary_cuda(self, model):
m = model.to(self.device)
yield m
m = model.cpu()
if not self.device_only:
m = model.to(self.device)
yield m
m = model.cpu()
else:
yield model


def load_cvvp(self):
"""Load CVVP model."""
self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
Expand All @@ -295,22 +306,31 @@ def get_conditioning_latents(self, voice_samples, return_mels=False):
for vs in voice_samples:
auto_conds.append(format_conditioning(vs, device=self.device))
auto_conds = torch.stack(auto_conds, dim=1)
self.autoregressive = self.autoregressive.to(self.device)
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = self.autoregressive.cpu()

if not self.device_only:
self.autoregressive = self.autoregressive.to(self.device)
auto_latent = self.autoregressive.get_conditioning(auto_conds)
self.autoregressive = self.autoregressive.cpu()
else:
auto_latent = self.autoregressive.get_conditioning(auto_conds)

diffusion_conds = []

for sample in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(sample.to(self.device), do_normalization=False, device=self.device)
diffusion_conds.append(cond_mel)

diffusion_conds = torch.stack(diffusion_conds, dim=1)

self.diffusion = self.diffusion.to(self.device)
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = self.diffusion.cpu()
if not self.device_only:
self.diffusion = self.diffusion.to(self.device)
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)
self.diffusion = self.diffusion.cpu()
else:
diffusion_latent = self.diffusion.get_conditioning(diffusion_conds)

if return_mels:
return auto_latent, diffusion_latent, auto_conds, diffusion_conds
Expand Down