Skip to content

Commit

Permalink
Fix ServeReferenceAudio to allow base64 reference data in json (#777)
Browse files Browse the repository at this point in the history
* Update schema.py to fix ServeStreamResponse

* Update schema.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
MithrilMan and pre-commit-ci[bot] authored Dec 21, 2024
1 parent 40665e1 commit b8bdcd4
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion tools/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import base64
import os
import queue
from dataclasses import dataclass
from typing import Literal

import torch
from pydantic import BaseModel, Field, conint, conlist
from pydantic import BaseModel, Field, conint, conlist, model_validator
from pydantic.functional_validators import SkipValidation
from typing_extensions import Annotated

Expand Down Expand Up @@ -140,6 +141,19 @@ class ServeReferenceAudio(BaseModel):
audio: bytes
text: str

@model_validator(mode="before")
def decode_audio(cls, values):
audio = values.get("audio")
if (
isinstance(audio, str) and len(audio) > 255
): # Check if audio is a string (Base64)
try:
values["audio"] = base64.b64decode(audio)
except Exception as e:
# If the audio is not a valid base64 string, we will just ignore it and let the server handle it
pass
return values

def __repr__(self) -> str:
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"

Expand Down

0 comments on commit b8bdcd4

Please sign in to comment.