Skip to content

Commit

Permalink
[Vertex AI] Refactor FinishReason as a struct and add new values (f…
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored and MojtabaHs committed Oct 17, 2024
1 parent 83dc2cc commit 10ea66e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 33 deletions.
8 changes: 4 additions & 4 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
as input. (#13767)
- [changed] **Breaking Change**: All initializers for `ModelContent` now require
the label `parts: `. (#13832)
- [changed] **Breaking Change**: `HarmCategory` and `HarmProbability` are now
structs instead of enums types and the `unknown` cases have been removed; in a
`switch` statement, use the `default:` case to cover unknown or unhandled
categories or probabilities. (#13728, #13854)
- [changed] **Breaking Change**: `HarmCategory`, `HarmProbability`, and
`FinishReason` are now structs instead of enums types and the `unknown` cases
have been removed; in a `switch` statement, use the `default:` case to cover
unknown or unhandled values. (#13728, #13854, #13860)
- [changed] The default request timeout is now 180 seconds instead of the
platform-default value of 60 seconds for a `URLRequest`; this timeout may
still be customized in `RequestOptions`. (#13722)
Expand Down
87 changes: 60 additions & 27 deletions FirebaseVertexAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,26 +145,76 @@ public struct Citation: Sendable {

/// A value enumerating possible reasons for a model to terminate a content generation request.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public enum FinishReason: String, Sendable {
/// The finish reason is unknown.
case unknown = "FINISH_REASON_UNKNOWN"
public struct FinishReason: DecodableProtoEnum, Hashable, Sendable {
enum Kind: String {
case stop = "STOP"
case maxTokens = "MAX_TOKENS"
case safety = "SAFETY"
case recitation = "RECITATION"
case other = "OTHER"
case blocklist = "BLOCKLIST"
case prohibitedContent = "PROHIBITED_CONTENT"
case spii = "SPII"
case malformedFunctionCall = "MALFORMED_FUNCTION_CALL"
}

/// Natural stop point of the model or provided stop sequence.
case stop = "STOP"
public static var stop: FinishReason {
return self.init(kind: .stop)
}

/// The maximum number of tokens as specified in the request was reached.
case maxTokens = "MAX_TOKENS"
public static var maxTokens: FinishReason {
return self.init(kind: .maxTokens)
}

/// The token generation was stopped because the response was flagged for safety reasons.
/// NOTE: When streaming, the Candidate.content will be empty if content filters blocked the
/// output.
case safety = "SAFETY"
///
/// > NOTE: When streaming, the ``CandidateResponse/content`` will be empty if content filters
/// > blocked the output.
public static var safety: FinishReason {
return self.init(kind: .safety)
}

/// The token generation was stopped because the response was flagged for unauthorized citations.
case recitation = "RECITATION"
public static var recitation: FinishReason {
return self.init(kind: .recitation)
}

/// All other reasons that stopped token generation.
case other = "OTHER"
public static var other: FinishReason {
return self.init(kind: .other)
}

/// Token generation was stopped because the response contained forbidden terms.
public static var blocklist: FinishReason {
return self.init(kind: .blocklist)
}

/// Token generation was stopped because the response contained potentially prohibited content.
public static var prohibitedContent: FinishReason {
return self.init(kind: .prohibitedContent)
}

/// Token generation was stopped because of Sensitive Personally Identifiable Information (SPII).
public static var spii: FinishReason {
return self.init(kind: .spii)
}

/// Token generation was stopped because the function call generated by the model was invalid.
public static var malformedFunctionCall: FinishReason {
return self.init(kind: .malformedFunctionCall)
}

/// Returns the raw string representation of the `FinishReason` value.
///
/// > Note: This value directly corresponds to the values in the [REST
/// > API](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/GenerateContentResponse#FinishReason).
public let rawValue: String

var unrecognizedValueMessageCode: VertexLog.MessageCode {
.generateContentResponseUnrecognizedFinishReason
}
}

/// A metadata struct containing any feedback the model had on the prompt it was provided.
Expand Down Expand Up @@ -333,23 +383,6 @@ extension Citation: Decodable {
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension FinishReason: Decodable {
public init(from decoder: Decoder) throws {
let value = try decoder.singleValueContainer().decode(String.self)
guard let decodedFinishReason = FinishReason(rawValue: value) else {
VertexLog.error(
code: .generateContentResponseUnrecognizedFinishReason,
"Unrecognized FinishReason with value \"\(value)\"."
)
self = .unknown
return
}

self = decodedFinishReason
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension PromptFeedback.BlockReason: Decodable {
public init(from decoder: Decoder) throws {
Expand Down
6 changes: 4 additions & 2 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -608,12 +608,13 @@ final class GenerativeModelTests: XCTestCase {
forResource: "unary-failure-unknown-enum-finish-reason",
withExtension: "json"
)
let unknownFinishReason = FinishReason(rawValue: "FAKE_NEW_FINISH_REASON")

do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw")
} catch let GenerateContentError.responseStoppedEarly(reason, response) {
XCTAssertEqual(reason, .unknown)
XCTAssertEqual(reason, unknownFinishReason)
XCTAssertEqual(response.text, "Some text")
} catch {
XCTFail("Should throw a responseStoppedEarly")
Expand Down Expand Up @@ -921,14 +922,15 @@ final class GenerativeModelTests: XCTestCase {
forResource: "streaming-failure-unknown-finish-enum",
withExtension: "txt"
)
let unknownFinishReason = FinishReason(rawValue: "FAKE_ENUM")

let stream = try model.generateContentStream("Hi")
do {
for try await content in stream {
XCTAssertNotNil(content.text)
}
} catch let GenerateContentError.responseStoppedEarly(reason, _) {
XCTAssertEqual(reason, .unknown)
XCTAssertEqual(reason, unknownFinishReason)
return
}

Expand Down

0 comments on commit 10ea66e

Please sign in to comment.