Skip to content

Commit

Permalink
feat: add extractObject
Browse files Browse the repository at this point in the history
  • Loading branch information
transitive-bullshit committed Jul 27, 2024
1 parent 9a8994b commit 6384239
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
#!/usr/bin/env node
import 'dotenv/config'

import { createAIChain, Msg } from '@agentic/stdlib'
import { extractObject, Msg } from '@agentic/stdlib'
import { ChatModel } from '@dexaai/dexter'
import { z } from 'zod'

async function main() {
const chatModel = new ChatModel({
params: { model: 'gpt-4o', temperature: 0 },
params: { model: 'gpt-4o-mini', temperature: 0 },
debug: true
})

const chain = createAIChain({
const result = await extractObject({
chatFn: chatModel.run.bind(chatModel),
params: {
messages: [Msg.system('Extract a JSON user object from the given text.')]
messages: [
Msg.system('Extract a JSON user object from the given text.'),
Msg.user(
'Bob Vance is 42 years old and lives in Brooklyn, NY. He is a software engineer.'
)
]
},
schema: z.object({
name: z.string(),
Expand All @@ -23,9 +28,6 @@ async function main() {
})
})

const result = await chain(
'Bob Vance is 42 years old and lives in Brooklyn, NY. He is a software engineer.'
)
console.log(result)
}

Expand Down
16 changes: 2 additions & 14 deletions src/create-ai-chain.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import type { SetOptional } from 'type-fest'
import type { ZodType } from 'zod'
import pMap from 'p-map'

import type * as types from './types.js'
Expand Down Expand Up @@ -33,18 +32,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
maxRetries = 2,
toolCallConcurrency = 8,
injectSchemaIntoSystemMessage = true
}: {
chatFn: types.ChatFn
params?: types.Simplify<
Partial<Omit<types.ChatParams, 'tools' | 'functions'>>
>
tools?: types.AIFunctionLike[]
schema?: ZodType<Result> | types.Schema<Result>
maxCalls?: number
maxRetries?: number
toolCallConcurrency?: number
injectSchemaIntoSystemMessage?: boolean
}): types.AIChain<Result> {
}: types.AIChainParams<Result>): types.AIChain<Result> {
const functionSet = new AIFunctionSet(tools)
const defaultParams: Partial<types.ChatParams> | undefined =
rawSchema && !functionSet.size
Expand All @@ -67,7 +55,7 @@ export function createAIChain<Result extends types.AIChainResult = string>({
...chatParams,
messages: [
...(params?.messages ?? []),
...(chatParams.messages ?? [])
...(chatParams?.messages ?? [])
]
}

Expand Down
8 changes: 4 additions & 4 deletions src/create-ai-function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { zodToJsonSchema } from './zod-to-json-schema.js'
* The `spec` property of the returned function is the spec for adding the
* function to the OpenAI API `functions` property.
*/
export function createAIFunction<InputSchema extends z.ZodObject<any>, Return>(
export function createAIFunction<InputSchema extends z.ZodObject<any>, Output>(
spec: {
/** Name of the function. */
name: string
Expand All @@ -24,8 +24,8 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Return>(
inputSchema: InputSchema
},
/** Implementation of the function to call with the parsed arguments. */
implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Return>
): types.AIFunction<InputSchema, Return> {
implementation: (params: z.infer<InputSchema>) => types.MaybePromise<Output>
): types.AIFunction<InputSchema, Output> {
assert(spec.name, 'createAIFunction missing required "spec.name"')
assert(
spec.inputSchema,
Expand All @@ -52,7 +52,7 @@ export function createAIFunction<InputSchema extends z.ZodObject<any>, Return>(
}

// Call the implementation function with the parsed arguments.
const aiFunction: types.AIFunction<InputSchema, Return> = (
const aiFunction: types.AIFunction<InputSchema, Output> = (
input: string | types.Msg
) => {
const parsedInput = parseInput(input)
Expand Down
9 changes: 9 additions & 0 deletions src/extract-object.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import type * as types from './types.js'
import { createAIChain } from './create-ai-chain.js'

export function extractObject<Result extends types.AIChainResult = string>(
args: types.ExtractObjectParams<Result>
): Promise<Result> {
const chain = createAIChain(args)
return chain()
}
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export * from './ai-function-set.js'
export * from './create-ai-chain.js'
export * from './create-ai-function.js'
export * from './errors.js'
export * from './extract-object.js'
export * from './fns.js'
export * from './message.js'
export * from './parse-structured-output.js'
Expand Down
32 changes: 27 additions & 5 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import type { Jsonifiable, SetOptional, Simplify } from 'type-fest'
import type { Jsonifiable, SetOptional, SetRequired, Simplify } from 'type-fest'
import type { z } from 'zod'

import type { AIFunctionSet } from './ai-function-set.js'
import type { AIFunctionsProvider } from './fns.js'
import type { Msg } from './message.js'
import type { Schema } from './schema.js'

export type { Msg } from './message.js'
export type { Schema } from './schema.js'
Expand Down Expand Up @@ -65,14 +66,14 @@ export type AIFunctionLike = AIFunctionsProvider | AIFunction | AIFunctionSet
*/
export interface AIFunction<
InputSchema extends z.ZodObject<any> = z.ZodObject<any>,
Return = any
Output = any
> {
/**
* Invokes the underlying AI function `impl` but first validates the input
* against this function's `inputSchema`. This method is callable and is
* meant to be passed the raw LLM JSON string or an OpenAI-compatible Message.
*/
(input: string | Msg): MaybePromise<Return>
(input: string | Msg): MaybePromise<Output>

/** The Zod schema for the input object. */
inputSchema: InputSchema
Expand All @@ -87,7 +88,7 @@ export interface AIFunction<
* The underlying function implementation without any arg parsing or validation.
*/
// TODO: this `any` shouldn't be necessary, but it is for `createAIFunction` results to be assignable to `AIFunctionLike`
impl: (params: z.infer<InputSchema> | any) => MaybePromise<Return>
impl: (params: z.infer<InputSchema> | any) => MaybePromise<Output>
}

export interface ChatParams {
Expand Down Expand Up @@ -124,7 +125,7 @@ export type ChatFn = (
export type AIChainResult = string | Record<string, any>

export type AIChain<Result extends AIChainResult = string> = (
params:
params?:
| string
| Simplify<SetOptional<Omit<ChatParams, 'tools' | 'functions'>, 'model'>>
) => Promise<Result>
Expand All @@ -140,3 +141,24 @@ export type SafeParseResult<TData> =
}

export type ValidatorFn<TData> = (value: unknown) => SafeParseResult<TData>

export type AIChainParams<Result extends AIChainResult = string> = {
chatFn: ChatFn
params?: Simplify<Partial<Omit<ChatParams, 'tools' | 'functions'>>>
tools?: AIFunctionLike[]
schema?: z.ZodType<Result> | Schema<Result>
maxCalls?: number
maxRetries?: number
toolCallConcurrency?: number
injectSchemaIntoSystemMessage?: boolean
}

export type ExtractObjectParams<Result extends AIChainResult = string> =
Simplify<
SetRequired<
Omit<AIChainParams<Result>, 'tools' | 'toolCallConcurrency' | 'params'>,
'schema'
> & {
params: SetRequired<Partial<ChatParams>, 'messages'>
}
>

0 comments on commit 6384239

Please sign in to comment.