Skip to content

Add structured response for Azure AudioTranscription #1278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
import org.springframework.ai.audio.transcription.AudioTranscription;
import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt;
import org.springframework.ai.audio.transcription.AudioTranscriptionResponse;
import org.springframework.ai.audio.transcription.metadata.StructuredResponse;
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.GranularityType;
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.StructuredResponse;
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.StructuredResponse.Segment;
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.StructuredResponse.Word;
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.TranscriptResponseFormat;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiAudioTranscriptionResponseMetadata;
import org.springframework.ai.model.Model;
Expand All @@ -38,6 +36,8 @@
import java.io.IOException;
import java.util.List;

import static org.springframework.ai.audio.transcription.metadata.StructuredResponse.*;

/**
* AzureOpenAI audio transcription client implementation for backed by
* {@link OpenAIClient}. You provide as input the audio file you want to transcribe and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,99 +242,24 @@ public String getValue() {

}

/**
* @param language The language of the transcribed text.
* @param duration The duration of the audio in seconds.
* @param text The transcribed text.
* @param words The extracted words and their timestamps.
* @param segments The segments of the transcribed text and their corresponding
* details.
*/
@JsonInclude(Include.NON_NULL)
public record StructuredResponse(
// @formatter:off
@JsonProperty("language") String language,
@JsonProperty("duration") Float duration,
@JsonProperty("text") String text,
@JsonProperty("words") List<Word> words,
@JsonProperty("segments") List<Segment> segments) {
// @formatter:on

/**
* Extracted word and it's corresponding timestamps.
*
* @param word The text content of the word.
* @param start The start time of the word in seconds.
* @param end The end time of the word in seconds.
*/
@JsonInclude(Include.NON_NULL)
public record Word(
// @formatter:off
@JsonProperty("word") String word,
@JsonProperty("start") Float start,
@JsonProperty("end") Float end) {
// @formatter:on
}

/**
* Segment of the transcribed text and its corresponding details.
*
* @param id Unique identifier of the segment.
* @param seek Seek offset of the segment.
* @param start Start time of the segment in seconds.
* @param end End time of the segment in seconds.
* @param text The text content of the segment.
* @param tokens Array of token IDs for the text content.
* @param temperature Temperature parameter used for generating the segment.
* @param avgLogprob Average logprob of the segment. If the value is lower than
* -1, consider the logprobs failed.
* @param compressionRatio Compression ratio of the segment. If the value is
* greater than 2.4, consider the compression failed.
* @param noSpeechProb Probability of no speech in the segment. If the value is
* higher than 1.0 and the avg_logprob is below -1, consider this segment silent.
*/
@JsonInclude(Include.NON_NULL)
public record Segment(
// @formatter:off
@JsonProperty("id") Integer id,
@JsonProperty("seek") Integer seek,
@JsonProperty("start") Float start,
@JsonProperty("end") Float end,
@JsonProperty("text") String text,
@JsonProperty("tokens") List<Integer> tokens,
@JsonProperty("temperature") Float temperature,
@JsonProperty("avg_logprob") Float avgLogprob,
@JsonProperty("compression_ratio") Float compressionRatio,
@JsonProperty("no_speech_prob") Float noSpeechProb) {
// @formatter:on
}
}

public enum TranscriptResponseFormat {

// @formatter:off
@JsonProperty("json") JSON(AudioTranscriptionFormat.JSON, StructuredResponse.class),
@JsonProperty("text") TEXT(AudioTranscriptionFormat.TEXT, String.class),
@JsonProperty("srt") SRT(AudioTranscriptionFormat.SRT, String.class),
@JsonProperty("verbose_json") VERBOSE_JSON(AudioTranscriptionFormat.VERBOSE_JSON, StructuredResponse.class),
@JsonProperty("vtt") VTT(AudioTranscriptionFormat.VTT, String.class);
@JsonProperty("json") JSON(AudioTranscriptionFormat.JSON),
@JsonProperty("text") TEXT(AudioTranscriptionFormat.TEXT),
@JsonProperty("srt") SRT(AudioTranscriptionFormat.SRT),
@JsonProperty("verbose_json") VERBOSE_JSON(AudioTranscriptionFormat.VERBOSE_JSON),
@JsonProperty("vtt") VTT(AudioTranscriptionFormat.VTT);

public final AudioTranscriptionFormat value;

public final Class<?> responseType;

TranscriptResponseFormat(AudioTranscriptionFormat value, Class<?> responseType) {
TranscriptResponseFormat(AudioTranscriptionFormat value) {
this.value = value;
this.responseType = responseType;
}

public AudioTranscriptionFormat getValue() {
return this.value;
}

public Class<?> getResponseType() {
return this.responseType;
}
}

public enum GranularityType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
package org.springframework.ai.azure.openai.metadata;

import org.springframework.ai.audio.transcription.AudioTranscriptionResponseMetadata;
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions;
import org.springframework.ai.audio.transcription.metadata.StructuredResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand All @@ -28,21 +29,29 @@ public class AzureOpenAiAudioTranscriptionResponseMetadata extends AudioTranscri

protected static final String AI_METADATA_STRING = "{ @type: %1$s }";

public static final AzureOpenAiAudioTranscriptionResponseMetadata NULL = new AzureOpenAiAudioTranscriptionResponseMetadata() {
};

public static AzureOpenAiAudioTranscriptionResponseMetadata from(
AzureOpenAiAudioTranscriptionOptions.StructuredResponse result) {
Assert.notNull(result, "AzureOpenAI Transcription must not be null");
return new AzureOpenAiAudioTranscriptionResponseMetadata();
public static AzureOpenAiAudioTranscriptionResponseMetadata from(StructuredResponse structuredResponse) {
Assert.notNull(structuredResponse, "AzureOpenAI Transcription must not be null");
return new AzureOpenAiAudioTranscriptionResponseMetadata(structuredResponse);
}

public static AzureOpenAiAudioTranscriptionResponseMetadata from(String result) {
Assert.notNull(result, "AzureOpenAI Transcription must not be null");
return new AzureOpenAiAudioTranscriptionResponseMetadata();
}

private final StructuredResponse structuredResponse;

protected AzureOpenAiAudioTranscriptionResponseMetadata() {
this(null);
}

public AzureOpenAiAudioTranscriptionResponseMetadata(StructuredResponse structuredResponse) {
this.structuredResponse = structuredResponse;
}

@Nullable
public StructuredResponse getStructuredResponse() {
return structuredResponse;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package org.springframework.ai.audio.transcription.metadata;

import org.springframework.lang.Nullable;

import java.util.List;

/**
* @author Piotr Olaszewski
*/
public record StructuredResponse(String language, Float duration, String text, @Nullable List<Word> words,
@Nullable List<Segment> segments) {

/**
* Extracted word and it's corresponding timestamps
*
* @param word The text content of the word.
* @param start The start time of the word in seconds.
* @param end The end time of the word in seconds.
*/
public record Word(String word, Float start, Float end) {
}

/**
* Segment of the transcribed text and its corresponding details.
*
* @param id Unique identifier of the segment.
* @param seek Seek offset of the segment.
* @param start Start time of the segment in seconds.
* @param end End time of the segment in seconds.
* @param text The text content of the segment.
* @param tokens Array of token IDs for the text content.
* @param temperature Temperature parameter used for generating the segment.
* @param avgLogprob Average logprob of the segment. If the value is lower than * -1,
* consider the logprobs failed.
* @param compressionRatio Compression ratio of the segment. If the value is greater
* than 2.4, consider the compression failed.
* @param noSpeechProb Probability of no speech in the segment. If the value is higher
* than 1.0 and the avg_logprob is below -1, consider this segment silent.
*/
public record Segment(Integer id, Integer seek, Float start, Float end, String text, List<Integer> tokens,
Float temperature, Float avgLogprob, Float compressionRatio, Float noSpeechProb) {
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.springframework.ai.model;

import io.micrometer.common.lang.NonNull;
import io.micrometer.common.lang.Nullable;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;

import java.util.Collections;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.springframework.ai.model;

import io.micrometer.common.lang.NonNull;
import io.micrometer.common.lang.Nullable;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;

import java.util.Collections;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/
package org.springframework.ai.model;

import io.micrometer.common.lang.NonNull;
import io.micrometer.common.lang.Nullable;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;

import java.util.Map;
import java.util.Set;
Expand Down