Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(EAI-539): Run text-to-Node.js driver benchmark #537

Open
wants to merge 21 commits into
base: text_to_node_js_driver_benchmark
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions packages/benchmarks/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ MONGODB_QUIZ_QUESTIONS_CONNECTION_URI="mongodb+srv://<user>:<pass>@docs-chatbot-
MONGODB_QUIZ_QUESTIONS_DATABASE_NAME="datasets"
MONGODB_QUIZ_QUESTIONS_COLLECTION_NAME="quiz_questions"
MONGODB_TEXT_TO_DRIVER_CONNECTION_URI="<CONNECTION URI>"
BRAINTRUST_ENDPOINT="https://api.braintrust.dev/v1/proxy"
BRAINTRUST_API_KEY="<some api key>"
BRAINTRUST_TEXT_TO_DRIVER_PROJECT_NAME="<some project>"
# Note: We use the Azure-OpenAI style passthrough for the Radiant endpoint
RADIANT_ENDPOINT="<base_url>/azure"
RADIANT_API_KEY="rad-API_KEY-iant"
OPENAI_API_VERSION="2024-06-01"
OPENAI_ENDPOINT="https://<resource_name>.openai.azure.com/"
OPENAI_API_KEY="<api_key>"
# Cookie must be updated daily
MONGODB_AUTH_COOKIE="<Auth cookie from CorpSecure-protected site>"
2 changes: 1 addition & 1 deletion packages/benchmarks/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"@azure/openai": "^1.0.0-beta.11",
"@langchain/openai": "^0.2.6",
"@supercharge/promise-pool": "^3.2.0",
"braintrust": "^0.0.159",
"braintrust": "^0.0.164",
"dotenv": "^16",
"mongodb-chatbot-evaluation": "*",
"mongodb-chatbot-server": "*",
Expand Down
6 changes: 3 additions & 3 deletions packages/benchmarks/src/discovery.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { MongoClient, assertEnvVars } from "mongodb-rag-core";
import { envVars } from "./envVars";
import { makeChatLlmConversationEvalCommands } from "./makeChatLlmConversationEvalCommands";
import { makeRadiantChatLlm } from "./makeRadiantChatLlm";
import { radiantModels } from "./radiantModels";
import { models } from "./models";
import { makeBaseConfig } from "./baseConfig";

export default async () => {
Expand All @@ -31,13 +31,13 @@ export default async () => {
)
);
const chatLlmConfigs = await Promise.all(
radiantModels.map(async (model) => {
models.map(async (model) => {
return {
name: model.label,
chatLlm: await makeRadiantChatLlm({
apiKey: RADIANT_API_KEY,
endpoint: RADIANT_ENDPOINT,
deployment: model.radiantModelDeployment,
deployment: model.deployment,
mongoDbAuthCookie: MONGODB_AUTH_COOKIE,
lmmConfigOptions: {
temperature: 0,
Expand Down
5 changes: 5 additions & 0 deletions packages/benchmarks/src/envVars.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ export const RADIANT_ENV_VARS = {
OPENAI_API_VERSION: "",
};

export const BRAINTRUST_ENV_VARS = {
BRAINTRUST_API_KEY: "",
BRAINTRUST_ENDPOINT: "",
};

export const envVars = {
MONGODB_DATABASE_NAME: "",
MONGODB_CONNECTION_URI: "",
Expand Down
101 changes: 101 additions & 0 deletions packages/benchmarks/src/makeOpenAiClientFactory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import OpenAI, { AzureOpenAI } from "openai";
import { ModelConfig } from "./models";
import { strict as assert } from "assert";
import { APIPromise } from "openai/core.mjs";
import {
ChatCompletion,
ChatCompletionChunk,
} from "openai/resources/index.mjs";
import { Stream } from "stream";
interface BaseModelProviderConfig {
apiKey: string;
endpoint: string;
}

interface MakeOpenAiClientFactoryParams {
azure?: BaseModelProviderConfig & {
apiVersion: string;
};
braintrust?: BaseModelProviderConfig;
radiant?: BaseModelProviderConfig & {
authCookie: string;
};
}

export function makeOpenAiClientFactory({
azure,
braintrust,
radiant,
}: MakeOpenAiClientFactoryParams) {
return {
makeOpenAiClient(modelConfig: ModelConfig) {
let openAiClient: OpenAI;
if (modelConfig.provider === "azure_openai") {
assert(azure, "Azure OpenAI config must be provided");
openAiClient = new AzureOpenAI({
apiKey: azure.apiKey,
endpoint: azure.endpoint,
apiVersion: azure.apiVersion,
});
} else if (modelConfig.provider === "braintrust") {
assert(braintrust, "Braintrust OpenAI config must be provided");
openAiClient = new OpenAI({
apiKey: braintrust.apiKey,
baseURL: braintrust.endpoint,
});
} else if (modelConfig.provider === "radiant") {
assert(radiant, "Radiant OpenAI config must be provided");
openAiClient = new OpenAI({
apiKey: radiant.apiKey,
baseURL: radiant.endpoint,
defaultHeaders: {
Cookie: radiant.authCookie,
},
});
} else {
throw new Error(`Unsupported provider: ${modelConfig.provider}`);
}
if (modelConfig.systemMessageAsUserMessage) {
openAiClient = wrapOpenAiClientWithSystemMessage(openAiClient);
}
return openAiClient;
},
};
}

function wrapOpenAiClientWithSystemMessage(openAiClient: OpenAI): OpenAI {
// Preserve the original `.create()` method with binding
const originalCreate = openAiClient.chat.completions.create.bind(
openAiClient.chat.completions
);

// Override the `.create()` method with minimal type casting
openAiClient.chat.completions.create = ((args, options) => {
const transformedMessages = transformSystemMessages(args.messages);

// Call the original method with type assertions to match expected types
return originalCreate(
{
...args,
messages: transformedMessages,
} satisfies OpenAI.Chat.Completions.ChatCompletionCreateParams,
options satisfies OpenAI.RequestOptions<any> | undefined
);
}) as typeof openAiClient.chat.completions.create;

return openAiClient;
}

// Utility function to transform system messages into user-like messages
function transformSystemMessages(
messages: OpenAI.Chat.Completions.ChatCompletionCreateParams["messages"]
): OpenAI.Chat.Completions.ChatCompletionCreateParams["messages"] {
return messages.map((message) =>
message.role === "system"
? {
role: "user",
content: `<System_Message>\n${message.content}\n</System_Message>`,
}
: message
);
}
84 changes: 84 additions & 0 deletions packages/benchmarks/src/models.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/**
@description Quick test to make sure all the Radiant models are functional.
Useful to test before executing benchmark runs to ensure all models are working.
*/
import {
assertEnvVars,
CORE_OPENAI_CHAT_COMPLETION_ENV_VARS,
CORE_OPENAI_CONNECTION_ENV_VARS,
} from "mongodb-rag-core";
import { BRAINTRUST_ENV_VARS, envVars } from "./envVars";
import { models } from "./models";
import { makeOpenAiClientFactory } from "./makeOpenAiClientFactory";
import OpenAI from "openai";
import "dotenv/config";

jest.setTimeout(60000);
// NOTE: due to this issue https://github.com/nodejs/node/issues/39964,
// you must run the tests with a Node version >= 20.0.0
describe.skip("Radiant models", () => {
test.each(models.filter((m) => m.provider === "radiant"))(
"'$label' model should generate data",
async (model) => {
// Note: this is inside of the tests so that this doesn't throw with the skipped tests.
// THe assertion inside of the describe block will throw if the env vars are not set,
// even if the block is skipped.
const { RADIANT_API_KEY, RADIANT_ENDPOINT, MONGODB_AUTH_COOKIE } =
assertEnvVars(envVars);
const openAiClientFactory = makeOpenAiClientFactory({
radiant: {
apiKey: RADIANT_API_KEY,
endpoint: RADIANT_ENDPOINT,
authCookie: MONGODB_AUTH_COOKIE,
},
});
const openAiClient = openAiClientFactory.makeOpenAiClient(model);
await expectModelResponse(openAiClient, model.deployment);
}
);
});
describe.skip("Braintrust models", () => {
test.each(models.filter((m) => m.provider === "braintrust"))(
"'$label' model should generate data",
async (model) => {
const { BRAINTRUST_API_KEY, BRAINTRUST_ENDPOINT } =
assertEnvVars(BRAINTRUST_ENV_VARS);
const openAiClientFactory = makeOpenAiClientFactory({
braintrust: {
apiKey: BRAINTRUST_API_KEY,
endpoint: BRAINTRUST_ENDPOINT,
},
});
const openAiClient = openAiClientFactory.makeOpenAiClient(model);
await expectModelResponse(openAiClient, model.deployment);
}
);
});

describe.skip("Azure OpenAI models", () => {
test.each(models.filter((m) => m.provider === "azure_openai"))(
"'$label' model should generate data",
async (model) => {
const { OPENAI_API_KEY, OPENAI_ENDPOINT, OPENAI_API_VERSION } =
assertEnvVars(CORE_OPENAI_CONNECTION_ENV_VARS);
const openAiClientFactory = makeOpenAiClientFactory({
azure: {
apiKey: OPENAI_API_KEY,
endpoint: OPENAI_ENDPOINT,
apiVersion: OPENAI_API_VERSION,
},
});
const openAiClient = openAiClientFactory.makeOpenAiClient(model);
await expectModelResponse(openAiClient, model.deployment);
}
);
});

async function expectModelResponse(client: OpenAI, model: string) {
const res = await client.chat.completions.create({
model,
messages: [{ role: "user", content: "Hello" }],
temperature: 0,
});
expect(res.choices[0].message.content).toEqual(expect.any(String));
}
Loading