diff --git a/README.md b/README.md index 1e533a6..a7d725a 100644 --- a/README.md +++ b/README.md @@ -69,8 +69,8 @@ sampled = cfm_wrapper.sample( - [x] basic loss - [x] get neural ode working with torchdyn - [x] get basic mask generation logic with the p_drop of 0.2-0.3 for ICL -- [x] just use torchdiffeq, nothing else is mature. torchode looks promising but cannot support ndim > 2 - [x] take care of p_drop, different between voicebox and duration model +- [x] support torchdiffeq and torchode - [ ] consider switching to adaptive rmsnorm for time conditioning - [ ] integrate with either hifi-gan and soundstream / encodec diff --git a/setup.py b/setup.py index 1dac4ca..b642810 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.0.15', + version = '0.0.16', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', @@ -20,7 +20,8 @@ 'beartype', 'einops>=0.6.1', 'torch>=2.0', - 'torchdiffeq' + 'torchdiffeq', + 'torchode' ], classifiers=[ 'Development Status :: 4 - Beta', diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index 6288d97..66b56c6 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -6,6 +6,9 @@ from torch.nn import Module import torch.nn.functional as F +import torchode as to +from torchode.single_step_methods import SingleStepMethod + from torchdiffeq import odeint from beartype import beartype @@ -556,8 +559,10 @@ def __init__( sigma = 0., ode_atol = 1e-5, ode_rtol = 1e-5, - ode_method = 'midpoint', ode_step_size = 0.0625, + use_torchode = False, + torchdiffeq_ode_method = 'midpoint', # use midpoint for torchdiffeq, as in paper + torchode_method_klass: SingleStepMethod = to.Tsit5, # use tsit5 for torchode, as torchode does not have midpoint (recommended by Bryan @b-chiang) cond_drop_prob = 0. ): super().__init__() @@ -567,10 +572,13 @@ def __init__( self.cond_drop_prob = cond_drop_prob + self.use_torchode = use_torchode + self.torchode_method_klass = torchode_method_klass + self.odeint_kwargs = dict( atol = ode_atol, rtol = ode_rtol, - method = ode_method, + method = torchdiffeq_ode_method, options = dict(step_size = ode_step_size) ) @@ -585,13 +593,18 @@ def sample( phoneme_ids, cond, mask = None, - steps = 2, + steps = 3, cond_scale = 1. ): + shape = cond.shape + batch = shape[0] + self.voicebox.eval() def fn(t, x): - return self.voicebox.forward_with_cond_scale( + x = x.reshape(*shape) + + out = self.voicebox.forward_with_cond_scale( x, times = t, phoneme_ids = phoneme_ids, @@ -599,16 +612,41 @@ def fn(t, x): cond_scale = cond_scale ) - batch = cond.shape[0] + return rearrange(out, 'b ... -> b (...)') y0 = torch.randn_like(cond) t = torch.linspace(0, 1, steps, device = self.device) - print('sampling') + if not self.use_torchode: + print('sampling with torchdiffeq') + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + sampled = trajectory[-1] + else: + print('sampling with torchode') + + term = to.ODETerm(fn) + step_method = self.torchode_method_klass(term = term) + + step_size_controller = to.IntegralController( + atol = self.odeint_kwargs['atol'], + rtol = self.odeint_kwargs['rtol'], + term = term + ) + + solver = to.AutoDiffAdjoint(step_method, step_size_controller) + jit_solver = torch.compile(solver) + + t = repeat(t, 'n -> b n', b = batch) + y0 = rearrange(y0, 'b ... -> b (...)') + + init_value = to.InitialValueProblem(y0 = y0, t_eval = t) + + sol = jit_solver.solve(init_value) - trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + sampled = sol.ys[:, -1] + sampled = sampled.reshape(*shape) - sampled = trajectory[-1] # last in trajectory return sampled def forward(