Skip to content

Commit

Permalink
[Refactor] Extract GrpcChannelFactory from GRPCIndexBase (#394)
Browse files Browse the repository at this point in the history
## Problem

I'm preparing to implement asyncio for the data plane, and I had a need
to extract some of this grpc channel configuration into a spot where it
could be reused more easily across both sync and async implementations.

## Solution

- Extract `GrpcChannelFactory` from `GRPCIndexBase`
- Add some unit tests for this new class

## Type of Change

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] This change requires a documentation update
- [ ] Infrastructure change (CI configs, etc)
- [ ] Non-code change (docs, etc)
- [x] None of the above: Refactoring only, should be no functional
change
  • Loading branch information
jhamon authored Oct 11, 2024
1 parent 1d0f046 commit 4c18899
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 64 deletions.
71 changes: 7 additions & 64 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
from functools import wraps
from typing import Dict, Optional

import certifi
import grpc
from grpc._channel import _InactiveRpcError, Channel
import json

from .retry import RetryConfig
from .channel_factory import GrpcChannelFactory

from pinecone import Config
from .utils import _generate_request_id
from .config import GRPCClientConfig
from pinecone.utils.constants import MAX_MSG_SIZE, REQUEST_ID, CLIENT_VERSION
from pinecone.utils.user_agent import get_user_agent_grpc
from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION
from pinecone.exceptions.exceptions import PineconeException

_logger = logging.getLogger(__name__)
Expand All @@ -35,8 +33,6 @@ def __init__(
grpc_config: Optional[GRPCClientConfig] = None,
_endpoint_override: Optional[str] = None,
):
self.name = index_name

self.config = config
self.grpc_client_config = grpc_config or GRPCClientConfig()
self.retry_config = self.grpc_client_config.retry_config or RetryConfig()
Expand All @@ -51,35 +47,10 @@ def __init__(

self._endpoint_override = _endpoint_override

self.method_config = json.dumps(
{
"methodConfig": [
{
"name": [{"service": "VectorService.Upsert"}],
"retryPolicy": {
"maxAttempts": 5,
"initialBackoff": "0.1s",
"maxBackoff": "1s",
"backoffMultiplier": 2,
"retryableStatusCodes": ["UNAVAILABLE"],
},
},
{
"name": [{"service": "VectorService"}],
"retryPolicy": {
"maxAttempts": 5,
"initialBackoff": "0.1s",
"maxBackoff": "1s",
"backoffMultiplier": 2,
"retryableStatusCodes": ["UNAVAILABLE"],
},
},
]
}
self.channel_factory = GrpcChannelFactory(
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=False
)

options = {"grpc.primary_user_agent": get_user_agent_grpc(config)}
self._channel = channel or self._gen_channel(options=options)
self._channel = channel or self._gen_channel()
self.stub = self.stub_class(self._channel)

@property
Expand All @@ -93,36 +64,8 @@ def _endpoint(self):
grpc_host = f"{grpc_host}:443"
return self._endpoint_override if self._endpoint_override else grpc_host

def _gen_channel(self, options=None):
target = self._endpoint()
default_options = {
"grpc.max_send_message_length": MAX_MSG_SIZE,
"grpc.max_receive_message_length": MAX_MSG_SIZE,
"grpc.service_config": self.method_config,
"grpc.enable_retries": True,
"grpc.per_rpc_retry_buffer_size": MAX_MSG_SIZE,
}
if self.grpc_client_config.secure:
default_options["grpc.ssl_target_name_override"] = target.split(":")[0]
if self.config.proxy_url:
default_options["grpc.http_proxy"] = self.config.proxy_url
user_provided_options = options or {}
_options = tuple((k, v) for k, v in {**default_options, **user_provided_options}.items())
_logger.debug(
"creating new channel with endpoint %s options %s and config %s",
target,
_options,
self.grpc_client_config,
)
if not self.grpc_client_config.secure:
channel = grpc.insecure_channel(target, options=_options)
else:
ca_certs = self.config.ssl_ca_certs if self.config.ssl_ca_certs else certifi.where()
root_cas = open(ca_certs, "rb").read()
tls = grpc.ssl_channel_credentials(root_certificates=root_cas)
channel = grpc.secure_channel(target, tls, options=_options)

return channel
def _gen_channel(self):
return self.channel_factory.create_channel(self._endpoint())

@property
def channel(self):
Expand Down
100 changes: 100 additions & 0 deletions pinecone/grpc/channel_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
from typing import Optional

import certifi
import grpc
import json

from pinecone import Config
from .config import GRPCClientConfig
from pinecone.utils.constants import MAX_MSG_SIZE
from pinecone.utils.user_agent import get_user_agent_grpc

_logger = logging.getLogger(__name__)


class GrpcChannelFactory:
def __init__(
self,
config: Config,
grpc_client_config: GRPCClientConfig,
use_asyncio: Optional[bool] = False,
):
self.config = config
self.grpc_client_config = grpc_client_config
self.use_asyncio = use_asyncio

def _get_service_config(self):
# https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto
return json.dumps(
{
"methodConfig": [
{
"name": [{"service": "VectorService.Upsert"}],
"retryPolicy": {
"maxAttempts": 5,
"initialBackoff": "0.1s",
"maxBackoff": "1s",
"backoffMultiplier": 2,
"retryableStatusCodes": ["UNAVAILABLE"],
},
},
{
"name": [{"service": "VectorService"}],
"retryPolicy": {
"maxAttempts": 5,
"initialBackoff": "0.1s",
"maxBackoff": "1s",
"backoffMultiplier": 2,
"retryableStatusCodes": ["UNAVAILABLE"],
},
},
]
}
)

def _build_options(self, target):
# For property definitions, see https://github.com/grpc/grpc/blob/v1.43.x/include/grpc/impl/codegen/grpc_types.h
options = {
"grpc.max_send_message_length": MAX_MSG_SIZE,
"grpc.max_receive_message_length": MAX_MSG_SIZE,
"grpc.service_config": self._get_service_config(),
"grpc.enable_retries": True,
"grpc.per_rpc_retry_buffer_size": MAX_MSG_SIZE,
"grpc.primary_user_agent": get_user_agent_grpc(self.config),
}
if self.grpc_client_config.secure:
options["grpc.ssl_target_name_override"] = target.split(":")[0]
if self.config.proxy_url:
options["grpc.http_proxy"] = self.config.proxy_url

options_tuple = tuple((k, v) for k, v in options.items())
return options_tuple

def _build_channel_credentials(self):
ca_certs = self.config.ssl_ca_certs if self.config.ssl_ca_certs else certifi.where()
root_cas = open(ca_certs, "rb").read()
channel_creds = grpc.ssl_channel_credentials(root_certificates=root_cas)
return channel_creds

def create_channel(self, endpoint):
options_tuple = self._build_options(endpoint)

_logger.debug(
"Creating new channel with endpoint %s options %s and config %s",
endpoint,
options_tuple,
self.grpc_client_config,
)

if not self.grpc_client_config.secure:
create_channel_fn = (
grpc.aio.insecure_channel if self.use_asyncio else grpc.insecure_channel
)
channel = create_channel_fn(endpoint, options=options_tuple)
else:
channel_creds = self._build_channel_credentials()
create_channel_fn = grpc.aio.secure_channel if self.use_asyncio else grpc.secure_channel
channel = create_channel_fn(endpoint, credentials=channel_creds, options=options_tuple)

return channel
141 changes: 141 additions & 0 deletions tests/unit_grpc/test_channel_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import grpc
import re
import pytest
from unittest.mock import patch, MagicMock, ANY

from pinecone import Config
from pinecone.grpc.channel_factory import GrpcChannelFactory, GRPCClientConfig
from pinecone.utils.constants import MAX_MSG_SIZE


@pytest.fixture
def config():
return Config(ssl_ca_certs=None, proxy_url=None)


@pytest.fixture
def grpc_client_config():
return GRPCClientConfig(secure=True)


class TestGrpcChannelFactory:
def test_create_secure_channel_with_default_settings(self, config, grpc_client_config):
factory = GrpcChannelFactory(
config=config, grpc_client_config=grpc_client_config, use_asyncio=False
)
endpoint = "test.endpoint:443"

with patch("grpc.secure_channel") as mock_secure_channel, patch(
"certifi.where", return_value="/path/to/certifi/cacert.pem"
), patch("builtins.open", new_callable=MagicMock) as mock_open:
# Mock the file object to return bytes when read() is called
mock_file = MagicMock()
mock_file.read.return_value = b"mocked_cert_data"
mock_open.return_value = mock_file
channel = factory.create_channel(endpoint)

mock_secure_channel.assert_called_once()
assert mock_secure_channel.call_args[0][0] == endpoint
assert isinstance(mock_secure_channel.call_args[1]["options"], tuple)

options = dict(mock_secure_channel.call_args[1]["options"])
assert options["grpc.ssl_target_name_override"] == "test.endpoint"
assert options["grpc.max_send_message_length"] == MAX_MSG_SIZE
assert options["grpc.per_rpc_retry_buffer_size"] == MAX_MSG_SIZE
assert options["grpc.max_receive_message_length"] == MAX_MSG_SIZE
assert "grpc.service_config" in options
assert options["grpc.enable_retries"] is True
assert (
re.search(
r"python-client\[grpc\]-\d+\.\d+\.\d+", options["grpc.primary_user_agent"]
)
is not None
)

assert isinstance(channel, MagicMock)

def test_create_secure_channel_with_proxy(self):
grpc_client_config = GRPCClientConfig(secure=True)
config = Config(proxy_url="http://test.proxy:8080")
factory = GrpcChannelFactory(
config=config, grpc_client_config=grpc_client_config, use_asyncio=False
)
endpoint = "test.endpoint:443"

with patch("grpc.secure_channel") as mock_secure_channel:
channel = factory.create_channel(endpoint)

mock_secure_channel.assert_called_once()
assert "grpc.http_proxy" in dict(mock_secure_channel.call_args[1]["options"])
assert (
"http://test.proxy:8080"
== dict(mock_secure_channel.call_args[1]["options"])["grpc.http_proxy"]
)
assert isinstance(channel, MagicMock)

def test_create_insecure_channel(self, config):
grpc_client_config = GRPCClientConfig(secure=False)
factory = GrpcChannelFactory(
config=config, grpc_client_config=grpc_client_config, use_asyncio=False
)
endpoint = "test.endpoint:50051"

with patch("grpc.insecure_channel") as mock_insecure_channel:
channel = factory.create_channel(endpoint)

mock_insecure_channel.assert_called_once_with(endpoint, options=ANY)
assert isinstance(channel, MagicMock)


class TestGrpcChannelFactoryAsyncio:
def test_create_secure_channel_with_default_settings(self, config, grpc_client_config):
factory = GrpcChannelFactory(
config=config, grpc_client_config=grpc_client_config, use_asyncio=True
)
endpoint = "test.endpoint:443"

with patch("grpc.aio.secure_channel") as mock_secure_aio_channel, patch(
"certifi.where", return_value="/path/to/certifi/cacert.pem"
), patch("builtins.open", new_callable=MagicMock) as mock_open:
# Mock the file object to return bytes when read() is called
mock_file = MagicMock()
mock_file.read.return_value = b"mocked_cert_data"
mock_open.return_value = mock_file
channel = factory.create_channel(endpoint)

mock_secure_aio_channel.assert_called_once()
assert mock_secure_aio_channel.call_args[0][0] == endpoint
assert isinstance(mock_secure_aio_channel.call_args[1]["options"], tuple)

options = dict(mock_secure_aio_channel.call_args[1]["options"])
assert options["grpc.ssl_target_name_override"] == "test.endpoint"
assert options["grpc.max_send_message_length"] == MAX_MSG_SIZE
assert options["grpc.per_rpc_retry_buffer_size"] == MAX_MSG_SIZE
assert options["grpc.max_receive_message_length"] == MAX_MSG_SIZE
assert "grpc.service_config" in options
assert options["grpc.enable_retries"] is True
assert (
re.search(
r"python-client\[grpc\]-\d+\.\d+\.\d+", options["grpc.primary_user_agent"]
)
is not None
)

security_credentials = mock_secure_aio_channel.call_args[1]["credentials"]
assert security_credentials is not None
assert isinstance(security_credentials, grpc.ChannelCredentials)

assert isinstance(channel, MagicMock)

def test_create_insecure_channel_asyncio(self, config):
grpc_client_config = GRPCClientConfig(secure=False)
factory = GrpcChannelFactory(
config=config, grpc_client_config=grpc_client_config, use_asyncio=True
)
endpoint = "test.endpoint:50051"

with patch("grpc.aio.insecure_channel") as mock_aio_insecure_channel:
channel = factory.create_channel(endpoint)

mock_aio_insecure_channel.assert_called_once_with(endpoint, options=ANY)
assert isinstance(channel, MagicMock)

0 comments on commit 4c18899

Please sign in to comment.