From 2be246b8bcc7d1c75ad52d586f57ba96a13cc217 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 8 May 2024 14:37:37 -0400 Subject: [PATCH] [Vertex AI] Add responseMIMEType to GenerationConfig --- .../Sources/GenerationConfig.swift | 10 +- .../Tests/Unit/GenerationConfigTests.swift | 94 +++++++++++++++++++ .../Tests/Unit/VertexAIAPITests.swift | 3 +- 3 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift diff --git a/FirebaseVertexAI/Sources/GenerationConfig.swift b/FirebaseVertexAI/Sources/GenerationConfig.swift index 18bf2bd1a52..ee527700d9a 100644 --- a/FirebaseVertexAI/Sources/GenerationConfig.swift +++ b/FirebaseVertexAI/Sources/GenerationConfig.swift @@ -63,6 +63,13 @@ public struct GenerationConfig { /// The stop sequence will not be included as part of the response. public let stopSequences: [String]? + /// Output response MIME type of the generated candidate text. + /// + /// Supported MIME types: + /// - `text/plain`: Text output; the default behavior if unspecified. + /// - `application/json`: JSON response in the candidates. + public let responseMIMEType: String? + /// Creates a new `GenerationConfig` value. /// /// - Parameter temperature: See ``temperature`` @@ -73,7 +80,7 @@ public struct GenerationConfig { /// - Parameter stopSequences: See ``stopSequences`` public init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil, candidateCount: Int? = nil, maxOutputTokens: Int? = nil, - stopSequences: [String]? = nil) { + stopSequences: [String]? = nil, responseMIMEType: String? = nil) { // Explicit init because otherwise if we re-arrange the above variables it changes the API // surface. self.temperature = temperature @@ -82,6 +89,7 @@ public struct GenerationConfig { self.candidateCount = candidateCount self.maxOutputTokens = maxOutputTokens self.stopSequences = stopSequences + self.responseMIMEType = responseMIMEType } } diff --git a/FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift b/FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift new file mode 100644 index 00000000000..f925a70effa --- /dev/null +++ b/FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift @@ -0,0 +1,94 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import FirebaseVertexAI +import Foundation +import XCTest + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +final class GenerationConfigTests: XCTestCase { + let encoder = JSONEncoder() + + override func setUp() { + encoder.outputFormatting = .init( + arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes + ) + } + + // MARK: GenerationConfig Encoding + + func testEncodeGenerationConfig_default() throws { + let generationConfig = GenerationConfig() + + let jsonData = try encoder.encode(generationConfig) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + + } + """) + } + + func testEncodeGenerationConfig_allOptions() throws { + let temperature: Float = 0.5 + let topP: Float = 0.75 + let topK = 40 + let candidateCount = 2 + let maxOutputTokens = 256 + let stopSequences = ["END", "DONE"] + let responseMIMEType = "text/plain" + let generationConfig = GenerationConfig( + temperature: temperature, + topP: topP, + topK: topK, + candidateCount: candidateCount, + maxOutputTokens: maxOutputTokens, + stopSequences: stopSequences, + responseMIMEType: responseMIMEType + ) + + let jsonData = try encoder.encode(generationConfig) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "candidateCount" : \(candidateCount), + "maxOutputTokens" : \(maxOutputTokens), + "responseMIMEType" : "\(responseMIMEType)", + "stopSequences" : [ + "END", + "DONE" + ], + "temperature" : \(temperature), + "topK" : \(topK), + "topP" : \(topP) + } + """) + } + + func testEncodeGenerationConfig_responseMIMEType() throws { + let mimeType = "image/jpeg" + let generationConfig = GenerationConfig(responseMIMEType: mimeType) + + let jsonData = try encoder.encode(generationConfig) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "responseMIMEType" : "\(mimeType)" + } + """) + } +} diff --git a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift index 438247961b8..76a649d964d 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift @@ -30,7 +30,8 @@ final class VertexAIAPITests: XCTestCase { topK: 16, candidateCount: 4, maxOutputTokens: 256, - stopSequences: ["..."]) + stopSequences: ["..."], + responseMIMEType: "text/plain") let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)] let systemInstruction = ModelContent(role: "system", parts: [.text("Talk like a pirate.")])