Skip to content

Commit

Permalink
[Vertex AI] Refactor StringFormat and IntegerFormat as structs (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Oct 10, 2024
1 parent ae484d3 commit fde2baf
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 48 deletions.
76 changes: 28 additions & 48 deletions FirebaseVertexAI/Sources/Types/Public/Schema.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,41 @@ import Foundation
/// [OpenAPI 3.0 schema object](https://spec.openapis.org/oas/v3.0.3#schema).
public class Schema {
/// Modifiers describing the expected format of a string `Schema`.
public enum StringFormat {
public struct StringFormat: EncodableProtoEnum {
// This enum is currently only used to conform `StringFormat` to `ProtoEnum`, which requires
// `associatedtype Kind: RawRepresentable<String>`.
enum Kind: String {
// Providing a case resolves the error "An enum with no cases cannot declare a raw type".
case unused // TODO: Remove `unused` case when we have at least one specific string format.
}

/// A custom string format.
case custom(String)
public static func custom(_ format: String) -> StringFormat {
return self.init(rawValue: format)
}

let rawValue: String
}

/// Modifiers describing the expected format of an integer `Schema`.
public enum IntegerFormat {
public struct IntegerFormat: EncodableProtoEnum {
enum Kind: String {
case int32
case int64
}

/// A 32-bit signed integer.
case int32
public static let int32 = IntegerFormat(kind: .int32)

/// A 64-bit signed integer.
case int64
public static let int64 = IntegerFormat(kind: .int64)

/// A custom integer format.
case custom(String)
public static func custom(_ format: String) -> IntegerFormat {
return self.init(rawValue: format)
}

let rawValue: String
}

let dataType: DataType
Expand Down Expand Up @@ -311,45 +333,3 @@ extension Schema: Encodable {
case requiredProperties = "required"
}
}

// MARK: - RawRepresentable Conformance

extension Schema.IntegerFormat: RawRepresentable {
public init?(rawValue: String) {
switch rawValue {
case "int32":
self = .int32
case "int64":
self = .int64
default:
self = .custom(rawValue)
}
}

public var rawValue: String {
switch self {
case .int32:
return "int32"
case .int64:
return "int64"
case let .custom(format):
return format
}
}
}

extension Schema.StringFormat: RawRepresentable {
public init?(rawValue: String) {
switch rawValue {
default:
self = .custom(rawValue)
}
}

public var rawValue: String {
switch self {
case let .custom(format):
return format
}
}
}
21 changes: 21 additions & 0 deletions FirebaseVertexAI/Tests/Integration/IntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,25 @@ final class IntegrationTests: XCTestCase {
XCTAssertEqual(response.totalTokens, 24)
XCTAssertEqual(response.totalBillableCharacters, 71)
}

func testCountTokens_jsonSchema() async throws {
model = vertex.generativeModel(
modelName: "gemini-1.5-flash",
generationConfig: GenerationConfig(
responseMIMEType: "application/json",
responseSchema: Schema.object(properties: [
"startDate": .string(format: .custom("date")),
"yearsSince": .integer(format: .custom("int16")),
"hoursSince": .integer(format: .int32),
"minutesSince": .integer(format: .int64),
])
)
)
let prompt = "It is 2050-01-01, how many years, hours and minutes since 2000-01-01?"

let response = try await model.countTokens(prompt)

XCTAssertEqual(response.totalTokens, 34)
XCTAssertEqual(response.totalBillableCharacters, 59)
}
}

0 comments on commit fde2baf

Please sign in to comment.