diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 01ab8b96c02..b7b8234001a 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -20,6 +20,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -587,35 +588,27 @@ protected List responseCandidateToGeneration(Candidate candidate) { .finishReason(candidateFinishReason.name()) .build(); - boolean isFunctionCall = candidate.getContent().getPartsList().stream().allMatch(Part::hasFunctionCall); - - if (isFunctionCall) { - List assistantToolCalls = candidate.getContent() - .getPartsList() - .stream() - .filter(part -> part.hasFunctionCall()) - .map(part -> { - FunctionCall functionCall = part.getFunctionCall(); - var functionName = functionCall.getName(); - String functionArguments = structToJson(functionCall.getArgs()); - return new AssistantMessage.ToolCall("", "function", functionName, functionArguments); - }) - .toList(); + var assistantToolCalls = candidate.getContent() + .getPartsList() + .stream() + .filter(Part::hasFunctionCall) + .map(part -> { + FunctionCall functionCall = part.getFunctionCall(); + var functionName = functionCall.getName(); + String functionArguments = structToJson(functionCall.getArgs()); + return new AssistantMessage.ToolCall("", "function", functionName, functionArguments); + }) + .toList(); - AssistantMessage assistantMessage = new AssistantMessage("", messageMetadata, assistantToolCalls); + var text = candidate.getContent() + .getPartsList() + .stream() + .filter(Part::hasText) + .map(Part::getText) + .collect(Collectors.joining(System.lineSeparator())); - return List.of(new Generation(assistantMessage, chatGenerationMetadata)); - } - else { - List generations = candidate.getContent() - .getPartsList() - .stream() - .map(part -> new AssistantMessage(part.getText(), messageMetadata)) - .map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata)) - .toList(); - - return generations; - } + return List.of(new Generation(new AssistantMessage(text, messageMetadata, assistantToolCalls), + chatGenerationMetadata)); } private ChatResponseMetadata toChatResponseMetadata(Usage usage) { diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java index 2f79a8b948a..1a0b2942ad2 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java @@ -96,6 +96,30 @@ public void functionCallExplicitOpenApiSchema() { assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); } + @Test + public void functionCallModelReturnsMixedTextAndFunctionCallParts() { + UserMessage userMessage = new UserMessage( + "What can you tell me about the temperature in San Francisco, Paris and in Tokyo? Return the temperature in Celsius. Expose your thinking process."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = VertexAiGeminiChatOptions.builder() + .model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH) + .toolCallbacks(List.of(FunctionToolCallback.builder("get_current_weather", new MockWeatherService()) + .description("Get the current weather in a given location.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse chatResponse = this.chatModel.call(new Prompt(messages, promptOptions)); + + assertThat(chatResponse.getResult().getOutput().getText()).contains("30", "10", "15"); + + assertThat(chatResponse.getMetadata()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage()).isNotNull(); + assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(550); + } + @Test public void functionCallTestInferredOpenApiSchema() {