diff --git a/airflow/providers/pinecone/hooks/pinecone.py b/airflow/providers/pinecone/hooks/pinecone.py index 35aa66c3204b56..b5e73ae4c69a07 100644 --- a/airflow/providers/pinecone/hooks/pinecone.py +++ b/airflow/providers/pinecone/hooks/pinecone.py @@ -29,6 +29,7 @@ from airflow.hooks.base import BaseHook if TYPE_CHECKING: + from pinecone import Vector from pinecone.core.client.model.sparse_values import SparseValues from pinecone.core.client.models import DescribeIndexStatsResponse, QueryResponse, UpsertResponse @@ -137,7 +138,7 @@ def list_indexes(self) -> Any: def upsert( self, index_name: str, - vectors: list[Any], + vectors: list[Vector] | list[tuple] | list[dict], namespace: str = "", batch_size: int | None = None, show_progress: bool = True, diff --git a/airflow/providers/pinecone/operators/pinecone.py b/airflow/providers/pinecone/operators/pinecone.py index bb3d44214d42b7..70711e062308d9 100644 --- a/airflow/providers/pinecone/operators/pinecone.py +++ b/airflow/providers/pinecone/operators/pinecone.py @@ -25,6 +25,8 @@ from airflow.utils.context import Context if TYPE_CHECKING: + from pinecone import Vector + from airflow.utils.context import Context @@ -38,8 +40,8 @@ class PineconeIngestOperator(BaseOperator): :param conn_id: The connection id to use when connecting to Pinecone. :param index_name: Name of the Pinecone index. - :param input_vectors: Data to be ingested, in the form of a list of tuples where each tuple - contains (id, vector_embedding, metadata). + :param input_vectors: Data to be ingested, in the form of a list of vectors, list of tuples, + or list of dictionaries. :param namespace: The namespace to write to. If not specified, the default namespace is used. :param batch_size: The number of vectors to upsert in each batch. :param upsert_kwargs: .. seealso:: https://docs.pinecone.io/reference/upsert @@ -52,7 +54,7 @@ def __init__( *, conn_id: str = PineconeHook.default_conn_name, index_name: str, - input_vectors: list[tuple], + input_vectors: list[Vector] | list[tuple] | list[dict], namespace: str = "", batch_size: int | None = None, upsert_kwargs: dict | None = None, diff --git a/tests/system/providers/pinecone/example_pinecone_cohere.py b/tests/system/providers/pinecone/example_pinecone_cohere.py index c74a376f61406e..80e6766484d6bd 100644 --- a/tests/system/providers/pinecone/example_pinecone_cohere.py +++ b/tests/system/providers/pinecone/example_pinecone_cohere.py @@ -17,7 +17,6 @@ from __future__ import annotations import os -import time from datetime import datetime from airflow import DAG @@ -46,19 +45,23 @@ def create_index(): hook = PineconeHook() pod_spec = hook.get_pod_spec_obj() hook.create_index(index_name=index_name, dimension=768, spec=pod_spec) - time.sleep(60) embed_task = CohereEmbeddingOperator( task_id="embed_task", input_text=data, ) + @task + def transform_output(embedding_output) -> list[dict]: + # Convert each embedding to a map with an ID and the embedding vector + return [dict(id=str(i), values=embedding) for i, embedding in enumerate(embedding_output)] + + transformed_output = transform_output(embed_task.output) + perform_ingestion = PineconeIngestOperator( task_id="perform_ingestion", index_name=index_name, - input_vectors=[ - ("id1", embed_task.output), - ], + input_vectors=transformed_output, namespace=namespace, batch_size=1, ) @@ -71,7 +74,7 @@ def delete_index(): hook = PineconeHook() hook.delete_index(index_name=index_name) - create_index() >> embed_task >> perform_ingestion >> delete_index() + create_index() >> embed_task >> transformed_output >> perform_ingestion >> delete_index() from tests.system.utils import get_test_run # noqa: E402