Skip to content

Commit

Permalink
[Vertex AI] Make text computed property handle mixed-parts responses (
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed May 13, 2024
1 parent 849680a commit 2ba8537
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 2 deletions.
10 changes: 8 additions & 2 deletions FirebaseVertexAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,17 @@ public struct GenerateContentResponse {
Logging.default.error("Could not get text from a response that had no candidates.")
return nil
}
guard let text = candidate.content.parts.first?.text else {
let textValues: [String] = candidate.content.parts.compactMap { part in
guard case let .text(text) = part else {
return nil
}
return text
}
guard textValues.count > 0 else {
Logging.default.error("Could not get a text part from the first candidate.")
return nil
}
return text
return textValues.joined(separator: " ")
}

/// Returns function calls found in any `Part`s of the first candidate of the response, if any.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"candidates": [
{
"content": {
"parts": [
{
"text": "The sum of [1, 2,"
},
{
"functionCall": {
"name": "sum",
"args": {
"y": 1,
"x": 2
}
}
},
{
"text": "3] is"
},
{
"functionCall": {
"name": "sum",
"args": {
"y": 3,
"x": 3
}
}
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "sum",
"args": {
"y": 1,
"x": 2
}
}
},
{
"functionCall": {
"name": "sum",
"args": {
"y": 3,
"x": 4
}
}
},
{
"functionCall": {
"name": "sum",
"args": {
"y": 5,
"x": 6
}
}
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0
}
]
}
34 changes: 34 additions & 0 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,40 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(response.functionCalls, [functionCall])
}

func testGenerateContent_success_functionCall_parallelCalls() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-function-call-parallel-calls",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
XCTAssertEqual(candidate.content.parts.count, 3)
let functionCalls = response.functionCalls
XCTAssertEqual(functionCalls.count, 3)
}

func testGenerateContent_success_functionCall_mixedContent() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-function-call-mixed-content",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

XCTAssertEqual(response.candidates.count, 1)
let candidate = try XCTUnwrap(response.candidates.first)
XCTAssertEqual(candidate.content.parts.count, 4)
let functionCalls = response.functionCalls
XCTAssertEqual(functionCalls.count, 2)
let text = try XCTUnwrap(response.text)
XCTAssertEqual(text, "The sum of [1, 2, 3] is")
}

func testGenerateContent_appCheck_validToken() async throws {
let appCheckToken = "test-valid-token"
model = GenerativeModel(
Expand Down

0 comments on commit 2ba8537

Please sign in to comment.