Skip to content

Commit

Permalink
feat(chat): allow specifying JSON schema for chat completions (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewAcomb authored Nov 18, 2024
1 parent 593e327 commit 1f4cce1
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 2 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Unreleased

### Added
- **chat**: Add support for structured outputs (#397)

## 4.0.0-beta01
> Published 27 Oct 2024
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.aallam.openai.client

import com.aallam.openai.api.chat.*
import com.aallam.openai.api.chat.ChatResponseFormat.Companion.jsonSchema
import com.aallam.openai.api.model.ModelId
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.launchIn
Expand All @@ -9,6 +10,9 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlin.coroutines.cancellation.CancellationException
import kotlin.test.*

Expand Down Expand Up @@ -131,6 +135,62 @@ class TestChatCompletions : TestOpenAI() {
assertNotNull(answer.response)
}

@Test
fun jsonSchema() = test {
val schemaJson = JsonObject(mapOf(
"type" to JsonPrimitive("object"),
"properties" to JsonObject(mapOf(
"question" to JsonObject(mapOf(
"type" to JsonPrimitive("string"),
"description" to JsonPrimitive("The question that was asked")
)),
"response" to JsonObject(mapOf(
"type" to JsonPrimitive("string"),
"description" to JsonPrimitive("The answer to the question")
))
)),
"required" to JsonArray(listOf(
JsonPrimitive("question"),
JsonPrimitive("response")
))
))

val jsonSchema = JsonSchema(
name = "AnswerSchema",
schema = schemaJson,
strict = true
)

val request = chatCompletionRequest {
model = ModelId("gpt-4o-mini-2024-07-18")
responseFormat = jsonSchema(jsonSchema)
messages {
message {
role = ChatRole.System
content = "You are a helpful assistant.!"
}
message {
role = ChatRole.System
content = """All your answers should be a valid JSON
""".trimMargin()
}
message {
role = ChatRole.User
content = "Who won the world cup in 1998?"
}
}
}
val response = openAI.chatCompletion(request)
val content = response.choices.first().message.content.orEmpty()

@Serializable
data class Answer(val question: String? = null, val response: String? = null)

val answer = Json.decodeFromString<Answer>(content)
assertNotNull(answer.question)
assertNotNull(answer.response)
}

@Test
fun logprobs() = test {
val request = chatCompletionRequest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.aallam.openai.api.chat

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonObject

/**
* An object specifying the format that the model must output.
Expand All @@ -11,9 +12,13 @@ public data class ChatResponseFormat(
/**
* Response format type.
*/
@SerialName("type") val type: String
) {
@SerialName("type") val type: String,

/**
* Optional JSON schema specification when type is "json_schema"
*/
@SerialName("json_schema") val jsonSchema: JsonSchema? = null
) {
public companion object {
/**
* JSON mode, which guarantees the message the model generates, is valid JSON.
Expand All @@ -24,5 +29,32 @@ public data class ChatResponseFormat(
* Default text mode.
*/
public val Text: ChatResponseFormat = ChatResponseFormat(type = "text")

/**
* Creates a JSON schema response format with the specified schema
*/
public fun jsonSchema(schema: JsonSchema): ChatResponseFormat =
ChatResponseFormat(type = "json_schema", jsonSchema = schema)
}
}

/**
* Specification for JSON schema response format
*/
@Serializable
public data class JsonSchema(
/**
* Optional name for the schema
*/
@SerialName("name") val name: String? = null,

/**
* The JSON schema specification
*/
@SerialName("schema") val schema: JsonObject,

/**
* Whether to enforce strict schema validation
*/
@SerialName("strict") val strict: Boolean = true
)

0 comments on commit 1f4cce1

Please sign in to comment.