Skip to content

Commit

Permalink
[Vertex AI] Refactor BlockReason as a struct and add new values
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Oct 9, 2024
1 parent 492e488 commit aaa9ef9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 23 deletions.
55 changes: 33 additions & 22 deletions FirebaseVertexAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,43 @@ public enum FinishReason: String, Sendable {
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct PromptFeedback: Sendable {
/// A type describing possible reasons to block a prompt.
public enum BlockReason: String, Sendable {
/// The block reason is unknown.
case unknown = "UNKNOWN"
public struct BlockReason: DecodableProtoEnum, Hashable, Sendable {
enum Kind: String {
case safety = "SAFETY"
case other = "OTHER"
case blocklist = "BLOCKLIST"
case prohibitedContent = "PROHIBITED_CONTENT"
}

/// The prompt was blocked because it was deemed unsafe.
case safety = "SAFETY"
public static var safety: BlockReason {
return self.init(kind: .safety)
}

/// All other block reasons.
case other = "OTHER"
public static var other: BlockReason {
return self.init(kind: .other)
}

/// The prompt was blocked because it contained terms from the terminology blocklist.
public static var blocklist: BlockReason {
return self.init(kind: .blocklist)
}

/// The prompt was blocked due to prohibited content.
public static var prohibitedContent: BlockReason {
return self.init(kind: .prohibitedContent)
}

/// Returns the raw string representation of the `BlockReason` value.
///
/// > Note: This value directly corresponds to the values in the [REST
/// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#BlockedReason).
public let rawValue: String

var unrecognizedValueMessageCode: VertexLog.MessageCode {
.generateContentResponseUnrecognizedBlockReason
}
}

/// The reason a prompt was blocked, if it was blocked.
Expand Down Expand Up @@ -350,23 +378,6 @@ extension FinishReason: Decodable {
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension PromptFeedback.BlockReason: Decodable {
public init(from decoder: Decoder) throws {
let value = try decoder.singleValueContainer().decode(String.self)
guard let decodedBlockReason = PromptFeedback.BlockReason(rawValue: value) else {
VertexLog.error(
code: .generateContentResponseUnrecognizedBlockReason,
"Unrecognized BlockReason with value \"\(value)\"."
)
self = .unknown
return
}

self = decodedBlockReason
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension PromptFeedback: Decodable {
enum CodingKeys: CodingKey {
Expand Down
3 changes: 2 additions & 1 deletion FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -626,13 +626,14 @@ final class GenerativeModelTests: XCTestCase {
forResource: "unary-failure-unknown-enum-prompt-blocked",
withExtension: "json"
)
let unknownBlockReason = PromptFeedback.BlockReason(rawValue: "FAKE_NEW_BLOCK_REASON")

do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw")
} catch let GenerateContentError.promptBlocked(response) {
let promptFeedback = try XCTUnwrap(response.promptFeedback)
XCTAssertEqual(promptFeedback.blockReason, .unknown)
XCTAssertEqual(promptFeedback.blockReason, unknownBlockReason)
} catch {
XCTFail("Should throw a promptBlocked")
}
Expand Down

0 comments on commit aaa9ef9

Please sign in to comment.