Skip to content
This repository has been archived by the owner on Sep 19, 2024. It is now read-only.

Embeddings #881

Draft
wants to merge 6 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/adapters/supabase/helpers/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,41 @@ export const _approveLabelChange = async (changeId: number) => {

return;
};

/**
* Upserts embeddings vector to the `embeddings` table
* @param org - The organization name
* @param repo - The repository name
* @param issue - The issue number
* @param embeddings - The vector of floating point numbers
*/
export const upsertEmbeddings = async (org: string, repo: string, issue: number, embedding: number[]) => {
const { supabase } = getAdapters();
const logger = getLogger();
const { data } = await supabase.from("embeddings").select("*").eq("org", org).eq("repo", repo).eq("issue", issue).single();
if (data) {
// Update the existing record with the new embedding for a given set.
const id = data["id"] as number;
const { error } = await supabase.from("embeddings").upsert({
id: id,
embedding,
updated_at: new Date().toISOString(),
});
if (error) {
logger.info(`Updating embedding table failed. id: ${id}, org: ${org}, repo: ${repo}, issue: ${issue}, embedding: ${embedding.join(", ")}`);
}
} else {
// Insert a new record to the `embeddings` table.
const { error } = await supabase.from("embeddings").upsert({
org,
repo,
issue,
embedding,
created_at: new Date().toISOString(),
updated_at: new Date().toISOString(),
});
if (error) {
logger.info(`Inserting to embedding table failed. org: ${org}, repo: ${repo}, issue: ${issue}, embedding: ${embedding.join(", ")}`);
}
}
};
6 changes: 2 additions & 4 deletions src/handlers/comment/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,9 @@ export const handleComment = async (): Promise<void> => {
const callbackComment = response ?? successComment ?? "";
if (callbackComment) await callback(issue.number, callbackComment, payload.action, payload.comment);
} catch (err: unknown) {
// Use failureComment for failed command if it is available
if (failureComment) {
await callback(issue.number, failureComment, payload.action, payload.comment);
for (const comment of [failureComment, ErrorDiff(err)]) {
if (comment) await callback(issue.number, comment, payload.action, payload.comment);
}
await callback(issue.number, ErrorDiff(err), payload.action, payload.comment);
}
} else {
logger.info(`Skipping for a command: ${command}`);
Expand Down
29 changes: 29 additions & 0 deletions src/handlers/issue/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import { getBotContext, getLogger } from "../../bindings";
import { Payload } from "../../types";
import { generateEmbeddings } from "../../helpers";
import { upsertEmbeddings } from "../../adapters/supabase";

/**
* Generates an embedding vector for the current issue
*/
export const embeddings = async () => {
const { payload: _payload } = getBotContext();
const logger = getLogger();
const payload = _payload as Payload;
const issue = payload.issue;

if (!issue) {
logger.info(`Skip to generate embeddings because of no issue instance`);
return;
}

if (!issue.body) {
logger.info("Skip to generate embeddings because of empty body");
return;
}

const embeddings = await generateEmbeddings(issue.body);
if (embeddings.length > 0) {
await upsertEmbeddings(payload.repository.owner.login, payload.repository.name, issue.number, embeddings);
}
};
1 change: 1 addition & 0 deletions src/handlers/issue/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export * from "./pre";
export * from "./embeddings";
8 changes: 4 additions & 4 deletions src/handlers/processors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import { handleComment, issueClosedCallback, issueCreatedCallback, issueReopened
import { checkPullRequests } from "./assign/auto";
import { createDevPoolPR } from "./pull-request";
import { runOnPush, validateConfigChange } from "./push";
import { findDuplicateOne } from "./issue";
import { findDuplicateOne, embeddings } from "./issue";
import { watchLabelChange } from "./label";

export const processors: Record<string, Handler> = {
[GithubEvent.ISSUES_OPENED]: {
pre: [nullHandler],
action: [findDuplicateOne, issueCreatedCallback],
post: [nullHandler],
post: [embeddings],
},
[GithubEvent.ISSUES_REOPENED]: {
pre: [nullHandler],
Expand Down Expand Up @@ -44,12 +44,12 @@ export const processors: Record<string, Handler> = {
[GithubEvent.ISSUE_COMMENT_CREATED]: {
pre: [nullHandler],
action: [handleComment],
post: [nullHandler],
post: [embeddings],
},
[GithubEvent.ISSUE_COMMENT_EDITED]: {
pre: [nullHandler],
action: [handleComment],
post: [nullHandler],
post: [embeddings],
},
[GithubEvent.ISSUES_CLOSED]: {
pre: [nullHandler],
Expand Down
66 changes: 66 additions & 0 deletions src/helpers/gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,69 @@ export const askGPT = async (question: string, chatHistory: CreateChatCompletion

return { answer, tokenUsage };
};

/**
* What is embedding?
* An embedding is a vector of floating point numbers to measure the relatedness of text strings.
* How can I get an embedding using OpenAI?
* To get an embedding, send your text string to the embeddings API endpoint along with a choice of embedding model ID (e.g., text-embedding-ada-002).
* The response will contain an embedding, which you can extract, save, and use.
*
* Example Request:
*
* curl https://api.openai.com/v1/embeddings \
* -H "Content-Type: application/json" \
* -H "Authorization: Bearer $OPENAI_API_KEY" \
* -d '{
* "input": "Your text string goes here",
* "model": "text-embedding-ada-002"
* }'
*
* Example Response:
*
* {
* "data": [
* {
* "embedding": [
* -0.006929283495992422,
* -0.005336422007530928,
* ...
* -4.547132266452536e-05,
* -0.024047505110502243
* ],
* "index": 0,
* "object": "embedding"
* }
* ],
* "model": "text-embedding-ada-002",
* "object": "list",
* "usage": {
* "prompt_tokens": 5,
* "total_tokens": 5
* }
* }
* @param words - The input data to generate the embedding for
*/
export const generateEmbeddings = async (words: string): Promise<number[]> => {
const logger = getLogger();
const config = getBotConfig();

if (!config.ask.apiKey) {
logger.info(`No OpenAI API Key provided`);
throw new Error("You must configure the `openai-api-key` property in the bot configuration in order to use AI powered features.");
}

const openai = new OpenAI({
apiKey: config.ask.apiKey,
});

const embedding = await openai.embeddings.create({
// TODO: A couple of embedding models exist and `text-embedding-ada-002` is the one recommended by OpenAI
// because it's better, cheaper and simpler to use.
// We might need to move the hardcoded model: `text-embedding-ada-002` to the bot configuration for better extensibility
model: "text-embedding-ada-002",
input: words,
});

return embedding.data[0]["embedding"] as number[];
};
1 change: 1 addition & 0 deletions src/helpers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ export * from "./payout";
export * from "./file";
export * from "./similarity";
export * from "./commit";
export * from "./gpt";
11 changes: 11 additions & 0 deletions supabase/migrations/20231030085814_create_embedding_table.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- Creates the `embeddings` table to store an embedding for every issue
-- We've decided to use the `text-embedding-ada-002` model in the beginning which has 1536 output dimensions
CREATE TABLE IF NOT EXISTS embeddings (
id serial PRIMARY KEY,
org character varying(255) NOT NULL,
repo character varying(255) NOT NULL,
issue integer NOT NULL,
embedding vector(1536) NOT NULL,
created_at timestamptz NOT NULL,
updated_at timestamptz NOT NULL,
);
Loading