diff --git a/FirebaseVertexAI/Sources/Types/Public/Schema.swift b/FirebaseVertexAI/Sources/Types/Public/Schema.swift index 0ff5b32ba47..9496113a071 100644 --- a/FirebaseVertexAI/Sources/Types/Public/Schema.swift +++ b/FirebaseVertexAI/Sources/Types/Public/Schema.swift @@ -20,19 +20,41 @@ import Foundation /// [OpenAPI 3.0 schema object](https://spec.openapis.org/oas/v3.0.3#schema). public class Schema { /// Modifiers describing the expected format of a string `Schema`. - public enum StringFormat { + public struct StringFormat: EncodableProtoEnum { + // This enum is currently only used to conform `StringFormat` to `ProtoEnum`, which requires + // `associatedtype Kind: RawRepresentable`. + enum Kind: String { + // Providing a case resolves the error "An enum with no cases cannot declare a raw type". + case unused // TODO: Remove `unused` case when we have at least one specific string format. + } + /// A custom string format. - case custom(String) + public static func custom(_ format: String) -> StringFormat { + return self.init(rawValue: format) + } + + let rawValue: String } /// Modifiers describing the expected format of an integer `Schema`. - public enum IntegerFormat { + public struct IntegerFormat: EncodableProtoEnum { + enum Kind: String { + case int32 + case int64 + } + /// A 32-bit signed integer. - case int32 + public static let int32 = IntegerFormat(kind: .int32) + /// A 64-bit signed integer. - case int64 + public static let int64 = IntegerFormat(kind: .int64) + /// A custom integer format. - case custom(String) + public static func custom(_ format: String) -> IntegerFormat { + return self.init(rawValue: format) + } + + let rawValue: String } let dataType: DataType @@ -311,45 +333,3 @@ extension Schema: Encodable { case requiredProperties = "required" } } - -// MARK: - RawRepresentable Conformance - -extension Schema.IntegerFormat: RawRepresentable { - public init?(rawValue: String) { - switch rawValue { - case "int32": - self = .int32 - case "int64": - self = .int64 - default: - self = .custom(rawValue) - } - } - - public var rawValue: String { - switch self { - case .int32: - return "int32" - case .int64: - return "int64" - case let .custom(format): - return format - } - } -} - -extension Schema.StringFormat: RawRepresentable { - public init?(rawValue: String) { - switch rawValue { - default: - self = .custom(rawValue) - } - } - - public var rawValue: String { - switch self { - case let .custom(format): - return format - } - } -} diff --git a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift index 80b6f1ab528..d8b25cd8ec4 100644 --- a/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift +++ b/FirebaseVertexAI/Tests/Integration/IntegrationTests.swift @@ -153,4 +153,25 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(response.totalTokens, 24) XCTAssertEqual(response.totalBillableCharacters, 71) } + + func testCountTokens_jsonSchema() async throws { + model = vertex.generativeModel( + modelName: "gemini-1.5-flash", + generationConfig: GenerationConfig( + responseMIMEType: "application/json", + responseSchema: Schema.object(properties: [ + "startDate": .string(format: .custom("date")), + "yearsSince": .integer(format: .custom("int16")), + "hoursSince": .integer(format: .int32), + "minutesSince": .integer(format: .int64), + ]) + ) + ) + let prompt = "It is 2050-01-01, how many years, hours and minutes since 2000-01-01?" + + let response = try await model.countTokens(prompt) + + XCTAssertEqual(response.totalTokens, 34) + XCTAssertEqual(response.totalBillableCharacters, 59) + } }