Skip to content

Commit

Permalink
[Vertex AI] Add HarmSeverity enum and SafetyRating properties
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Oct 11, 2024
1 parent 3eaa04d commit e1e9796
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 37 deletions.
72 changes: 64 additions & 8 deletions FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,38 @@ struct ErrorDetailsView: View {
Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo.
"""),
safetyRatings: [
SafetyRating(category: .dangerousContent, probability: .high),
SafetyRating(category: .harassment, probability: .low),
SafetyRating(category: .hateSpeech, probability: .low),
SafetyRating(category: .sexuallyExplicit, probability: .low),
SafetyRating(
category: .dangerousContent,
probability: .medium,
probabilityScore: 0.8,
severity: .medium,
severityScore: 0.9,
blocked: false
),
SafetyRating(
category: .harassment,
probability: .low,
probabilityScore: 0.5,
severity: .low,
severityScore: 0.6,
blocked: false
),
SafetyRating(
category: .hateSpeech,
probability: .low,
probabilityScore: 0.3,
severity: .medium,
severityScore: 0.2,
blocked: false
),
SafetyRating(
category: .sexuallyExplicit,
probability: .low,
probabilityScore: 0.2,
severity: .negligible,
severityScore: 0.5,
blocked: false
),
],
finishReason: FinishReason.maxTokens,
citationMetadata: nil),
Expand All @@ -190,10 +218,38 @@ struct ErrorDetailsView: View {
Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo.
"""),
safetyRatings: [
SafetyRating(category: .dangerousContent, probability: .high),
SafetyRating(category: .harassment, probability: .low),
SafetyRating(category: .hateSpeech, probability: .low),
SafetyRating(category: .sexuallyExplicit, probability: .low),
SafetyRating(
category: .dangerousContent,
probability: .low,
probabilityScore: 0.8,
severity: .medium,
severityScore: 0.9,
blocked: false
),
SafetyRating(
category: .harassment,
probability: .low,
probabilityScore: 0.5,
severity: .low,
severityScore: 0.6,
blocked: false
),
SafetyRating(
category: .hateSpeech,
probability: .low,
probabilityScore: 0.3,
severity: .medium,
severityScore: 0.2,
blocked: false
),
SafetyRating(
category: .sexuallyExplicit,
probability: .low,
probabilityScore: 0.2,
severity: .negligible,
severityScore: 0.5,
blocked: false
),
],
finishReason: FinishReason.other,
citationMetadata: nil),
Expand Down
64 changes: 48 additions & 16 deletions FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,54 @@ struct ErrorView: View {
#Preview {
NavigationView {
let errorPromptBlocked = GenerateContentError.promptBlocked(
response: GenerateContentResponse(candidates: [
CandidateResponse(content: ModelContent(role: "model", parts: [
"""
A _hypothetical_ model response.
Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo.
""",
]),
safetyRatings: [
SafetyRating(category: .dangerousContent, probability: .high),
SafetyRating(category: .harassment, probability: .low),
SafetyRating(category: .hateSpeech, probability: .low),
SafetyRating(category: .sexuallyExplicit, probability: .low),
],
finishReason: FinishReason.other,
citationMetadata: nil),
])
response: GenerateContentResponse(
candidates: [
CandidateResponse(
content: ModelContent(role: "model", parts: [
"""
A _hypothetical_ model response.
Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo.
""",
]),
safetyRatings: [
SafetyRating(
category: .dangerousContent,
probability: .high,
probabilityScore: 0.8,
severity: .medium,
severityScore: 0.9,
blocked: true
),
SafetyRating(
category: .harassment,
probability: .low,
probabilityScore: 0.5,
severity: .low,
severityScore: 0.6,
blocked: false
),
SafetyRating(
category: .hateSpeech,
probability: .low,
probabilityScore: 0.3,
severity: .medium,
severityScore: 0.2,
blocked: false
),
SafetyRating(
category: .sexuallyExplicit,
probability: .low,
probabilityScore: 0.2,
severity: .negligible,
severityScore: 0.5,
blocked: false
),
],
finishReason: FinishReason.other,
citationMetadata: nil
),
]
)
)
List {
MessageView(message: ChatMessage.samples[0])
Expand Down
71 changes: 69 additions & 2 deletions FirebaseVertexAI/Sources/Safety.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,28 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
/// > Important: This does not indicate the severity of harm for a piece of content.
public let probability: HarmProbability

public let probabilityScore: Float

public let severity: HarmSeverity

public let severityScore: Float

public let blocked: Bool

/// Initializes a new `SafetyRating` instance with the given category and probability.
/// Use this initializer for SwiftUI previews or tests.
public init(category: HarmCategory, probability: HarmProbability) {
public init(category: HarmCategory,
probability: HarmProbability,
probabilityScore: Float,
severity: HarmSeverity,
severityScore: Float,
blocked: Bool) {
self.category = category
self.probability = probability
self.probabilityScore = probabilityScore
self.severity = severity
self.severityScore = severityScore
self.blocked = blocked
}

/// The probability that a given model output falls under a harmful content category.
Expand Down Expand Up @@ -74,6 +91,32 @@ public struct SafetyRating: Equatable, Hashable, Sendable {
static let unrecognizedValueMessageCode =
VertexLog.MessageCode.generateContentResponseUnrecognizedHarmProbability
}

public struct HarmSeverity: DecodableProtoEnum, Hashable, Sendable {
enum Kind: String {
case negligible = "HARM_SEVERITY_NEGLIGIBLE"
case low = "HARM_SEVERITY_LOW"
case medium = "HARM_SEVERITY_MEDIUM"
case high = "HARM_SEVERITY_HIGH"
}

public static let negligible = HarmSeverity(kind: .negligible)

public static let low = HarmSeverity(kind: .low)

public static let medium = HarmSeverity(kind: .medium)

public static let high = HarmSeverity(kind: .high)

/// Returns the raw string representation of the `HarmSeverity` value.
///
/// > Note: This value directly corresponds to the values in the [REST
/// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#HarmSeverity).
public let rawValue: String

static let unrecognizedValueMessageCode =
VertexLog.MessageCode.generateContentResponseUnrecognizedHarmSeverity
}
}

/// A type used to specify a threshold for harmful content, beyond which the model will return a
Expand Down Expand Up @@ -164,7 +207,31 @@ public struct HarmCategory: CodableProtoEnum, Hashable, Sendable {
// MARK: - Codable Conformances

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension SafetyRating: Decodable {}
extension SafetyRating: Decodable {
enum CodingKeys: CodingKey {
case category
case probability
case probabilityScore
case severity
case severityScore
case blocked
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
category = try container.decode(HarmCategory.self, forKey: .category)
probability = try container.decode(HarmProbability.self, forKey: .probability)

// The following 3 fields are only omitted in our test data.
probabilityScore = try container.decodeIfPresent(Float.self, forKey: .probabilityScore) ?? 0.0
severity = try container.decodeIfPresent(HarmSeverity.self, forKey: .severity) ??
HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED")
severityScore = try container.decodeIfPresent(Float.self, forKey: .severityScore) ?? 0.0

// The blocked field is only included when true.
blocked = try container.decodeIfPresent(Bool.self, forKey: .blocked) ?? false
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension SafetySetting.HarmBlockThreshold: Encodable {}
Expand Down
1 change: 1 addition & 0 deletions FirebaseVertexAI/Sources/VertexLog.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ enum VertexLog {
case generateContentResponseUnrecognizedBlockThreshold = 3004
case generateContentResponseUnrecognizedHarmProbability = 3005
case generateContentResponseUnrecognizedHarmCategory = 3006
case generateContentResponseUnrecognizedHarmSeverity = 3007

// SDK State Errors
case generateContentResponseNoCandidates = 4000
Expand Down
72 changes: 61 additions & 11 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,38 @@ import XCTest
final class GenerativeModelTests: XCTestCase {
let testPrompt = "What sorts of questions can I ask you?"
let safetyRatingsNegligible: [SafetyRating] = [
.init(category: .sexuallyExplicit, probability: .negligible),
.init(category: .hateSpeech, probability: .negligible),
.init(category: .harassment, probability: .negligible),
.init(category: .dangerousContent, probability: .negligible),
.init(
category: .sexuallyExplicit,
probability: .negligible,
probabilityScore: 0.1431877,
severity: .negligible,
severityScore: 0.11027937,
blocked: false
),
.init(
category: .hateSpeech,
probability: .negligible,
probabilityScore: 0.029035643,
severity: .negligible,
severityScore: 0.05613278,
blocked: false
),
.init(
category: .harassment,
probability: .negligible,
probabilityScore: 0.087252244,
severity: .negligible,
severityScore: 0.04509957,
blocked: false
),
.init(
category: .dangerousContent,
probability: .negligible,
probabilityScore: 0.2641685,
severity: .negligible,
severityScore: 0.082253955,
blocked: false
),
].sorted()
let testModelResourceName =
"projects/test-project-id/locations/test-location/publishers/google/models/test-model"
Expand Down Expand Up @@ -69,7 +97,7 @@ final class GenerativeModelTests: XCTestCase {
let candidate = try XCTUnwrap(response.candidates.first)
let finishReason = try XCTUnwrap(candidate.finishReason)
XCTAssertEqual(finishReason, .stop)
XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
XCTAssertEqual(candidate.safetyRatings.count, 4)
XCTAssertEqual(candidate.content.parts.count, 1)
let part = try XCTUnwrap(candidate.content.parts.first)
let partText = try XCTUnwrap(part as? TextPart).text
Expand Down Expand Up @@ -148,25 +176,43 @@ final class GenerativeModelTests: XCTestCase {
let candidate = try XCTUnwrap(response.candidates.first)
let finishReason = try XCTUnwrap(candidate.finishReason)
XCTAssertEqual(finishReason, .stop)
XCTAssertEqual(candidate.safetyRatings.sorted(), safetyRatingsNegligible)
XCTAssertEqual(candidate.safetyRatings.count, 4)
XCTAssertEqual(candidate.content.parts.count, 1)
let part = try XCTUnwrap(candidate.content.parts.first)
let textPart = try XCTUnwrap(part as? TextPart)
XCTAssertTrue(textPart.text.hasPrefix("Google"))
XCTAssertEqual(response.text, textPart.text)
let promptFeedback = try XCTUnwrap(response.promptFeedback)
XCTAssertNil(promptFeedback.blockReason)
XCTAssertEqual(promptFeedback.safetyRatings.sorted(), safetyRatingsNegligible)
XCTAssertEqual(promptFeedback.safetyRatings.count, 4)
}

func testGenerateContent_success_unknownEnum_safetyRatings() async throws {
let expectedSafetyRatings = [
SafetyRating(category: .harassment, probability: .medium),
SafetyRating(
category: .harassment,
probability: .medium,
probabilityScore: 0.0,
severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
severityScore: 0.0,
blocked: false
),
SafetyRating(
category: .dangerousContent,
probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY")
probability: SafetyRating.HarmProbability(rawValue: "FAKE_NEW_HARM_PROBABILITY"),
probabilityScore: 0.0,
severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
severityScore: 0.0,
blocked: false
),
SafetyRating(
category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"),
probability: .high,
probabilityScore: 0.0,
severity: .init(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
severityScore: 0.0,
blocked: false
),
SafetyRating(category: HarmCategory(rawValue: "FAKE_NEW_HARM_CATEGORY"), probability: .high),
]
MockURLProtocol
.requestHandler = try httpRequestHandler(
Expand Down Expand Up @@ -930,7 +976,11 @@ final class GenerativeModelTests: XCTestCase {
)
let unknownSafetyRating = SafetyRating(
category: HarmCategory(rawValue: "HARM_CATEGORY_DANGEROUS_CONTENT_NEW_ENUM"),
probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM")
probability: SafetyRating.HarmProbability(rawValue: "NEGLIGIBLE_UNKNOWN_ENUM"),
probabilityScore: 0.0,
severity: SafetyRating.HarmSeverity(rawValue: "HARM_SEVERITY_UNSPECIFIED"),
severityScore: 0.0,
blocked: false
)

var foundUnknownSafetyRating = false
Expand Down

0 comments on commit e1e9796

Please sign in to comment.