Skip to content

Commit

Permalink
Room: AI-Talk allow disable ASR/TTS, enable text. v5.13.19
Browse files Browse the repository at this point in the history
  • Loading branch information
winlinvip committed Jan 29, 2024
1 parent 7bf7220 commit f3cef62
Show file tree
Hide file tree
Showing 8 changed files with 428 additions and 159 deletions.
1 change: 1 addition & 0 deletions DEVELOPER.md
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,7 @@ The following are the update records for the SRS Stack server.
* Room: AI-Talk support popout AI assistant. v5.13.17
* Room: AI-Talk support multiple assistant in a room. v5.13.18
* Room: AI-Talk support user different languages. v5.13.18
* Room: AI-Talk allow disable ASR/TTS, enable text. v5.13.19
* v5.12
* Refine local variable name conf to config. v5.12.1
* Add forced exit on timeout for program termination. v5.12.1
Expand Down
164 changes: 118 additions & 46 deletions platform/ai-talk.go
Original file line number Diff line number Diff line change
Expand Up @@ -555,42 +555,42 @@ func (v *StageRequest) total() float64 {
}

func (v *StageRequest) upload() float64 {
if v.lastUploadAudio.After(v.lastSentence) {
if v.lastUploadAudio.After(v.lastSentence.Add(100 * time.Millisecond)) {
return float64(v.lastUploadAudio.Sub(v.lastSentence)) / float64(time.Second)
}
return 0
}

func (v *StageRequest) exta() float64 {
if v.lastExtractAudio.After(v.lastUploadAudio) {
if v.lastExtractAudio.After(v.lastUploadAudio.Add(100 * time.Millisecond)) {
return float64(v.lastExtractAudio.Sub(v.lastUploadAudio)) / float64(time.Second)
}
return 0
}

func (v *StageRequest) asr() float64 {
if v.lastRequestASR.After(v.lastExtractAudio) {
if v.lastRequestASR.After(v.lastExtractAudio.Add(100 * time.Millisecond)) {
return float64(v.lastRequestASR.Sub(v.lastExtractAudio)) / float64(time.Second)
}
return 0
}

func (v *StageRequest) chat() float64 {
if v.lastRequestChat.After(v.lastRequestASR) {
if v.lastRequestChat.After(v.lastRequestASR.Add(100 * time.Millisecond)) {
return float64(v.lastRequestChat.Sub(v.lastRequestASR)) / float64(time.Second)
}
return 0
}

func (v *StageRequest) tts() float64 {
if v.lastRequestTTS.After(v.lastRequestChat) {
if v.lastRequestTTS.After(v.lastRequestChat.Add(100 * time.Millisecond)) {
return float64(v.lastRequestTTS.Sub(v.lastRequestChat)) / float64(time.Second)
}
return 0
}

func (v *StageRequest) download() float64 {
if v.lastDownloadAudio.After(v.lastRequestTTS) {
if v.lastDownloadAudio.After(v.lastRequestTTS.Add(100 * time.Millisecond)) {
return float64(v.lastDownloadAudio.Sub(v.lastRequestTTS)) / float64(time.Second)
}
return 0
Expand All @@ -614,6 +614,8 @@ type StageMessage struct {
// For role audio.
// The audio segment uuid.
SegmentUUID string `json:"asid"`
// Whether has audio file.
HasAudioFile bool `json:"hasAudio"`
// The audio tts file for audio message.
audioFile string

Expand Down Expand Up @@ -694,18 +696,26 @@ func (v *StageSubscriber) addUserTextMessage(rid, name, msg string) {
func (v *StageSubscriber) createRobotEmptyMessage() *StageMessage {
message := &StageMessage{
finished: false, MessageUUID: uuid.NewString(), subscriber: v,
Role: "robot",
}
v.messages = append(v.messages, message)
return message
}

func (v *StageSubscriber) completeRobotAudioMessage(ctx context.Context, sreq *StageRequest, segment *AnswerSegment, message *StageMessage) {
// Build a new copy file of ttsFile.
ttsExt := path.Ext(segment.ttsFile)
copyFile := fmt.Sprintf("%v-copy-%v%v", segment.ttsFile[:len(segment.ttsFile)-len(ttsExt)], v.spid, ttsExt)
var copyFile string
if !segment.noTTS && segment.ttsFile != "" {
ttsExt := path.Ext(segment.ttsFile)
copyFile = fmt.Sprintf("%v-copy-%v%v", segment.ttsFile[:len(segment.ttsFile)-len(ttsExt)], v.spid, ttsExt)
}

// Copy the ttsFile to copyFile.
if err := func() error {
if copyFile == "" {
return nil
}

// If segment is error, ignore.
if segment.err != nil {
return nil
Expand Down Expand Up @@ -736,7 +746,11 @@ func (v *StageSubscriber) completeRobotAudioMessage(ctx context.Context, sreq *S

message.finished, message.segment = true, segment
message.RequestUUID, message.SegmentUUID = sreq.rid, segment.asid
message.Role, message.Message, message.audioFile = "robot", segment.text, copyFile
message.Message, message.audioFile = segment.text, copyFile
message.Username = v.stage.room.AIName

// User may disable TTS, we only ship the text message to user.
message.HasAudioFile = !segment.noTTS

// Always close message if timeout.
go func() {
Expand Down Expand Up @@ -866,6 +880,11 @@ type Stage struct {
// AI Chat message window.
chatWindow int

// Whether enabled AI services.
aiASREnabled bool
aiChatEnabled bool
aiTtsEnabled bool

// The AI configuration.
aiConfig openai.ClientConfig
// The room it belongs to. Note that it's a caching object, update when updating the room. The room object
Expand Down Expand Up @@ -924,6 +943,11 @@ func helloVoiceFromLanguage(language string) string {
}

func (v *Stage) UpdateFromRoom(room *SrsLiveRoom) {
// Whether enabled.
v.aiASREnabled = room.AIASREnabled
v.aiChatEnabled = room.AIChatEnabled
v.aiTtsEnabled = room.AITTSEnabled

// Create robot for the stage, which attach to a special room.
v.voice = helloVoiceFromLanguage(room.AIASRLanguage)
v.prompt = room.AIChatPrompt
Expand Down Expand Up @@ -1028,6 +1052,8 @@ type AnswerSegment struct {
text string
// The TTS file path.
ttsFile string
// Whether no tts file, as user disabled TTS for example.
noTTS bool
// Whether TTS is done, ready to play.
ready bool
// Whether TTS is error, failed.
Expand Down Expand Up @@ -1181,18 +1207,25 @@ func (v *TTSWorker) SubmitSegment(ctx context.Context, stage *Stage, sreq *Stage
go func() {
defer v.wg.Done()

ttsService := NewOpenAITTSService(stage.aiConfig)
if err := ttsService.RequestTTS(ctx, func(ext string) string {
segment.ttsFile = path.Join(aiTalkWorkDir,
fmt.Sprintf("assistant-%v-sentence-%v-tts.%v", sreq.rid, segment.asid, ext),
)
return segment.ttsFile
}, segment.text); err != nil {
segment.err = err
if stage.aiTtsEnabled {
ttsService := NewOpenAITTSService(stage.aiConfig)
if err := ttsService.RequestTTS(ctx, func(ext string) string {
segment.ttsFile = path.Join(aiTalkWorkDir,
fmt.Sprintf("assistant-%v-sentence-%v-tts.%v", sreq.rid, segment.asid, ext),
)
return segment.ttsFile
}, segment.text); err != nil {
segment.err = err
} else {
segment.ready, segment.noTTS = true, false
sreq.onSegmentReady(segment)
logger.Tf(ctx, "TTS: Complete rid=%v, asid=%v, file saved to %v, %v",
sreq.rid, segment.asid, segment.ttsFile, segment.text)
}
} else {
segment.ready = true
segment.ready, segment.noTTS = true, true
sreq.onSegmentReady(segment)
logger.Tf(ctx, "File saved to %v, %v", segment.ttsFile, segment.text)
logger.Tf(ctx, "TTS: Skip rid=%v, asid=%v, %v", sreq.rid, segment.asid, segment.text)
}

// Update all messages.
Expand Down Expand Up @@ -1358,11 +1391,15 @@ func handleAITalkService(ctx context.Context, handler *http.ServeMux) error {
StageID string `json:"sid"`
RoomToken string `json:"roomToken"`
UserID string `json:"userId"`
// AI Configurations.
AIASREnabled bool `json:"aiAsrEnabled"`
}
r0 := &StageResult{
StageID: stage.sid,
RoomToken: stage.room.RoomToken,
UserID: user.UserID,
// AI Configurations.
AIASREnabled: room.AIASREnabled,
}

ohttp.WriteData(ctx, w, r, r0)
Expand Down Expand Up @@ -1447,21 +1484,23 @@ func handleAITalkService(ctx context.Context, handler *http.ServeMux) error {
handler.HandleFunc(ep, func(w http.ResponseWriter, r *http.Request) {
if err := func() error {
var token string
var sid, rid, userID, audioBase64Data string
var sid, rid, userID string
var roomUUID, roomToken string
var userMayInput float64
var audioBase64Data, textMessage string
if err := ParseBody(ctx, r.Body, &struct {
Token *string `json:"token"`
StageUUID *string `json:"sid"`
UserID *string `json:"userId"`
RequestUUID *string `json:"rid"`
UserMayInput *float64 `json:"umi"`
AudioData *string `json:"audio"`
TextMessage *string `json:"text"`
RoomUUID *string `json:"room"`
RoomToken *string `json:"roomToken"`
}{
Token: &token, StageUUID: &sid, UserID: &userID, RequestUUID: &rid,
UserMayInput: &userMayInput, AudioData: &audioBase64Data,
UserMayInput: &userMayInput, TextMessage: &textMessage, AudioData: &audioBase64Data,
RoomUUID: &roomUUID, RoomToken: &roomToken,
}); err != nil {
return errors.Wrapf(err, "parse body")
Expand All @@ -1484,6 +1523,9 @@ func handleAITalkService(ctx context.Context, handler *http.ServeMux) error {
if userID == "" {
return errors.Errorf("empty userId")
}
if audioBase64Data == "" && textMessage == "" {
return errors.Errorf("empty audio and text")
}

stage := talkServer.QueryStage(sid)
if stage == nil {
Expand Down Expand Up @@ -1522,18 +1564,33 @@ func handleAITalkService(ctx context.Context, handler *http.ServeMux) error {
logger.Tf(ctx, "Stage: Got question sid=%v, rid=%v, user=%v, umi=%v, input=%v",
sid, sreq.rid, userID, userMayInput, sreq.inputFile)

// Save audio input to file.
if err := sreq.receiveInputFile(ctx, audioBase64Data); err != nil {
return errors.Wrapf(err, "save %vB audio to file %v", len(audioBase64Data), sreq.inputFile)
// Whether user input audio.
if audioBase64Data != "" {
// Save audio input to file.
if err := sreq.receiveInputFile(ctx, audioBase64Data); err != nil {
return errors.Wrapf(err, "save %vB audio to file %v", len(audioBase64Data), sreq.inputFile)
}

// Do ASR, convert to text.
asrLanguage := ChooseNotEmpty(user.Language, stage.asrLanguage)
if err := sreq.asrAudioToText(ctx, stage.aiConfig, asrLanguage, user.previousAsrText); err != nil {
return errors.Wrapf(err, "asr lang=%v, previous=%v", asrLanguage, user.previousAsrText)
}
logger.Tf(ctx, "ASR ok, sid=%v, rid=%v, user=%v, lang=%v, prompt=<%v>, resp is <%v>",
sid, sreq.rid, userID, asrLanguage, user.previousAsrText, sreq.asrText)
} else {
// Directly update the time for stat.
sreq.lastUploadAudio = time.Now()
sreq.lastExtractAudio = time.Now()
sreq.lastRequestASR = time.Now()
}

// Do ASR, convert to text.
asrLanguage := ChooseNotEmpty(user.Language, stage.asrLanguage)
if err := sreq.asrAudioToText(ctx, stage.aiConfig, asrLanguage, user.previousAsrText); err != nil {
return errors.Wrapf(err, "asr lang=%v, previous=%v", asrLanguage, user.previousAsrText)
// Handle user input text.
if textMessage != "" {
sreq.asrText = textMessage
logger.Tf(ctx, "Text ok, sid=%v, rid=%v, user=%v, text=%v",
sid, sreq.rid, userID, sreq.asrText)
}
logger.Tf(ctx, "ASR ok, sid=%v, rid=%v, user=%v, lang=%v, prompt=<%v>, resp is <%v>",
sid, sreq.rid, userID, asrLanguage, user.previousAsrText, sreq.asrText)

// Important trace log.
user.previousAsrText = sreq.asrText
Expand Down Expand Up @@ -1638,7 +1695,7 @@ func handleAITalkService(ctx context.Context, handler *http.ServeMux) error {
ohttp.WriteData(ctx, w, r, struct {
Finished bool `json:"finished"`
}{
Finished: !sreq.finished,
Finished: sreq.finished,
})

return nil
Expand Down Expand Up @@ -1892,6 +1949,30 @@ func handleAITalkService(ctx context.Context, handler *http.ServeMux) error {
}
})

finishAudioSegment := func(segment *AnswerSegment) {
if segment == nil || segment.logged {
return
}

// Only log the first segment.
segment.logged = true
if !segment.first {
return
}

// Time cost logging.
sreq := segment.request
sreq.lastDownloadAudio = time.Now()
speech := float64(sreq.lastAsrDuration) / float64(time.Second)
logger.Tf(ctx, "Elapsed cost total=%.1fs, steps=[upload=%.1fs,exta=%.1fs,asr=%.1fs,chat=%.1fs,tts=%.1fs,download=%.1fs], ask=%v, speech=%.1fs, answer=%v",
sreq.total(), sreq.upload(), sreq.exta(), sreq.asr(), sreq.chat(), sreq.tts(), sreq.download(),
sreq.lastRequestAsrText, speech, sreq.lastRobotFirstText)

// Important trace log. Note that browser may request multiple times, so we only log for the first
// request to reduce logs.
logger.Tf(ctx, "Bot: %v", segment.text)
}

ep = "/terraform/v1/ai-talk/subscribe/tts"
logger.Tf(ctx, "Handle %v", ep)
handler.HandleFunc(ep, func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -1954,21 +2035,7 @@ func handleAITalkService(ctx context.Context, handler *http.ServeMux) error {
logger.Tf(ctx, "Stage: Download sid=%v, spid=%v, asid=%v", sid, spid, asid)

// When the first subscriber got the segment, we log the elapsed time.
if segment := answer.segment; !segment.logged {
sreq := segment.request
if segment.first {
sreq.lastDownloadAudio = time.Now()
speech := float64(sreq.lastAsrDuration) / float64(time.Second)
logger.Tf(ctx, "Elapsed cost total=%.1fs, steps=[upload=%.1fs,exta=%.1fs,asr=%.1fs,chat=%.1fs,tts=%.1fs,download=%.1fs], ask=%v, speech=%.1fs, answer=%v",
sreq.total(), sreq.upload(), sreq.exta(), sreq.asr(), sreq.chat(), sreq.tts(), sreq.download(),
sreq.lastRequestAsrText, speech, sreq.lastRobotFirstText)
}

// Important trace log. Note that browser may request multiple times, so we only log for the first
// request to reduce logs.
segment.logged = true
logger.Tf(ctx, "Bot: %v", segment.text)
}
finishAudioSegment(answer.segment)

// Read the ttsFile and response it as opus audio.
if strings.HasSuffix(answer.audioFile, ".wav") {
Expand Down Expand Up @@ -2041,6 +2108,11 @@ func handleAITalkService(ctx context.Context, handler *http.ServeMux) error {
return errors.Errorf("invalid spid %v of sid %v", spid, sid)
}

// If no audio file, we stat the time cost when remove the segment.
if answer := subscriber.queryAudioFile(asid); answer != nil {
finishAudioSegment(answer.segment)
}

// Keep alive the stage.
stage.KeepAlive()
subscriber.KeepAlive()
Expand Down
Loading

0 comments on commit f3cef62

Please sign in to comment.