Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add New Structured Response Feature On Assistant Creation #391

Merged
merged 8 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@ import com.aallam.openai.api.assistant.AssistantResponseFormat
import com.aallam.openai.api.assistant.AssistantTool
import com.aallam.openai.api.assistant.assistantRequest
import com.aallam.openai.api.chat.ToolCall
import com.aallam.openai.api.core.RequestOptions
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.api.run.RequiredAction
import com.aallam.openai.api.run.Run
import com.aallam.openai.client.internal.JsonLenient
import kotlin.test.*
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.put
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertIs
import kotlin.test.assertNull
import kotlin.test.assertTrue

class TestAssistants : TestOpenAI() {

Expand Down Expand Up @@ -144,4 +151,64 @@ class TestAssistants : TestOpenAI() {
val action = decoded.requiredAction as RequiredAction.SubmitToolOutputs
assertIs<ToolCall.Function>(action.toolOutputs.toolCalls.first())
}

@Test
fun jsonSchemaAssistant() = test {
val jsonSchema = AssistantResponseFormat.JSON_SCHEMA(
name = "TestSchema",
description = "A test schema",
schema = buildJsonObject {
put("type", "object")
put("properties", buildJsonObject {
put("name", buildJsonObject {
put("type", "string")
})
})
put("required", JsonArray(listOf(JsonPrimitive("name"))))
put("additionalProperties", false)
},
strict = true
)

val request = assistantRequest {
name = "Schema Assistant"
model = ModelId("gpt-4o-mini")
responseFormat = jsonSchema
}

val assistant = openAI.assistant(
request = request,
)
assertEquals(request.name, assistant.name)
assertEquals(request.model, assistant.model)
assertEquals(request.responseFormat, assistant.responseFormat)

val getAssistant = openAI.assistant(
assistant.id,
)
assertEquals(getAssistant, assistant)

val assistants = openAI.assistants()
assertTrue { assistants.isNotEmpty() }

val updated = assistantRequest {
name = "Updated Schema Assistant"
responseFormat = AssistantResponseFormat.AUTO
}
val updatedAssistant = openAI.assistant(
assistant.id,
updated,
)
assertEquals(updated.name, updatedAssistant.name)
assertEquals(updated.responseFormat, updatedAssistant.responseFormat)

openAI.delete(
updatedAssistant.id,
)

val fileGetAfterDelete = openAI.assistant(
updatedAssistant.id,
)
assertNull(fileGetAfterDelete)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,23 @@ public data class AssistantRequest(
* Specifies the format that the model must output. Compatible with GPT-4o, GPT-4 Turbo, and all GPT-3.5 Turbo
* models since gpt-3.5-turbo-1106.
*
* Setting to [AssistantResponseFormat.JsonObject] enables JSON mode, which guarantees the message the model
* Setting to [AssistantResponseFormat.JSON_SCHEMA] enables Structured Outputs which ensures the model will match your supplied JSON schema.
*
* Structured Outputs ([AssistantResponseFormat.JSON_SCHEMA]) are available in our latest large language models, starting with GPT-4o:
* 1. gpt-4o-mini-2024-07-18 and later
* 2. gpt-4o-2024-08-06 and later
*
* Older models like gpt-4-turbo and earlier may use JSON mode ([AssistantResponseFormat.JSON_OBJECT]) instead.
*
* Setting to [AssistantResponseFormat.JSON_OBJECT] enables JSON mode, which guarantees the message the model
* generates is valid JSON.
*
* important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user
* message. Without this, the model may generate an unending stream of whitespace until the generation reaches the
* token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be
* partially cut off if finish_reason="length", which indicates the generation exceeded max_tokens or
* the conversation exceeded the max context length.
*
*/
@SerialName("response_format") val responseFormat: AssistantResponseFormat? = null,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,51 +9,73 @@ import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.descriptors.element
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonObjectBuilder
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.booleanOrNull
import kotlinx.serialization.json.contentOrNull
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive

/**
* string: auto is the default value
* Represents the format of the response from the assistant.
*
* object: An object describing the expected output of the model. If json_object only function type tools are allowed to be passed to the Run.
* If text, the model can return text or any value needed.
* type: string Must be one of text or json_object.
* @property type The type of the response format.
* @property jsonSchema The JSON schema associated with the response format, if type is "json_schema" otherwise null.
*/
@BetaOpenAI
@Serializable(with = AssistantResponseFormat.ResponseFormatSerializer::class)
public data class AssistantResponseFormat(
val format: String? = null,
val objectType: AssistantResponseType? = null,
val type: String,
val jsonSchema: JsonSchema? = null
) {

/**
* Represents a JSON schema.
*
* @property name The name of the schema.
* @property description The description of the schema.
* @property schema The actual JSON schema.
* @property strict Indicates if the schema is strict.
*/
@Serializable
public data class AssistantResponseType(
val type: String
public data class JsonSchema(
val name: String,
val description: String? = null,
val schema: JsonObject,
val strict: Boolean? = null
)

public companion object {
public val AUTO: AssistantResponseFormat = AssistantResponseFormat(format = "auto")
public val TEXT: AssistantResponseFormat = AssistantResponseFormat(objectType = AssistantResponseType(type = "text"))
public val JSON_OBJECT: AssistantResponseFormat = AssistantResponseFormat(objectType = AssistantResponseType(type = "json_object"))
public val AUTO: AssistantResponseFormat = AssistantResponseFormat("auto")
public val TEXT: AssistantResponseFormat = AssistantResponseFormat("text")
public val JSON_OBJECT: AssistantResponseFormat = AssistantResponseFormat("json_object")

/**
* Creates an instance of `AssistantResponseFormat` with type `json_schema`.
*
* @param name The name of the schema.
* @param description The description of the schema.
* @param schema The actual JSON schema.
* @param strict Indicates if the schema is strict.
* @return An instance of `AssistantResponseFormat` with the specified JSON schema.
*/
public fun JSON_SCHEMA(
name: String,
description: String? = null,
schema: JsonObject,
strict: Boolean? = null
): AssistantResponseFormat = AssistantResponseFormat(
"json_schema",
JsonSchema(name, description, schema, strict)
)
}


public object ResponseFormatSerializer : KSerializer<AssistantResponseFormat> {
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("AssistantResponseFormat") {
element<String>("format", isOptional = true)
element<AssistantResponseType>("type", isOptional = true)
}

override fun serialize(encoder: Encoder, value: AssistantResponseFormat) {
val jsonEncoder = encoder as? kotlinx.serialization.json.JsonEncoder
?: throw SerializationException("This class can be saved only by Json")

if (value.format != null) {
jsonEncoder.encodeJsonElement(JsonPrimitive(value.format))
} else if (value.objectType != null) {
val jsonElement: JsonElement = JsonObject(mapOf("type" to JsonPrimitive(value.objectType.type)))
jsonEncoder.encodeJsonElement(jsonElement)
}
element<String>("type")
element<JsonSchema>("json_schema", isOptional = true) // Only for "json_schema" type
}

override fun deserialize(decoder: Decoder): AssistantResponseFormat {
Expand All @@ -63,14 +85,63 @@ public data class AssistantResponseFormat(
val jsonElement = jsonDecoder.decodeJsonElement()
return when {
jsonElement is JsonPrimitive && jsonElement.isString -> {
AssistantResponseFormat(format = jsonElement.content)
AssistantResponseFormat(type = jsonElement.content)
}
jsonElement is JsonObject && "type" in jsonElement -> {
val type = jsonElement["type"]!!.jsonPrimitive.content
AssistantResponseFormat(objectType = AssistantResponseType(type))
when (type) {
"json_schema" -> {
val schemaObject = jsonElement["json_schema"]?.jsonObject
val name = schemaObject?.get("name")?.jsonPrimitive?.content ?: ""
val description = schemaObject?.get("description")?.jsonPrimitive?.contentOrNull
val schema = schemaObject?.get("schema")?.jsonObject ?: JsonObject(emptyMap())
val strict = schemaObject?.get("strict")?.jsonPrimitive?.booleanOrNull
AssistantResponseFormat(
type = "json_schema",
jsonSchema = JsonSchema(name = name, description = description, schema = schema, strict = strict)
)
}
"json_object" -> AssistantResponseFormat(type = "json_object")
"auto" -> AssistantResponseFormat(type = "auto")
"text" -> AssistantResponseFormat(type = "text")
else -> throw SerializationException("Unknown response format type: $type")
}
}
else -> throw SerializationException("Unknown response format: $jsonElement")
}
}

override fun serialize(encoder: Encoder, value: AssistantResponseFormat) {
val jsonEncoder = encoder as? kotlinx.serialization.json.JsonEncoder
?: throw SerializationException("This class can be saved only by Json")

val jsonElement = when (value.type) {
"json_schema" -> {
JsonObject(
mapOf(
"type" to JsonPrimitive("json_schema"),
"json_schema" to JsonObject(
mapOf(
"name" to JsonPrimitive(value.jsonSchema?.name ?: ""),
"description" to JsonPrimitive(value.jsonSchema?.description ?: ""),
"schema" to (value.jsonSchema?.schema ?: JsonObject(emptyMap())),
"strict" to JsonPrimitive(value.jsonSchema?.strict ?: false)
)
)
)
)
}
"json_object" -> JsonObject(mapOf("type" to JsonPrimitive("json_object")))
"auto" -> JsonPrimitive("auto")
"text" -> JsonObject(mapOf("type" to JsonPrimitive("text")))
else -> throw SerializationException("Unsupported response format type: ${value.type}")
}
jsonEncoder.encodeJsonElement(jsonElement)
}

}
}

public fun JsonObject.Companion.buildJsonObject(block: JsonObjectBuilder.() -> Unit): JsonObject {
return kotlinx.serialization.json.buildJsonObject(block)
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public data class Run(
@SerialName("usage") public val usage: Usage? = null,

/**
* The Unix timestamp (in seconds) for when the run was completed.
* The sampling temperature used for this run. If not set, defaults to 1.
*/
@SerialName("temperature") val temperature: Double? = null,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.aallam.openai.sample.jvm

import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.assistant.AssistantRequest
import com.aallam.openai.api.assistant.AssistantResponseFormat
import com.aallam.openai.api.assistant.AssistantTool
import com.aallam.openai.api.assistant.Function
import com.aallam.openai.api.chat.ToolCall
Expand All @@ -17,7 +18,10 @@ import com.aallam.openai.api.run.RunRequest
import com.aallam.openai.api.run.ToolOutput
import com.aallam.openai.client.OpenAI
import kotlinx.coroutines.delay
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.add
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.put
import kotlinx.serialization.json.putJsonArray
import kotlinx.serialization.json.putJsonObject
Expand All @@ -29,6 +33,36 @@ suspend fun assistantsFunctions(openAI: OpenAI) {
request = AssistantRequest(
name = "Math Tutor",
instructions = "You are a weather bot. Use the provided functions to answer questions.",
responseFormat = AssistantResponseFormat.JSON_SCHEMA(
name = "math_response",
strict = true,
schema = buildJsonObject {
put("type", "object")
putJsonObject("properties") {
putJsonObject("steps") {
put("type", "array")
putJsonObject("items") {
put("type", "object")
putJsonObject("properties") {
putJsonObject("explanation") {
put("type", "string")
}
putJsonObject("output") {
put("type", "string")
}
}
put("required", JsonArray(listOf(JsonPrimitive("explanation"), JsonPrimitive("output"))))
put("additionalProperties", false)
}
}
putJsonObject("final_answer") {
put("type", "string")
}
}
put("additionalProperties", false)
put("required", JsonArray(listOf(JsonPrimitive("steps"), JsonPrimitive("final_answer"))))
},
),
tools = listOf(
AssistantTool.FunctionTool(
function = Function(
Expand Down Expand Up @@ -74,7 +108,7 @@ suspend fun assistantsFunctions(openAI: OpenAI) {
)
)
),
model = ModelId("gpt-4-1106-preview")
model = ModelId("gpt-4o-mini")
)
)

Expand Down
Loading