diff --git a/.env.example b/.env.example index 4ee7670..a6cce84 100644 --- a/.env.example +++ b/.env.example @@ -36,3 +36,7 @@ VOICEFLOW_VOLCENGINE_TTS_TOKEN='' # 语音服务端口配置 VOICEFLOW_SERVER_PORT=18080 # 语音服务端口 + +# Whisper 配置 +VOICEFLOW_WHISPER_API_KEY='' +VOICEFLOW_WHISPER_ENDPOINT="https://audio-turbo.us-virginia-1.direct.fireworks.ai/v1/audio/transcriptions" \ No newline at end of file diff --git a/configs/config.yaml b/configs/config.yaml index 1cc89ca..2a42f14 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -12,7 +12,7 @@ minio: storage_path: "voiceflow/audio/" stt: - # 可选值:azure、 google、 local、 assemblyai-ws、 volcengine、 aws、 assemblyai + # 可选值:azure、 google、 local、 assemblyai-ws、 volcengine、 aws、 assemblyai、 whisper-v3 provider: assemblyai tts: @@ -121,4 +121,17 @@ logging: max_backups: 3 max_age: 28 compress: true - report_caller: true \ No newline at end of file + report_caller: true + +# 添加 Whisper 配置段 +whisper: + api_key: "your-api-key" + endpoint: "https://audio-turbo.us-virginia-1.direct.fireworks.ai/v1/audio/transcriptions" + model: "accounts/fireworks/models/whisper-v3-turbo" + temperature: 0 + vad_model: "silero" + max_retries: 3 + timeout: 30 + language: "auto" + task: "transcribe" + batch_size: 30 \ No newline at end of file diff --git a/internal/server/handlers.go b/internal/server/handlers.go index 4912e59..76d3a2b 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -30,7 +30,10 @@ func InitServices() { if err != nil { logger.Fatalf("配置初始化失败: %v", err) } - sttService = stt.NewService(cfg.STT.Provider) + sttService, err = stt.NewService(cfg.STT.Provider) + if err != nil { + logger.Fatalf("STT 服务初始化失败: %v", err) + } ttsService = tts.NewService(cfg.TTS.Provider) // llmService = llm.NewService(cfg.LLM.Provider) storageService = storage.NewService() diff --git a/internal/stt/stt.go b/internal/stt/stt.go index 336b444..8aeaca8 100644 --- a/internal/stt/stt.go +++ b/internal/stt/stt.go @@ -2,12 +2,15 @@ package stt import ( + "fmt" + assemblyai "github.com/telepace/voiceflow/internal/stt/assemblyai" aaiws "github.com/telepace/voiceflow/internal/stt/assemblyai-ws" "github.com/telepace/voiceflow/internal/stt/azure" "github.com/telepace/voiceflow/internal/stt/google" "github.com/telepace/voiceflow/internal/stt/local" "github.com/telepace/voiceflow/internal/stt/volcengine" + "github.com/telepace/voiceflow/internal/stt/whisper" "github.com/telepace/voiceflow/pkg/logger" ) @@ -17,22 +20,24 @@ type Service interface { } // NewService 根据配置返回相应的 STT 服务实现 -func NewService(provider string) Service { +func NewService(provider string) (Service, error) { logger.Debugf("Using STT provider: %s", provider) switch provider { case "azure": - return azure.NewAzureSTT() + return azure.NewAzureSTT(), nil case "google": - return google.NewGoogleSTT() + return google.NewGoogleSTT(), nil case "assemblyai-ws": - return aaiws.NewAssemblyAI() + return aaiws.NewAssemblyAI(), nil case "volcengine": - return volcengine.NewVolcengineSTT() + return volcengine.NewVolcengineSTT(), nil case "local": - return local.NewLocalSTT() + return local.NewLocalSTT(), nil case "assemblyai": - return assemblyai.NewAssemblyAI() + return assemblyai.NewAssemblyAI(), nil + case "whisper-v3": + return whisper.NewWhisperSTT(), nil default: - return local.NewLocalSTT() + return nil, fmt.Errorf("未知的 STT 提供商: %s", provider) } } diff --git a/internal/stt/whisper/whisper.go b/internal/stt/whisper/whisper.go new file mode 100644 index 0000000..112b176 --- /dev/null +++ b/internal/stt/whisper/whisper.go @@ -0,0 +1,110 @@ +package whisper + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + + "github.com/telepace/voiceflow/pkg/config" + "github.com/telepace/voiceflow/pkg/logger" +) + +type WhisperSTT struct { + apiKey string + endpoint string + model string + temperature float64 + vadModel string +} + +type WhisperResponse struct { + Text string `json:"text"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` +} + +func NewWhisperSTT() *WhisperSTT { + cfg, err := config.GetConfig() + if err != nil { + logger.Fatalf("配置初始化失败: %v", err) + } + + return &WhisperSTT{ + apiKey: cfg.Whisper.APIKey, + endpoint: cfg.Whisper.Endpoint, + model: cfg.Whisper.Model, + temperature: cfg.Whisper.Temperature, + vadModel: cfg.Whisper.VADModel, + } +} + +func (w *WhisperSTT) Recognize(audioData []byte, audioURL string) (string, error) { + // 创建一个 buffer 来写入 multipart 数据 + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // 写入音频文件 + part, err := writer.CreateFormFile("file", "audio.mp3") + if err != nil { + return "", fmt.Errorf("创建表单文件失败: %v", err) + } + if _, err := io.Copy(part, bytes.NewReader(audioData)); err != nil { + return "", fmt.Errorf("写入音频数据失败: %v", err) + } + + // 添加其他参数 + writer.WriteField("model", w.model) + writer.WriteField("temperature", fmt.Sprintf("%f", w.temperature)) + writer.WriteField("vad_model", w.vadModel) + + // 关闭 multipart writer + if err := writer.Close(); err != nil { + return "", fmt.Errorf("关闭 writer 失败: %v", err) + } + + // 创建请求 + req, err := http.NewRequest("POST", w.endpoint, body) + if err != nil { + return "", fmt.Errorf("创建请求失败: %v", err) + } + + // 设置请求头 + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", w.apiKey)) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + // 发送请求 + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("发送请求失败: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("API 请求失败,状态码: %d,响应: %s", + resp.StatusCode, string(bodyBytes)) + } + + // 解析响应 + var result WhisperResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("解析响应失败: %v", err) + } + + logger.Infof("语音识别完成,语言: %s, 时长: %.2f秒", + result.Language, result.Duration) + + return result.Text, nil +} + +func (w *WhisperSTT) StreamRecognize(ctx context.Context, audioDataChan <-chan []byte, + transcriptChan chan<- string) error { + // Whisper V3 Turbo 目前不支持流式处理 + return fmt.Errorf("Whisper V3 Turbo 不支持流式处理") +} diff --git a/pkg/config/config.go b/pkg/config/config.go index e5a81c4..8bcb0ee 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -51,6 +51,19 @@ type AWSConfig struct { Region string `yaml:"region"` } +type WhisperConfig struct { + APIKey string `mapstructure:"api_key"` + Endpoint string `mapstructure:"endpoint"` + Model string `mapstructure:"model"` + Temperature float64 `mapstructure:"temperature"` + VADModel string `mapstructure:"vad_model"` + MaxRetries int `mapstructure:"max_retries"` + Timeout int `mapstructure:"timeout"` + Language string `mapstructure:"language"` // 可选的指定语言 + Task string `mapstructure:"task"` // transcribe 或 translate + BatchSize int `mapstructure:"batch_size"` // 音频分段大小(秒) +} + type Config struct { Server struct { Port int @@ -129,6 +142,7 @@ type Config struct { Compress bool `mapstructure:"compress"` ReportCaller bool `mapstructure:"report_caller"` } + Whisper WhisperConfig `mapstructure:"whisper"` } var (