Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pinecone): Add support for Pinecone /embed endpoint #7203

Merged
merged 17 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
344 changes: 344 additions & 0 deletions docs/core_docs/docs/integrations/text_embedding/pinecone.ipynb

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions examples/src/embeddings/pinecone.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { PineconeEmbeddings } from "@langchain/pinecone";

export const run = async () => {
const model = new PineconeEmbeddings();
console.log({ model }); // Prints out model metadata
const res = await model.embedQuery(
"What would be a good company name a company that makes colorful socks?"
);
console.log({ res });
};

await run();
2 changes: 1 addition & 1 deletion examples/src/indexes/vector_stores/pinecone/delete_docs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { OpenAIEmbeddings } from "@langchain/openai";
import { PineconeStore } from "@langchain/pinecone";

// Instantiate a new Pinecone client, which will automatically read the
// env vars: PINECONE_API_KEY and PINECONE_ENVIRONMENT which come from
// env vars: PINECONE_API_KEY which comes from
// the Pinecone dashboard at https://app.pinecone.io

const pinecone = new Pinecone();
Expand Down
3 changes: 2 additions & 1 deletion examples/src/indexes/vector_stores/pinecone/index_docs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ import { Pinecone } from "@pinecone-database/pinecone";
import { Document } from "@langchain/core/documents";
import { OpenAIEmbeddings } from "@langchain/openai";
import { PineconeStore } from "@langchain/pinecone";
// import { Index } from "@upstash/vector";

// Instantiate a new Pinecone client, which will automatically read the
// env vars: PINECONE_API_KEY and PINECONE_ENVIRONMENT which come from
// env vars: PINECONE_API_KEY which comes from
// the Pinecone dashboard at https://app.pinecone.io

const pinecone = new Pinecone();
Expand Down
2 changes: 1 addition & 1 deletion examples/src/indexes/vector_stores/pinecone/mmr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { OpenAIEmbeddings } from "@langchain/openai";
import { PineconeStore } from "@langchain/pinecone";

// Instantiate a new Pinecone client, which will automatically read the
// env vars: PINECONE_API_KEY and PINECONE_ENVIRONMENT which come from
// env vars: PINECONE_API_KEY which comes from
// the Pinecone dashboard at https://app.pinecone.io

const pinecone = new Pinecone();
Expand Down
2 changes: 1 addition & 1 deletion examples/src/indexes/vector_stores/pinecone/query_docs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { OpenAIEmbeddings } from "@langchain/openai";
import { PineconeStore } from "@langchain/pinecone";

// Instantiate a new Pinecone client, which will automatically read the
// env vars: PINECONE_API_KEY and PINECONE_ENVIRONMENT which come from
// env vars: PINECONE_API_KEY which comes from
// the Pinecone dashboard at https://app.pinecone.io

const pinecone = new Pinecone();
Expand Down
10 changes: 2 additions & 8 deletions examples/src/retrievers/pinecone_self_query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,8 @@ const attributeInfo: AttributeInfo[] = [
* Next, we instantiate a vector store. This is where we store the embeddings of the documents.
* We also need to provide an embeddings object. This is used to embed the documents.
*/
if (
!process.env.PINECONE_API_KEY ||
!process.env.PINECONE_ENVIRONMENT ||
!process.env.PINECONE_INDEX
) {
throw new Error(
"PINECONE_ENVIRONMENT and PINECONE_API_KEY and PINECONE_INDEX must be set"
);
if (!process.env.PINECONE_API_KEY || !process.env.PINECONE_INDEX) {
throw new Error("PINECONE_API_KEY and PINECONE_INDEX must be set");
}

const pinecone = new Pinecone();
Expand Down
1 change: 0 additions & 1 deletion libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
"@neondatabase/serverless": "^0.9.1",
"@notionhq/client": "^2.2.10",
"@opensearch-project/opensearch": "^2.2.0",
"@pinecone-database/pinecone": "^1.1.0",
"@planetscale/database": "^1.8.0",
"@premai/prem-sdk": "^0.3.25",
"@qdrant/js-client-rest": "^1.8.2",
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain-pinecone/jest.config.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ module.exports = {
setupFiles: ["dotenv/config"],
testTimeout: 20_000,
passWithNoTests: true,
collectCoverageFrom: ["src/**/*.ts"]
};
collectCoverageFrom: ["src/**/*.ts"],
};
2 changes: 1 addition & 1 deletion libs/langchain-pinecone/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"author": "Pinecone, Inc",
"license": "MIT",
"dependencies": {
"@pinecone-database/pinecone": "^3.0.0 || ^4.0.0",
"@pinecone-database/pinecone": "^4.0.0",
"flat": "^5.0.2",
"uuid": "^10.0.0"
},
Expand Down
16 changes: 16 additions & 0 deletions libs/langchain-pinecone/src/client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { Pinecone, PineconeConfiguration } from "@pinecone-database/pinecone";
import { getEnvironmentVariable } from "@langchain/core/utils/env";

export function getPineconeClient(config?: PineconeConfiguration): Pinecone {
if (
getEnvironmentVariable("PINECONE_API_KEY") === undefined ||
getEnvironmentVariable("PINECONE_API_KEY") === ""
) {
throw new Error("PINECONE_API_KEY must be set in environment");
}
if (!config) {
return new Pinecone();
} else {
return new Pinecone(config);
}
}
139 changes: 139 additions & 0 deletions libs/langchain-pinecone/src/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/* eslint-disable arrow-body-style */

import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import {
EmbeddingsList,
Pinecone,
PineconeConfiguration,
} from "@pinecone-database/pinecone";
import { getPineconeClient } from "./client.js";

/* PineconeEmbeddingsParams holds the optional fields a user can pass to a Pinecone embedding model.
* @param model - Model to use to generate embeddings. Default is "multilingual-e5-large".
* @param params - Additional parameters to pass to the embedding model. Note: parameters are model-specific. Read
* more about model-specific parameters in the [Pinecone
* documentation](https://docs.pinecone.io/guides/inference/understanding-inference#model-specific-parameters).
* */
export interface PineconeEmbeddingsParams extends EmbeddingsParams {
model?: string; // Model to use to generate embeddings
params?: Record<string, string>; // Additional parameters to pass to the embedding model
}

/* PineconeEmbeddings generates embeddings using the Pinecone Inference API. */
export class PineconeEmbeddings
extends Embeddings
implements PineconeEmbeddingsParams
{
client: Pinecone;

model: string;

params: Record<string, string>;

constructor(
fields?: Partial<PineconeEmbeddingsParams> & Partial<PineconeConfiguration>
) {
const defaultFields = { maxRetries: 3, ...fields };
super(defaultFields);

if (defaultFields.apiKey) {
const config = {
apiKey: defaultFields.apiKey,
controllerHostUrl: defaultFields.controllerHostUrl,
fetchApi: defaultFields.fetchApi,
additionalHeaders: defaultFields.additionalHeaders,
sourceTag: defaultFields.sourceTag,
} as PineconeConfiguration;
this.client = getPineconeClient(config);
} else {
this.client = getPineconeClient();
}

if (!defaultFields.model) {
this.model = "multilingual-e5-large";
} else {
this.model = defaultFields.model;
}

const defaultParams = { inputType: "passage" };

if (defaultFields.params) {
this.params = { ...defaultFields.params, ...defaultParams };
} else {
this.params = defaultParams;
}
}

/* Generate embeddings for a list of input strings using a specified embedding model.
*
* @param texts - List of input strings for which to generate embeddings.
* */
async embedDocuments(texts: string[]): Promise<number[][]> {
if (texts.length === 0) {
throw new Error(
"At least one document is required to generate embeddings"
);
}

let embeddings;
if (this.params) {
embeddings = await this.caller.call(async () => {
const result: EmbeddingsList = await this.client.inference.embed(
this.model,
texts,
this.params
);
return result;
});
} else {
embeddings = await this.caller.call(async () => {
const result: EmbeddingsList = await this.client.inference.embed(
this.model,
texts,
{}
);
return result;
});
}

const embeddingsList: number[][] = [];

for (let i = 0; i < embeddings.length; i += 1) {
if (embeddings[i].values) {
embeddingsList.push(embeddings[i].values as number[]);
}
}
return embeddingsList;
}

/* Generate embeddings for a given query string using a specified embedding model.
* @param text - Query string for which to generate embeddings.
* */
async embedQuery(text: string): Promise<number[]> {
// Change inputType to query-specific param for multilingual-e5-large embedding model
this.params.inputType = "query";

if (!text) {
throw new Error("No query passed for which to generate embeddings");
}
let embeddings: EmbeddingsList;
if (this.params) {
embeddings = await this.caller.call(async () => {
return await this.client.inference.embed(
this.model,
[text],
this.params
);
});
} else {
embeddings = await this.caller.call(async () => {
return await this.client.inference.embed(this.model, [text], {});
});
}
if (embeddings[0].values) {
return embeddings[0].values as number[];
} else {
return [];
}
}
}
1 change: 1 addition & 0 deletions libs/langchain-pinecone/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from "./vectorstores.js";
export * from "./translator.js";
export * from "./embeddings.js";
39 changes: 39 additions & 0 deletions libs/langchain-pinecone/src/tests/client.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { Pinecone } from "@pinecone-database/pinecone";
import { getPineconeClient } from "../client.js";

describe("Tests for getPineconeClient", () => {
test("Happy path for getPineconeClient with and without `config` obj passed", async () => {
const client = getPineconeClient();
expect(client).toBeInstanceOf(Pinecone);
expect(client).toHaveProperty("config"); // Config is always set to *at least* the user's api key

const clientWithConfig = getPineconeClient({
// eslint-disable-next-line no-process-env
apiKey: process.env.PINECONE_API_KEY!,
additionalHeaders: { header: "value" },
});
expect(clientWithConfig).toBeInstanceOf(Pinecone);
expect(client).toHaveProperty("config"); // Unfortunately cannot assert on contents of config b/c it's a private
// attribute of the Pinecone class
});

test("Unhappy path: expect getPineconeClient to throw error if reset PINECONE_API_KEY to empty string", async () => {
// eslint-disable-next-line no-process-env
const originalApiKey = process.env.PINECONE_API_KEY;
try {
// eslint-disable-next-line no-process-env
process.env.PINECONE_API_KEY = "";
const errorThrown = async () => {
getPineconeClient();
};
await expect(errorThrown).rejects.toThrow(Error);
await expect(errorThrown).rejects.toThrow(
"PINECONE_API_KEY must be set in environment"
);
} finally {
// Restore the original value of PINECONE_API_KEY
// eslint-disable-next-line no-process-env
process.env.PINECONE_API_KEY = originalApiKey;
}
});
});
15 changes: 15 additions & 0 deletions libs/langchain-pinecone/src/tests/client.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { getPineconeClient } from "../client.js";

describe("Tests for getPineconeClient", () => {
test("Confirm getPineconeClient throws error when PINECONE_API_KEY is not set", async () => {
/* eslint-disable-next-line no-process-env */
process.env.PINECONE_API_KEY = "";
const errorThrown = async () => {
getPineconeClient();
};
await expect(errorThrown).rejects.toThrow(Error);
await expect(errorThrown).rejects.toThrow(
"PINECONE_API_KEY must be set in environment"
);
});
});
59 changes: 59 additions & 0 deletions libs/langchain-pinecone/src/tests/embeddings.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { PineconeEmbeddings } from "../embeddings.js";

describe("Integration tests for Pinecone embeddings", () => {
test("Happy path: defaults for both embedDocuments and embedQuery", async () => {
const model = new PineconeEmbeddings();
expect(model.model).toBe("multilingual-e5-large");
expect(model.params).toEqual({ inputType: "passage" });

const docs = ["hello", "world"];
const embeddings = await model.embedDocuments(docs);
expect(embeddings.length).toBe(docs.length);

const query = "hello";
const queryEmbedding = await model.embedQuery(query);
expect(queryEmbedding.length).toBeGreaterThan(0);
});

test("Happy path: custom `params` obj passed to embedDocuments and embedQuery", async () => {
const model = new PineconeEmbeddings({
params: { customParam: "value" },
});
expect(model.model).toBe("multilingual-e5-large");
expect(model.params).toEqual({
inputType: "passage",
customParam: "value",
});

const docs = ["hello", "world"];
const embeddings = await model.embedDocuments(docs);
expect(embeddings.length).toBe(docs.length);
expect(embeddings[0].length).toBe(1024); // Assert correct dims on random doc
expect(model.model).toBe("multilingual-e5-large");
expect(model.params).toEqual({
inputType: "passage", // Maintain default inputType for docs
customParam: "value",
});

const query = "hello";
const queryEmbedding = await model.embedQuery(query);
expect(model.model).toBe("multilingual-e5-large");
expect(queryEmbedding.length).toBe(1024);
expect(model.params).toEqual({
inputType: "query", // Change inputType for query
customParam: "value",
});
});

test("Unhappy path: embedDocuments and embedQuery throw when empty objs are passed", async () => {
const model = new PineconeEmbeddings();
await expect(model.embedDocuments([])).rejects.toThrow();
await expect(model.embedQuery("")).rejects.toThrow();
});

test("Unhappy path: PineconeEmbeddings throws when invalid model is passed", async () => {
const model = new PineconeEmbeddings({ model: "invalid-model" });
await expect(model.embedDocuments([])).rejects.toThrow();
await expect(model.embedQuery("")).rejects.toThrow();
});
});
Loading
Loading