Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add max_gen_duration_s to waveform generation #447

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion bark/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def semantic_to_waveform(
temp: float = 0.7,
silent: bool = False,
output_full: bool = False,
max_gen_duration_s=None,
):
"""Generate audio array from semantic input.

Expand All @@ -47,6 +48,7 @@ def semantic_to_waveform(
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
silent: disable progress bar
output_full: return full generation to be used as a history prompt
max_gen_duration_s: maximum duration of generated audio in seconds

Returns:
numpy audio array at sample frequency 24khz
Expand All @@ -56,7 +58,8 @@ def semantic_to_waveform(
history_prompt=history_prompt,
temp=temp,
silent=silent,
use_kv_caching=True
use_kv_caching=True,
max_gen_duration_s=max_gen_duration_s,
)
fine_tokens = generate_fine(
coarse_tokens,
Expand Down
12 changes: 12 additions & 0 deletions bark/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ def generate_coarse(
max_coarse_history=630, # min 60 (faster), max 630 (more context)
sliding_window_len=60,
use_kv_caching=False,
max_gen_duration_s=None,
):
"""Generate coarse audio codes from semantic tokens."""
assert (
Expand Down Expand Up @@ -605,6 +606,17 @@ def generate_coarse(
* N_COARSE_CODEBOOKS
)
)

if max_gen_duration_s is not None:
n_steps = min(
n_steps,
int(
np.floor(
round(max_gen_duration_s * COARSE_RATE_HZ)
) * N_COARSE_CODEBOOKS
)
)

assert n_steps > 0 and n_steps % N_COARSE_CODEBOOKS == 0
x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32)
x_coarse = x_coarse_history.astype(np.int32)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "suno-bark"
version = "0.0.1a"
version = "0.1.0"
description = "Bark text to audio model"
readme = "README.md"
requires-python = ">=3.8"
Expand Down