Skip to content

Commit

Permalink
feat: add common version mange
Browse files Browse the repository at this point in the history
  • Loading branch information
cubxxw committed Dec 25, 2024
1 parent c540c5a commit e026d09
Showing 1 changed file with 42 additions and 53 deletions.
95 changes: 42 additions & 53 deletions internal/stt/assemblyai/assemblyai.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,46 @@ func (s *STT) Recognize(audioData []byte, audioURL string) (string, error) {
func (s *STT) transcribeFromURL(audioURL string) (string, error) {
ctx := context.Background()

// 第一次尝试��使用语言检测
params := s.buildParams()
// 第一次尝试:启用语言检测
params := &aai.TranscriptOptionalParams{
LanguageDetection: aai.Bool(true),
LanguageConfidenceThreshold: aai.Float64(0.1), // 设置较低的初始阈值
Punctuate: aai.Bool(s.cfg.AssemblyAI.Punctuate),
FormatText: aai.Bool(s.cfg.AssemblyAI.FormatText),
SpeechThreshold: aai.Float64(s.cfg.AssemblyAI.SpeechThreshold),
Multichannel: aai.Bool(s.cfg.AssemblyAI.Multichannel),
}

transcript, err := s.client.Transcripts.TranscribeFromURL(ctx, audioURL, params)
if err != nil {
// 检查是否是语言置信度错误
if s.isLanguageConfidenceError(err) && s.cfg.AssemblyAI.DefaultLanguageCode != "" {
// 使用默认语言重试
logger.Infof("语言置信度低于阈值 %.2f,使用默认语言 %s 重试",
s.cfg.AssemblyAI.LanguageConfidenceThreshold,
logger.Infof("第一次尝试失败(语言置信度低), 使用默认语言 %s 重试",
s.cfg.AssemblyAI.DefaultLanguageCode)

// 构建新的参数,使用默认语言
params = s.buildParamsWithDefaultLanguage()
transcript, err = s.client.Transcripts.TranscribeFromURL(ctx, audioURL, params)
// 第二次尝试:禁用语言检测,使用固定语言
retryParams := &aai.TranscriptOptionalParams{
LanguageDetection: aai.Bool(false), // 明确禁用语言检测
LanguageCode: aai.TranscriptLanguageCode(s.cfg.AssemblyAI.DefaultLanguageCode),
// 基础参数
Punctuate: aai.Bool(s.cfg.AssemblyAI.Punctuate),
FormatText: aai.Bool(s.cfg.AssemblyAI.FormatText),
SpeechThreshold: aai.Float64(s.cfg.AssemblyAI.SpeechThreshold),
Multichannel: aai.Bool(s.cfg.AssemblyAI.Multichannel),
// 不再设置 LanguageConfidenceThreshold
}

// 记录重试请求参数
logger.Debugf("重试请求参数: %+v", retryParams)

transcript, err = s.client.Transcripts.TranscribeFromURL(ctx, audioURL, retryParams)
if err != nil {
return "", fmt.Errorf("语言置信度低于阈值 %.2f,使用默认语言 %s 重试失败: %v",
s.cfg.AssemblyAI.LanguageConfidenceThreshold,
s.cfg.AssemblyAI.DefaultLanguageCode,
err)
return "", fmt.Errorf("使用默认语言 %s 重试失败: %v",
s.cfg.AssemblyAI.DefaultLanguageCode, err)
}
} else {
return "", fmt.Errorf("转录请求失败: %v", err)
}
return "", fmt.Errorf("转录请求失败: %v", err)
}

// 使用指数退避策略,轮询转录状态
Expand Down Expand Up @@ -132,34 +150,20 @@ func (s *STT) buildParams() *aai.TranscriptOptionalParams {
aaiCfg := s.cfg.AssemblyAI

params := &aai.TranscriptOptionalParams{
// 将 string 转换为 SpeechModel 类型
SpeechModel: aai.SpeechModel(aaiCfg.Model),
LanguageDetection: aai.Bool(aaiCfg.LanguageDetection),
LanguageConfidenceThreshold: aai.Float64(aaiCfg.LanguageConfidenceThreshold),
Punctuate: aai.Bool(aaiCfg.Punctuate),
FormatText: aai.Bool(aaiCfg.FormatText),
Disfluencies: aai.Bool(aaiCfg.Disfluencies),
FilterProfanity: aai.Bool(aaiCfg.FilterProfanity),
AudioStartFrom: aai.Int64(aaiCfg.AudioStartFrom),
AudioEndAt: aai.Int64(aaiCfg.AudioEndAt),
SpeechThreshold: aai.Float64(aaiCfg.SpeechThreshold),
Multichannel: aai.Bool(aaiCfg.Multichannel),
SpeechModel: aai.SpeechModel(aaiCfg.Model),
Punctuate: aai.Bool(aaiCfg.Punctuate),
FormatText: aai.Bool(aaiCfg.FormatText),
SpeechThreshold: aai.Float64(aaiCfg.SpeechThreshold),
Multichannel: aai.Bool(aaiCfg.Multichannel),
}

// 如果设置了固定的 language_code,则禁用语言检测并指定语言
if aaiCfg.LanguageCode != "" {
params.LanguageDetection = aai.Bool(false)
params.LanguageCode = aai.TranscriptLanguageCode(aaiCfg.LanguageCode)
}

// 如果配置了词汇增强
// 词汇增强设置
if len(aaiCfg.WordBoost) > 0 {
params.WordBoost = aaiCfg.WordBoost
// 将 string 转换为 TranscriptBoostParam 类型
params.BoostParam = aai.TranscriptBoostParam(aaiCfg.BoostParam)
}

// 如果配置了自定义拼写
// 自定义拼写设置
if len(aaiCfg.CustomSpelling) > 0 {
var customSpellings []aai.TranscriptCustomSpelling
for _, cs := range aaiCfg.CustomSpelling {
Expand All @@ -174,24 +178,9 @@ func (s *STT) buildParams() *aai.TranscriptOptionalParams {
return params
}

// 新增:检查是否是语言置信度错误
// isLanguageConfidenceError 优化错误检测逻辑
func (s *STT) isLanguageConfidenceError(err error) bool {
return strings.Contains(err.Error(), "below the requested confidence threshold value")
}

// 新增:使用默认语言构建参数
func (s *STT) buildParamsWithDefaultLanguage() *aai.TranscriptOptionalParams {
// 不再调用 s.buildParams(),防止里面带了 threshold
// 自己手动指定二次请求想要的字段
return &aai.TranscriptOptionalParams{
LanguageDetection: aai.Bool(false),
LanguageCode: aai.TranscriptLanguageCode(s.cfg.AssemblyAI.DefaultLanguageCode),
Punctuate: aai.Bool(true),
FormatText: aai.Bool(true),
SpeechThreshold: aai.Float64(s.cfg.AssemblyAI.SpeechThreshold),
Multichannel: aai.Bool(s.cfg.AssemblyAI.Multichannel),
AudioStartFrom: aai.Int64(s.cfg.AssemblyAI.AudioStartFrom),
AudioEndAt: aai.Int64(s.cfg.AssemblyAI.AudioEndAt),
BoostParam: aai.TranscriptBoostParam(s.cfg.AssemblyAI.BoostParam),
}
errMsg := err.Error()
return strings.Contains(errMsg, "below the requested confidence threshold") ||
strings.Contains(errMsg, "confidence threshold value")
}

0 comments on commit e026d09

Please sign in to comment.