Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] asyncio for grpc data interactions #398

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions .github/workflows/alpha-release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@ on:
default: 'rc1'

jobs:
unit-tests:
uses: './.github/workflows/testing-unit.yaml'
secrets: inherit
integration-tests:
uses: './.github/workflows/testing-integration.yaml'
secrets: inherit
dependency-tests:
uses: './.github/workflows/testing-dependency.yaml'
secrets: inherit
# unit-tests:
# uses: './.github/workflows/testing-unit.yaml'
# secrets: inherit
# integration-tests:
# uses: './.github/workflows/testing-integration.yaml'
# secrets: inherit
# dependency-tests:
# uses: './.github/workflows/testing-dependency.yaml'
# secrets: inherit

pypi:
uses: './.github/workflows/publish-to-pypi.yaml'
needs:
- unit-tests
- integration-tests
- dependency-tests
# needs:
# - unit-tests
# - integration-tests
# - dependency-tests
with:
isPrerelease: true
ref: ${{ inputs.ref }}
Expand All @@ -49,4 +49,3 @@ jobs:
secrets:
PYPI_USERNAME: __token__
PYPI_PASSWORD: ${{ secrets.PROD_PYPI_PUBLISH_TOKEN }}

6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ venv.bak/
.ropeproject

# pdocs documentation
# We want to exclude any locally generated artifacts, but we rely on
# We want to exclude any locally generated artifacts, but we rely on
# keeping documentation assets in the docs/ folder.
docs/*
!docs/pinecone-python-client-fork.png
Expand All @@ -155,4 +155,6 @@ dmypy.json
*.hdf5
*~

tests/integration/proxy_config/logs
tests/integration/proxy_config/logs
*.parquet
app*.py
16 changes: 9 additions & 7 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import logging
from typing import Optional, Dict, Any, Union, List, Tuple, Literal
from typing import Optional, Dict, Any, Union, Literal

from .index_host_store import IndexHostStore

Expand All @@ -10,7 +10,12 @@
from pinecone.core.openapi.shared.api_client import ApiClient


from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client
from pinecone.utils import (
normalize_host,
setup_openapi_client,
build_plugin_setup_client,
parse_non_empty_args,
)
from pinecone.core.openapi.control.models import (
CreateCollectionRequest,
CreateIndexRequest,
Expand Down Expand Up @@ -317,9 +322,6 @@ def create_index(

api_instance = self.index_api

def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
return {arg_name: val for arg_name, val in args if val is not None}

if deletion_protection in ["enabled", "disabled"]:
dp = DeletionProtection(deletion_protection)
else:
Expand All @@ -329,7 +331,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
if "serverless" in spec:
index_spec = IndexSpec(serverless=ServerlessSpecModel(**spec["serverless"]))
elif "pod" in spec:
args_dict = _parse_non_empty_args(
args_dict = parse_non_empty_args(
[
("environment", spec["pod"].get("environment")),
("metadata_config", spec["pod"].get("metadata_config")),
Expand All @@ -351,7 +353,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
serverless=ServerlessSpecModel(cloud=spec.cloud, region=spec.region)
)
elif isinstance(spec, PodSpec):
args_dict = _parse_non_empty_args(
args_dict = parse_non_empty_args(
[
("replicas", spec.replicas),
("shards", spec.shards),
Expand Down
1 change: 1 addition & 0 deletions pinecone/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"""

from .index_grpc import GRPCIndex
from .index_grpc_asyncio import GRPCIndexAsyncio
from .pinecone import PineconeGRPC
from .config import GRPCClientConfig

Expand Down
18 changes: 11 additions & 7 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,32 @@
from pinecone import Config
from .config import GRPCClientConfig
from .grpc_runner import GrpcRunner
from .utils import normalize_endpoint


class GRPCIndexBase(ABC):
"""
Base class for grpc-based interaction with Pinecone indexes
"""

_pool = None

def __init__(
self,
index_name: str,
config: Config,
channel: Optional[Channel] = None,
grpc_config: Optional[GRPCClientConfig] = None,
_endpoint_override: Optional[str] = None,
use_asyncio: Optional[bool] = False,
):
self.config = config
self.grpc_client_config = grpc_config or GRPCClientConfig()

self._endpoint_override = _endpoint_override

self.runner = GrpcRunner(
index_name=index_name, config=config, grpc_config=self.grpc_client_config
)
self.channel_factory = GrpcChannelFactory(
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=False
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=use_asyncio
)
self._channel = channel or self._gen_channel()
self.stub = self.stub_class(self._channel)
Expand All @@ -46,9 +45,7 @@ def stub_class(self):
pass

def _endpoint(self):
grpc_host = self.config.host.replace("https://", "")
if ":" not in grpc_host:
grpc_host = f"{grpc_host}:443"
grpc_host = normalize_endpoint(self.config.host)
return self._endpoint_override if self._endpoint_override else grpc_host

def _gen_channel(self):
Expand Down Expand Up @@ -83,3 +80,10 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
self.close()

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, traceback):
self.close()
return True
29 changes: 21 additions & 8 deletions pinecone/grpc/grpc_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from functools import wraps
from typing import Dict, Tuple, Optional

Expand Down Expand Up @@ -62,20 +63,32 @@ async def run_asyncio(
credentials: Optional[CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[Compression] = None,
semaphore: Optional[asyncio.Semaphore] = None,
):
@wraps(func)
async def wrapped():
user_provided_metadata = metadata or {}
_metadata = self._prepare_metadata(user_provided_metadata)
try:
return await func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
if semaphore is not None:
async with semaphore:
return await func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
else:
return await func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
except _InactiveRpcError as e:
raise PineconeException(e._state.debug_error_string) from e

Expand Down
81 changes: 37 additions & 44 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import logging
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, cast
from typing import Optional, Dict, Union, List, cast

from google.protobuf import json_format

from tqdm.autonotebook import tqdm

from pinecone.utils import parse_non_empty_args
from .utils import (
dict_to_proto_struct,
parse_fetch_response,
parse_query_response,
parse_stats_response,
parse_sparse_values_arg,
)
from .vector_factory_grpc import VectorFactoryGRPC
from .base import GRPCIndexBase
from .future import PineconeGrpcFuture
from .sparse_vector import SparseVectorTypedDict
from .config import GRPCClientConfig

from pinecone.core.openapi.data.models import (
FetchResponse,
Expand All @@ -36,23 +42,36 @@
)
from pinecone import Vector as NonGRPCVector
from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub
from .base import GRPCIndexBase
from .future import PineconeGrpcFuture

from pinecone.config import Config
from grpc._channel import Channel


__all__ = ["GRPCIndex", "GRPCVector", "GRPCQueryVector", "GRPCSparseValues"]

_logger = logging.getLogger(__name__)


class SparseVectorTypedDict(TypedDict):
indices: List[int]
values: List[float]


class GRPCIndex(GRPCIndexBase):
"""A client for interacting with a Pinecone index via GRPC API."""

def __init__(
self,
index_name: str,
config: Config,
channel: Optional[Channel] = None,
grpc_config: Optional[GRPCClientConfig] = None,
_endpoint_override: Optional[str] = None,
):
super().__init__(
index_name=index_name,
config=config,
channel=channel,
grpc_config=grpc_config,
_endpoint_override=_endpoint_override,
use_asyncio=False,
)

@property
def stub_class(self):
return VectorServiceStub
Expand Down Expand Up @@ -131,7 +150,7 @@ def upsert(

vectors = list(map(VectorFactoryGRPC.build, vectors))
if async_req:
args_dict = self._parse_non_empty_args([("namespace", namespace)])
args_dict = parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict, **kwargs)
future = self.runner.run(self.stub.Upsert.future, request, timeout=timeout)
return PineconeGrpcFuture(future)
Expand All @@ -157,7 +176,7 @@ def upsert(
def _upsert_batch(
self, vectors: List[GRPCVector], namespace: Optional[str], timeout: Optional[int], **kwargs
) -> UpsertResponse:
args_dict = self._parse_non_empty_args([("namespace", namespace)])
args_dict = parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict)
return self.runner.run(self.stub.Upsert, request, timeout=timeout, **kwargs)

Expand Down Expand Up @@ -264,7 +283,7 @@ def delete(
else:
filter_struct = None

args_dict = self._parse_non_empty_args(
args_dict = parse_non_empty_args(
[
("ids", ids),
("delete_all", delete_all),
Expand Down Expand Up @@ -301,7 +320,7 @@ def fetch(
"""
timeout = kwargs.pop("timeout", None)

args_dict = self._parse_non_empty_args([("namespace", namespace)])
args_dict = parse_non_empty_args([("namespace", namespace)])

request = FetchRequest(ids=ids, **args_dict, **kwargs)
response = self.runner.run(self.stub.Fetch, request, timeout=timeout)
Expand Down Expand Up @@ -367,8 +386,8 @@ def query(
else:
filter_struct = None

sparse_vector = self._parse_sparse_values_arg(sparse_vector)
args_dict = self._parse_non_empty_args(
sparse_vector = parse_sparse_values_arg(sparse_vector)
args_dict = parse_non_empty_args(
[
("vector", vector),
("id", id),
Expand Down Expand Up @@ -435,8 +454,8 @@ def update(
set_metadata_struct = None

timeout = kwargs.pop("timeout", None)
sparse_values = self._parse_sparse_values_arg(sparse_values)
args_dict = self._parse_non_empty_args(
sparse_values = parse_sparse_values_arg(sparse_values)
args_dict = parse_non_empty_args(
[
("values", values),
("set_metadata", set_metadata_struct),
Expand Down Expand Up @@ -485,7 +504,7 @@ def list_paginated(

Returns: SimpleListResponse object which contains the list of ids, the namespace name, pagination information, and usage showing the number of read_units consumed.
"""
args_dict = self._parse_non_empty_args(
args_dict = parse_non_empty_args(
[
("prefix", prefix),
("limit", limit),
Expand Down Expand Up @@ -564,36 +583,10 @@ def describe_index_stats(
filter_struct = dict_to_proto_struct(filter)
else:
filter_struct = None
args_dict = self._parse_non_empty_args([("filter", filter_struct)])
args_dict = parse_non_empty_args([("filter", filter_struct)])
timeout = kwargs.pop("timeout", None)

request = DescribeIndexStatsRequest(**args_dict)
response = self.runner.run(self.stub.DescribeIndexStats, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_stats_response(json_response)

@staticmethod
def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
return {arg_name: val for arg_name, val in args if val is not None}

@staticmethod
def _parse_sparse_values_arg(
sparse_values: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]],
) -> Optional[GRPCSparseValues]:
if sparse_values is None:
return None

if isinstance(sparse_values, GRPCSparseValues):
return sparse_values

if (
not isinstance(sparse_values, dict)
or "indices" not in sparse_values
or "values" not in sparse_values
):
raise ValueError(
"Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}."
f"Received: {sparse_values}"
)

return GRPCSparseValues(indices=sparse_values["indices"], values=sparse_values["values"])
Loading
Loading