1
1
package com.openai.helpers
2
2
3
- import com.openai.core.JsonField
4
3
import com.openai.core.JsonNull
5
4
import com.openai.core.JsonValue
6
5
import com.openai.errors.OpenAIInvalidDataException
@@ -68,6 +67,27 @@ class ChatCompletionAccumulator private constructor() {
68
67
*/
69
68
private val logprobsBuilders = mutableMapOf<Long , ChatCompletion .Choice .Logprobs .Builder >()
70
69
70
+ /* *
71
+ * The accumulated tool call builders for each message. The "outer" keys correspond to the
72
+ * indexes in [messageBuilders] (the choice index). The "inner" keys correspond to the position
73
+ * of each tool call in the message's list of tool calls (the tool call index).
74
+ */
75
+ private val toolCallBuilders =
76
+ mutableMapOf<Long , MutableMap <Long , ChatCompletionMessageToolCall .Builder >>()
77
+
78
+ /* *
79
+ * The accumulated tool call function builders for the tool call builders of each message. The
80
+ * entries correspond to those in [toolCallBuilders].
81
+ */
82
+ private val toolCallFunctionBuilders =
83
+ mutableMapOf<Long , MutableMap <Long , ChatCompletionMessageToolCall .Function .Builder >>()
84
+
85
+ /* *
86
+ * The accumulated tool call function arguments that will be set on the function builders when
87
+ * completed. The entries correspond to those in [toolCallFunctionBuilders].
88
+ */
89
+ private val toolCallFunctionArgs = mutableMapOf<Long , MutableMap <Long , String >>()
90
+
71
91
/* *
72
92
* The finished status of each of the `n` completions. When a chunk with a `finishReason` is
73
93
* encountered, its index is recorded against a `true` value. When a `true` has been recorded
@@ -80,27 +100,6 @@ class ChatCompletionAccumulator private constructor() {
80
100
companion object {
81
101
@JvmStatic fun create () = ChatCompletionAccumulator ()
82
102
83
- @JvmSynthetic
84
- internal fun convertToolCall (chunkToolCall : ChatCompletionChunk .Choice .Delta .ToolCall ) =
85
- ChatCompletionMessageToolCall .builder()
86
- .id(chunkToolCall._id ())
87
- .function(convertToolCallFunction(chunkToolCall._function ()))
88
- .additionalProperties(chunkToolCall._additionalProperties ())
89
- // Let the `type` default to "function".
90
- .build()
91
-
92
- @JvmSynthetic
93
- internal fun convertToolCallFunction (
94
- chunkToolCallFunction : JsonField <ChatCompletionChunk .Choice .Delta .ToolCall .Function >
95
- ): JsonField <ChatCompletionMessageToolCall .Function > =
96
- chunkToolCallFunction.map { function ->
97
- ChatCompletionMessageToolCall .Function .builder()
98
- .name(function._name ())
99
- .arguments(function._arguments ())
100
- .additionalProperties(function._additionalProperties ())
101
- .build()
102
- }
103
-
104
103
@JvmSynthetic
105
104
internal fun convertFunctionCall (
106
105
chunkFunctionCall : ChatCompletionChunk .Choice .Delta .FunctionCall
@@ -253,14 +252,48 @@ class ChatCompletionAccumulator private constructor() {
253
252
delta.role().ifPresent { messageBuilder.role(JsonValue .from(it.asString())) }
254
253
delta.functionCall().ifPresent { messageBuilder.functionCall(convertFunctionCall(it)) }
255
254
256
- // Add the `ToolCall` objects in the order in which they are encountered.
257
- // (`...Delta.ToolCall.index` is not documented, so it is ignored here.)
258
- delta.toolCalls().ifPresent { it.map { messageBuilder.addToolCall(convertToolCall(it)) } }
255
+ delta.toolCalls().ifPresent {
256
+ it.map { deltaToolCall ->
257
+ // The first chunk delta will carry the tool call ID and the function name. Later
258
+ // deltas will carry only fragments of the function arguments, but the tool call
259
+ // index will identify the function to which those argument fragments belong.
260
+ val messageToolCallBuilders = toolCallBuilders.getOrPut(index) { mutableMapOf () }
261
+
262
+ messageToolCallBuilders.getOrPut(deltaToolCall.index()) {
263
+ ChatCompletionMessageToolCall .builder()
264
+ .id(deltaToolCall._id ())
265
+ .additionalProperties(deltaToolCall._additionalProperties ())
266
+ // Must wait until the `function` is accumulated and built before adding it to
267
+ // the tool call later when `buildChoices` is called.
268
+ }
269
+
270
+ val messageToolCallFunctionBuilders =
271
+ toolCallFunctionBuilders.getOrPut(index) { mutableMapOf () }
272
+
273
+ messageToolCallFunctionBuilders.getOrPut(deltaToolCall.index()) {
274
+ ChatCompletionMessageToolCall .Function .builder()
275
+ .name(ensureFunction(deltaToolCall.function())._name ())
276
+ .additionalProperties(deltaToolCall._additionalProperties ())
277
+ }
278
+
279
+ val messageToolCallFunctionArgs =
280
+ toolCallFunctionArgs.getOrPut(index) { mutableMapOf () }
281
+
282
+ messageToolCallFunctionArgs[deltaToolCall.index()] =
283
+ (messageToolCallFunctionArgs[deltaToolCall.index()] ? : " " ) +
284
+ (ensureFunction(deltaToolCall.function()).arguments().getOrNull() ? : " " )
285
+ }
286
+ }
287
+
259
288
messageBuilder.putAllAdditionalProperties(delta._additionalProperties ())
260
289
}
261
290
262
- @JvmSynthetic
263
- internal fun buildChoices () =
291
+ private fun ensureFunction (
292
+ function : Optional <ChatCompletionChunk .Choice .Delta .ToolCall .Function >
293
+ ): ChatCompletionChunk .Choice .Delta .ToolCall .Function =
294
+ function.orElseThrow { OpenAIInvalidDataException (" Tool call chunk missing function." ) }
295
+
296
+ private fun buildChoices () =
264
297
choiceBuilders.entries
265
298
.sortedBy { it.key }
266
299
.map {
@@ -270,13 +303,41 @@ class ChatCompletionAccumulator private constructor() {
270
303
.build()
271
304
}
272
305
273
- @JvmSynthetic
274
- internal fun buildMessage (index : Long ) =
306
+ private fun buildMessage (index : Long ) =
275
307
messageBuilders
276
308
.getOrElse(index) {
277
309
throw OpenAIInvalidDataException (" Missing message for index $index ." )
278
310
}
279
311
.content(messageContents[index])
280
312
.refusal(messageRefusals[index])
313
+ .toolCalls(buildToolCalls(index))
281
314
.build()
315
+
316
+ private fun buildToolCalls (index : Long ): List <ChatCompletionMessageToolCall > =
317
+ // It is OK for a message not to have any tool calls; most will not and an empty list will
318
+ // be returned. An entry (if it exists) will be a collection of tool call builders and each
319
+ // has a function that needs to be set.
320
+ toolCallBuilders[index]
321
+ ?.entries
322
+ ?.sortedBy { it.key }
323
+ ?.map { messageToolCallBuilderEntry ->
324
+ messageToolCallBuilderEntry.value
325
+ .function(buildFunction(index, messageToolCallBuilderEntry.key))
326
+ .build()
327
+ } ? : listOf ()
328
+
329
+ private fun buildFunction (index : Long , toolCallIndex : Long ) =
330
+ // Every tool call is expected to have a function with arguments.
331
+ toolCallFunctionBuilders[index]
332
+ ?.get(toolCallIndex)
333
+ ?.arguments(
334
+ toolCallFunctionArgs[index]?.get(toolCallIndex)
335
+ ? : throw OpenAIInvalidDataException (
336
+ " Missing function arguments for index $index .$toolCallIndex ."
337
+ )
338
+ )
339
+ ?.build()
340
+ ? : throw OpenAIInvalidDataException (
341
+ " Missing function builder for index $index .$toolCallIndex ."
342
+ )
282
343
}
0 commit comments