Skip to content

Commit

Permalink
Add responseMIMEType to GenerationConfig (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored May 6, 2024
1 parent fcc1084 commit 010fb95
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 8 deletions.
24 changes: 17 additions & 7 deletions Sources/GoogleAI/GenerationConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,26 @@ public struct GenerationConfig {
/// The stop sequence will not be included as part of the response.
public let stopSequences: [String]?

/// Output response MIME type of the generated candidate text.
///
/// Supported MIME types:
/// - `text/plain`: Text output; the default behavior if unspecified.
/// - `application/json`: JSON response in the candidates.
public let responseMIMEType: String?

/// Creates a new `GenerationConfig` value.
///
/// - Parameter temperature: See ``temperature``
/// - Parameter topP: See ``topP``
/// - Parameter topK: See ``topK``
/// - Parameter candidateCount: See ``candidateCount``
/// - Parameter maxOutputTokens: See ``maxOutputTokens``
/// - Parameter stopSequences: See ``stopSequences``
/// - Parameters:
/// - temperature: See ``temperature``.
/// - topP: See ``topP``.
/// - topK: See ``topK``.
/// - candidateCount: See ``candidateCount``.
/// - maxOutputTokens: See ``maxOutputTokens``.
/// - stopSequences: See ``stopSequences``.
/// - responseMIMEType: See ``responseMIMEType``.
public init(temperature: Float? = nil, topP: Float? = nil, topK: Int? = nil,
candidateCount: Int? = nil, maxOutputTokens: Int? = nil,
stopSequences: [String]? = nil) {
stopSequences: [String]? = nil, responseMIMEType: String? = nil) {
// Explicit init because otherwise if we re-arrange the above variables it changes the API
// surface.
self.temperature = temperature
Expand All @@ -82,6 +91,7 @@ public struct GenerationConfig {
self.candidateCount = candidateCount
self.maxOutputTokens = maxOutputTokens
self.stopSequences = stopSequences
self.responseMIMEType = responseMIMEType
}
}

Expand Down
94 changes: 94 additions & 0 deletions Tests/GoogleAITests/GenerationConfigTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation
import GoogleGenerativeAI
import XCTest

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
final class GenerationConfigTests: XCTestCase {
let encoder = JSONEncoder()

override func setUp() {
encoder.outputFormatting = .init(
arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
)
}

// MARK: GenerationConfig Encoding

func testEncodeGenerationConfig_default() throws {
let generationConfig = GenerationConfig()

let jsonData = try encoder.encode(generationConfig)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
}
""")
}

func testEncodeGenerationConfig_allOptions() throws {
let temperature: Float = 0.5
let topP: Float = 0.75
let topK = 40
let candidateCount = 2
let maxOutputTokens = 256
let stopSequences = ["END", "DONE"]
let responseMIMEType = "text/plain"
let generationConfig = GenerationConfig(
temperature: temperature,
topP: topP,
topK: topK,
candidateCount: candidateCount,
maxOutputTokens: maxOutputTokens,
stopSequences: stopSequences,
responseMIMEType: responseMIMEType
)

let jsonData = try encoder.encode(generationConfig)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"candidateCount" : \(candidateCount),
"maxOutputTokens" : \(maxOutputTokens),
"responseMIMEType" : "\(responseMIMEType)",
"stopSequences" : [
"END",
"DONE"
],
"temperature" : \(temperature),
"topK" : \(topK),
"topP" : \(topP)
}
""")
}

func testEncodeGenerationConfig_responseMIMEType() throws {
let mimeType = "image/jpeg"
let generationConfig = GenerationConfig(responseMIMEType: mimeType)

let jsonData = try encoder.encode(generationConfig)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"responseMIMEType" : "\(mimeType)"
}
""")
}
}
3 changes: 2 additions & 1 deletion Tests/GoogleAITests/GoogleAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ final class GoogleGenerativeAITests: XCTestCase {
topK: 16,
candidateCount: 4,
maxOutputTokens: 256,
stopSequences: ["..."])
stopSequences: ["..."],
responseMIMEType: "text/plain")
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]
let systemInstruction = ModelContent(role: "system", parts: [.text("Talk like a pirate.")])

Expand Down

0 comments on commit 010fb95

Please sign in to comment.