Skip to content

Commit

Permalink
[Vertex AI] Simplify ModelContent initializers (firebase#13832)
Browse files Browse the repository at this point in the history
Co-authored-by: Ryan Wilson <[email protected]>
  • Loading branch information
2 people authored and MojtabaHs committed Oct 17, 2024
1 parent cd5ec2f commit 94afbfa
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 66 deletions.
2 changes: 2 additions & 0 deletions FirebaseVertexAI/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
generating content the types `TextPart`; additionally the types
`InlineDataPart`, `FileDataPart` and `FunctionResponsePart` may be provided
as input. (#13767)
- [changed] **Breaking Change**: All initializers for `ModelContent` now require
the label `parts: `. (#13832)
- [changed] The default request timeout is now 180 seconds instead of the
platform-default value of 60 seconds for a `URLRequest`; this timeout may
still be customized in `RequestOptions`. (#13722)
Expand Down
10 changes: 4 additions & 6 deletions FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,11 @@ struct ErrorDetailsView: View {
let error = GenerateContentError.responseStoppedEarly(
reason: .maxTokens,
response: GenerateContentResponse(candidates: [
CandidateResponse(content: ModelContent(role: "model", [
CandidateResponse(content: ModelContent(role: "model", parts:
"""
A _hypothetical_ model response.
Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo.
""",
]),
"""),
safetyRatings: [
SafetyRating(category: .dangerousContent, probability: .high),
SafetyRating(category: .harassment, probability: .low),
Expand All @@ -183,12 +182,11 @@ struct ErrorDetailsView: View {
#Preview("Prompt Blocked") {
let error = GenerateContentError.promptBlocked(
response: GenerateContentResponse(candidates: [
CandidateResponse(content: ModelContent(role: "model", [
CandidateResponse(content: ModelContent(role: "model", parts:
"""
A _hypothetical_ model response.
Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo.
""",
]),
"""),
safetyRatings: [
SafetyRating(category: .dangerousContent, probability: .high),
SafetyRating(category: .harassment, probability: .low),
Expand Down
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct ErrorView: View {
NavigationView {
let errorPromptBlocked = GenerateContentError.promptBlocked(
response: GenerateContentResponse(candidates: [
CandidateResponse(content: ModelContent(role: "model", [
CandidateResponse(content: ModelContent(role: "model", parts: [
"""
A _hypothetical_ model response.
Cillum ex aliqua amet aliquip labore amet eiusmod consectetur reprehenderit sit commodo.
Expand Down
35 changes: 1 addition & 34 deletions FirebaseVertexAI/Sources/ModelContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,32 +71,6 @@ public struct ModelContent: Equatable, Sendable {
// TODO: Refactor this
let internalParts: [InternalPart]

/// Creates a new value from any data or `Array` of data interpretable as a
/// ``Part``. See ``PartsRepresentable`` for types that can be interpreted as `Part`s.
public init(role: String? = "user", parts: some PartsRepresentable) {
self.role = role
var convertedParts = [InternalPart]()
for part in parts.partsValue {
switch part {
case let textPart as TextPart:
convertedParts.append(.text(textPart.text))
case let inlineDataPart as InlineDataPart:
let inlineData = inlineDataPart.inlineData
convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data))
case let fileDataPart as FileDataPart:
let fileData = fileDataPart.fileData
convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI))
case let functionCallPart as FunctionCallPart:
convertedParts.append(.functionCall(functionCallPart.functionCall))
case let functionResponsePart as FunctionResponsePart:
convertedParts.append(.functionResponse(functionResponsePart.functionResponse))
default:
fatalError()
}
}
internalParts = convertedParts
}

/// Creates a new value from a list of ``Part``s.
public init(role: String? = "user", parts: [any Part]) {
self.role = role
Expand Down Expand Up @@ -124,14 +98,7 @@ public struct ModelContent: Equatable, Sendable {

/// Creates a new value from any data interpretable as a ``Part``.
/// See ``PartsRepresentable`` for types that can be interpreted as `Part`s.
public init(role: String? = "user", _ parts: any PartsRepresentable...) {
let content = parts.flatMap { $0.partsValue }
self.init(role: role, parts: content)
}

/// Creates a new value from any data interpretable as a ``Part``.
/// See ``PartsRepresentable``for types that can be interpreted as `Part`s.
public init(role: String? = "user", _ parts: [PartsRepresentable]) {
public init(role: String? = "user", parts: any PartsRepresentable...) {
let content = parts.flatMap { $0.partsValue }
self.init(role: role, parts: content)
}
Expand Down
38 changes: 13 additions & 25 deletions FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ final class VertexAIAPITests: XCTestCase {
_ = try await genAI.generateContent([str, UIImage(), TextPart(str)])
_ = try await genAI.generateContent(str, UIImage(), "def", UIImage())
_ = try await genAI.generateContent([str, UIImage(), "def", UIImage()])
_ = try await genAI.generateContent([ModelContent("def", UIImage()),
ModelContent("def", UIImage())])
_ = try await genAI.generateContent([ModelContent(parts: "def", UIImage()),
ModelContent(parts: "def", UIImage())])
#elseif canImport(AppKit)
_ = try await genAI.generateContent(NSImage())
_ = try await genAI.generateContent([NSImage()])
Expand All @@ -121,37 +121,25 @@ final class VertexAIAPITests: XCTestCase {
let _ = ModelContent(parts: "Constant String")
let _ = ModelContent(parts: str)
let _ = ModelContent(parts: [str])
// Note: without `as [any PartsRepresentable]` this will fail to compile with "Cannot
// convert value of type 'String' to expected element type
// 'Array<Part>.ArrayLiteralElement'. Not sure if there's a way we can get it to
// work.
let _ = ModelContent(
parts: [str, InlineDataPart(data: Data(), mimeType: "foo")] as [any PartsRepresentable]
)
let _ = ModelContent(parts: [str, InlineDataPart(data: Data(), mimeType: "foo")])
#if canImport(UIKit)
_ = ModelContent(role: "user", parts: UIImage())
_ = ModelContent(role: "user", parts: [UIImage()])
// Note: without `as [any PartsRepresentable]` this will fail to compile with "Cannot convert
// value of type `[Any]` to expected type `[any PartsRepresentable]`. Not sure if there's a
// way we can get it to work.
_ = ModelContent(parts: [str, UIImage()] as [any PartsRepresentable])
// Alternatively, you can explicitly declare the type in a variable and pass it in.
_ = ModelContent(parts: [str, UIImage()])
// Note: without explicitly specifying`: [any PartsRepresentable]` this will fail to compile
// below with "Cannot convert value of type `[Any]` to expected type `[any Part]`.
let representable2: [any PartsRepresentable] = [str, UIImage()]
_ = ModelContent(parts: representable2)
_ =
ModelContent(parts: [str, UIImage(), TextPart(str)] as [any PartsRepresentable])
_ = ModelContent(parts: [str, UIImage(), TextPart(str)])
#elseif canImport(AppKit)
_ = ModelContent(role: "user", parts: NSImage())
_ = ModelContent(role: "user", parts: [NSImage()])
// Note: without `as [any PartsRepresentable]` this will fail to compile with "Cannot convert
// value of type `[Any]` to expected type `[any PartsRepresentable]`. Not sure if there's a
// way we can get it to work.
_ = ModelContent(parts: [str, NSImage()] as [any PartsRepresentable])
// Alternatively, you can explicitly declare the type in a variable and pass it in.
_ = ModelContent(parts: [str, NSImage()])
// Note: without explicitly specifying`: [any PartsRepresentable]` this will fail to compile
// below with "Cannot convert value of type `[Any]` to expected type `[any Part]`.
let representable2: [any PartsRepresentable] = [str, NSImage()]
_ = ModelContent(parts: representable2)
_ =
ModelContent(parts: [str, NSImage(), TextPart(str)] as [any PartsRepresentable])
_ = ModelContent(parts: [str, NSImage(), TextPart(str)])
#endif

// countTokens API
Expand All @@ -160,8 +148,8 @@ final class VertexAIAPITests: XCTestCase {
let _: CountTokensResponse = try await genAI.countTokens("What color is the Sky?",
UIImage())
let _: CountTokensResponse = try await genAI.countTokens([
ModelContent("What color is the Sky?", UIImage()),
ModelContent(UIImage(), "What color is the Sky?", UIImage()),
ModelContent(parts: "What color is the Sky?", UIImage()),
ModelContent(parts: UIImage(), "What color is the Sky?", UIImage()),
])
#endif

Expand Down

0 comments on commit 94afbfa

Please sign in to comment.