Skip to content

Commit

Permalink
Fix the argument type of input_vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
sunank200 committed May 17, 2024
1 parent 8b19b78 commit 9ef1b86
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/pinecone/hooks/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any

from pinecone import Pinecone, PodSpec, ServerlessSpec
from pinecone import Pinecone, PodSpec, ServerlessSpec, Vector

from airflow.hooks.base import BaseHook

Expand Down Expand Up @@ -137,7 +137,7 @@ def list_indexes(self) -> Any:
def upsert(
self,
index_name: str,
vectors: list[Any],
vectors: list[Vector] | list[tuple] | list[dict],
namespace: str = "",
batch_size: int | None = None,
show_progress: bool = True,
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/pinecone/operators/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Any, Sequence

from airflow.models import BaseOperator
from airflow.providers.pinecone.hooks.pinecone import PineconeHook
from airflow.providers.pinecone.hooks.pinecone import PineconeHook, Vector
from airflow.utils.context import Context

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
*,
conn_id: str = PineconeHook.default_conn_name,
index_name: str,
input_vectors: list[tuple],
input_vectors: list[Vector] | list[tuple] | list[dict],
namespace: str = "",
batch_size: int | None = None,
upsert_kwargs: dict | None = None,
Expand Down
15 changes: 9 additions & 6 deletions tests/system/providers/pinecone/example_pinecone_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import os
import time
from datetime import datetime

from airflow import DAG
Expand Down Expand Up @@ -46,19 +45,23 @@ def create_index():
hook = PineconeHook()
pod_spec = hook.get_pod_spec_obj()
hook.create_index(index_name=index_name, dimension=768, spec=pod_spec)
time.sleep(60)

embed_task = CohereEmbeddingOperator(
task_id="embed_task",
input_text=data,
)

@task
def transform_output(embedding_output) -> list[dict]:
# Convert each embedding to a map with an ID and the embedding vector
return [dict(id=str(i), values=embedding) for i, embedding in enumerate(embedding_output)]

transformed_output = transform_output(embed_task.output)

perform_ingestion = PineconeIngestOperator(
task_id="perform_ingestion",
index_name=index_name,
input_vectors=[
("id1", embed_task.output),
],
input_vectors=transformed_output,
namespace=namespace,
batch_size=1,
)
Expand All @@ -71,7 +74,7 @@ def delete_index():
hook = PineconeHook()
hook.delete_index(index_name=index_name)

create_index() >> embed_task >> perform_ingestion >> delete_index()
create_index() >> embed_task >> transformed_output >> perform_ingestion >> delete_index()

from tests.system.utils import get_test_run # noqa: E402

Expand Down

0 comments on commit 9ef1b86

Please sign in to comment.