Skip to content

Commit

Permalink
Fix the argument type of input_vectors in pinecone upsert (apache#39688)
Browse files Browse the repository at this point in the history
* Fix the argument type of input_vectors

* Fix typing and docstring
  • Loading branch information
sunank200 authored and romsharon98 committed Jul 26, 2024
1 parent 45df809 commit eb53295
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
3 changes: 2 additions & 1 deletion airflow/providers/pinecone/hooks/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
from pinecone import Vector
from pinecone.core.client.model.sparse_values import SparseValues
from pinecone.core.client.models import DescribeIndexStatsResponse, QueryResponse, UpsertResponse

Expand Down Expand Up @@ -137,7 +138,7 @@ def list_indexes(self) -> Any:
def upsert(
self,
index_name: str,
vectors: list[Any],
vectors: list[Vector] | list[tuple] | list[dict],
namespace: str = "",
batch_size: int | None = None,
show_progress: bool = True,
Expand Down
8 changes: 5 additions & 3 deletions airflow/providers/pinecone/operators/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from airflow.utils.context import Context

if TYPE_CHECKING:
from pinecone import Vector

from airflow.utils.context import Context


Expand All @@ -38,8 +40,8 @@ class PineconeIngestOperator(BaseOperator):
:param conn_id: The connection id to use when connecting to Pinecone.
:param index_name: Name of the Pinecone index.
:param input_vectors: Data to be ingested, in the form of a list of tuples where each tuple
contains (id, vector_embedding, metadata).
:param input_vectors: Data to be ingested, in the form of a list of vectors, list of tuples,
or list of dictionaries.
:param namespace: The namespace to write to. If not specified, the default namespace is used.
:param batch_size: The number of vectors to upsert in each batch.
:param upsert_kwargs: .. seealso:: https://docs.pinecone.io/reference/upsert
Expand All @@ -52,7 +54,7 @@ def __init__(
*,
conn_id: str = PineconeHook.default_conn_name,
index_name: str,
input_vectors: list[tuple],
input_vectors: list[Vector] | list[tuple] | list[dict],
namespace: str = "",
batch_size: int | None = None,
upsert_kwargs: dict | None = None,
Expand Down
15 changes: 9 additions & 6 deletions tests/system/providers/pinecone/example_pinecone_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import os
import time
from datetime import datetime

from airflow import DAG
Expand Down Expand Up @@ -46,19 +45,23 @@ def create_index():
hook = PineconeHook()
pod_spec = hook.get_pod_spec_obj()
hook.create_index(index_name=index_name, dimension=768, spec=pod_spec)
time.sleep(60)

embed_task = CohereEmbeddingOperator(
task_id="embed_task",
input_text=data,
)

@task
def transform_output(embedding_output) -> list[dict]:
# Convert each embedding to a map with an ID and the embedding vector
return [dict(id=str(i), values=embedding) for i, embedding in enumerate(embedding_output)]

transformed_output = transform_output(embed_task.output)

perform_ingestion = PineconeIngestOperator(
task_id="perform_ingestion",
index_name=index_name,
input_vectors=[
("id1", embed_task.output),
],
input_vectors=transformed_output,
namespace=namespace,
batch_size=1,
)
Expand All @@ -71,7 +74,7 @@ def delete_index():
hook = PineconeHook()
hook.delete_index(index_name=index_name)

create_index() >> embed_task >> perform_ingestion >> delete_index()
create_index() >> embed_task >> transformed_output >> perform_ingestion >> delete_index()

from tests.system.utils import get_test_run # noqa: E402

Expand Down

0 comments on commit eb53295

Please sign in to comment.