Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Refactor FinishReason as a struct and add new values #13860

Merged
merged 3 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
as input. (#13767)
- [changed] **Breaking Change**: All initializers for `ModelContent` now require
the label `parts: `. (#13832)
- [changed] **Breaking Change**: `HarmCategory` and `HarmProbability` are now
structs instead of enums types and the `unknown` cases have been removed; in a
`switch` statement, use the `default:` case to cover unknown or unhandled
categories or probabilities. (#13728, #13854)
- [changed] **Breaking Change**: `HarmCategory`, `HarmProbability`, and
`FinishReason` are now structs instead of enums types and the `unknown` cases
have been removed; in a `switch` statement, use the `default:` case to cover
unknown or unhandled values. (#13728, #13854, #13860)
- [changed] The default request timeout is now 180 seconds instead of the
platform-default value of 60 seconds for a `URLRequest`; this timeout may
still be customized in `RequestOptions`. (#13722)
Expand Down
87 changes: 60 additions & 27 deletions FirebaseVertexAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,26 +145,76 @@ public struct Citation: Sendable {

/// A value enumerating possible reasons for a model to terminate a content generation request.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public enum FinishReason: String, Sendable {
/// The finish reason is unknown.
case unknown = "FINISH_REASON_UNKNOWN"
public struct FinishReason: DecodableProtoEnum, Hashable, Sendable {
enum Kind: String {
case stop = "STOP"
case maxTokens = "MAX_TOKENS"
case safety = "SAFETY"
case recitation = "RECITATION"
case other = "OTHER"
case blocklist = "BLOCKLIST"
case prohibitedContent = "PROHIBITED_CONTENT"
case spii = "SPII"
case malformedFunctionCall = "MALFORMED_FUNCTION_CALL"
}

/// Natural stop point of the model or provided stop sequence.
case stop = "STOP"
public static var stop: FinishReason {
return self.init(kind: .stop)
}

/// The maximum number of tokens as specified in the request was reached.
case maxTokens = "MAX_TOKENS"
public static var maxTokens: FinishReason {
return self.init(kind: .maxTokens)
}

/// The token generation was stopped because the response was flagged for safety reasons.
/// NOTE: When streaming, the Candidate.content will be empty if content filters blocked the
/// output.
case safety = "SAFETY"
///
/// > NOTE: When streaming, the ``CandidateResponse/content`` will be empty if content filters
/// > blocked the output.
public static var safety: FinishReason {
return self.init(kind: .safety)
}

/// The token generation was stopped because the response was flagged for unauthorized citations.
case recitation = "RECITATION"
public static var recitation: FinishReason {
return self.init(kind: .recitation)
}

/// All other reasons that stopped token generation.
case other = "OTHER"
public static var other: FinishReason {
return self.init(kind: .other)
}

/// Token generation was stopped because the response contained forbidden terms.
public static var blocklist: FinishReason {
return self.init(kind: .blocklist)
}

/// Token generation was stopped because the response contained potentially prohibited content.
public static var prohibitedContent: FinishReason {
return self.init(kind: .prohibitedContent)
}

/// Token generation was stopped because of Sensitive Personally Identifiable Information (SPII).
public static var spii: FinishReason {
return self.init(kind: .spii)
}

/// Token generation was stopped because the function call generated by the model was invalid.
public static var malformedFunctionCall: FinishReason {
return self.init(kind: .malformedFunctionCall)
}

/// Returns the raw string representation of the `FinishReason` value.
///
/// > Note: This value directly corresponds to the values in the [REST
/// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#FinishReason).
public let rawValue: String

var unrecognizedValueMessageCode: VertexLog.MessageCode {
.generateContentResponseUnrecognizedFinishReason
}
}

/// A metadata struct containing any feedback the model had on the prompt it was provided.
Expand Down Expand Up @@ -333,23 +383,6 @@ extension Citation: Decodable {
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension FinishReason: Decodable {
public init(from decoder: Decoder) throws {
let value = try decoder.singleValueContainer().decode(String.self)
guard let decodedFinishReason = FinishReason(rawValue: value) else {
VertexLog.error(
code: .generateContentResponseUnrecognizedFinishReason,
"Unrecognized FinishReason with value \"\(value)\"."
)
self = .unknown
return
}

self = decodedFinishReason
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension PromptFeedback.BlockReason: Decodable {
public init(from decoder: Decoder) throws {
Expand Down
6 changes: 4 additions & 2 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -608,12 +608,13 @@ final class GenerativeModelTests: XCTestCase {
forResource: "unary-failure-unknown-enum-finish-reason",
withExtension: "json"
)
let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")

do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw")
} catch let GenerateContentError.responseStoppedEarly(reason, response) {
XCTAssertEqual(reason, .unknown)
XCTAssertEqual(reason, unknownFinishReason)
XCTAssertEqual(response.text, "Some text")
} catch {
XCTFail("Should throw a responseStoppedEarly")
Expand Down Expand Up @@ -921,14 +922,15 @@ final class GenerativeModelTests: XCTestCase {
forResource: "streaming-failure-unknown-finish-enum",
withExtension: "txt"
)
let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")

let stream = try model.generateContentStream("Hi")
do {
for try await content in stream {
XCTAssertNotNil(content.text)
}
} catch let GenerateContentError.responseStoppedEarly(reason, _) {
XCTAssertEqual(reason, .unknown)
XCTAssertEqual(reason, unknownFinishReason)
return
}

Expand Down
Loading