Skip to content

Commit

Permalink
chore: use vector endpoint for MVI API calls (#373)
Browse files Browse the repository at this point in the history
* chore: use vector endpoint for MVI API calls
  • Loading branch information
pratik151192 authored Aug 28, 2023
1 parent 74d2bf2 commit 6f6e0f8
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 8 deletions.
16 changes: 12 additions & 4 deletions src/momento/auth/credential_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,24 @@ class CredentialProvider:
auth_token: str
control_endpoint: str
cache_endpoint: str
vector_endpoint: str

@staticmethod
def from_environment_variable(
env_var_name: str,
control_endpoint: Optional[str] = None,
cache_endpoint: Optional[str] = None,
vector_endpoint: Optional[str] = None,
) -> CredentialProvider:
"""Reads and parses a Momento auth token stored as an environment variable.
Args:
env_var_name (str): Name of the environment variable from which the auth token will be read
control_endpoint (Optional[str], optional): Optionally overrides the default control endpoint.
Defaults to None.
cache_endpoint (Optional[str], optional): Optionally overrides the default control endpoint.
cache_endpoint (Optional[str], optional): Optionally overrides the default cache endpoint.
Defaults to None.
vector_endpoint (Optional[str], optional): Optionally overrides the default vector endpoint.
Defaults to None.
Raises:
Expand All @@ -40,21 +44,24 @@ def from_environment_variable(
auth_token = os.getenv(env_var_name)
if not auth_token:
raise RuntimeError(f"Missing required environment variable {env_var_name}")
return CredentialProvider.from_string(auth_token, control_endpoint, cache_endpoint)
return CredentialProvider.from_string(auth_token, control_endpoint, cache_endpoint, vector_endpoint)

@staticmethod
def from_string(
auth_token: str,
control_endpoint: Optional[str] = None,
cache_endpoint: Optional[str] = None,
vector_endpoint: Optional[str] = None,
) -> CredentialProvider:
"""Reads and parses a Momento auth token.
Args:
auth_token (str): the Momento auth token
control_endpoint (Optional[str], optional): Optionally overrides the default control endpoint.
Defaults to None.
cache_endpoint (Optional[str], optional): Optionally overrides the default control endpoint.
cache_endpoint (Optional[str], optional): Optionally overrides the default cache endpoint.
Defaults to None.
vector_endpoint (Optional[str], optional): Optionally overrides the default vector endpoint.
Defaults to None.
Returns:
Expand All @@ -63,8 +70,9 @@ def from_string(
token_and_endpoints = momento_endpoint_resolver.resolve(auth_token)
control_endpoint = control_endpoint or token_and_endpoints.control_endpoint
cache_endpoint = cache_endpoint or token_and_endpoints.cache_endpoint
vector_endpoint = vector_endpoint or token_and_endpoints.vector_endpoint
auth_token = token_and_endpoints.auth_token
return CredentialProvider(auth_token, control_endpoint, cache_endpoint)
return CredentialProvider(auth_token, control_endpoint, cache_endpoint, vector_endpoint)

def __repr__(self) -> str:
attributes: Dict[str, str] = copy.copy(vars(self)) # type: ignore[misc]
Expand Down
5 changes: 5 additions & 0 deletions src/momento/auth/momento_endpoint_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@

_MOMENTO_CONTROL_ENDPOINT_PREFIX = "control."
_MOMENTO_CACHE_ENDPOINT_PREFIX = "cache."
_MOMENTO_VECTOR_ENDPOINT_PREFIX = "vector."
_CONTROL_ENDPOINT_CLAIM_ID = "cp"
_CACHE_ENDPOINT_CLAIM_ID = "c"
_VECTOR_ENDPOINT_CLAIM_ID = "c" # we don't have a new claim here so defaulting to c


@dataclass
class _TokenAndEndpoints:
control_endpoint: str
cache_endpoint: str
vector_endpoint: str
auth_token: str


Expand All @@ -38,6 +41,7 @@ def resolve(auth_token: str) -> _TokenAndEndpoints:
return _TokenAndEndpoints(
control_endpoint=_MOMENTO_CONTROL_ENDPOINT_PREFIX + info["endpoint"], # type: ignore[misc]
cache_endpoint=_MOMENTO_CACHE_ENDPOINT_PREFIX + info["endpoint"], # type: ignore[misc]
vector_endpoint=_MOMENTO_VECTOR_ENDPOINT_PREFIX + info["endpoint"], # type: ignore[misc]
auth_token=info["api_key"], # type: ignore[misc]
)
else:
Expand All @@ -50,6 +54,7 @@ def _get_endpoint_from_token(auth_token: str) -> _TokenAndEndpoints:
return _TokenAndEndpoints(
control_endpoint=claims[_CONTROL_ENDPOINT_CLAIM_ID], # type: ignore[misc]
cache_endpoint=claims[_CACHE_ENDPOINT_CLAIM_ID], # type: ignore[misc]
vector_endpoint=claims[_VECTOR_ENDPOINT_CLAIM_ID], # type: ignore[misc]
auth_token=auth_token,
)
except (DecodeError, KeyError) as e:
Expand Down
2 changes: 1 addition & 1 deletion src/momento/internal/aio/_vector_index_data_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class _VectorIndexDataClient:
"""Internal vector index data client."""

def __init__(self, configuration: VectorIndexConfiguration, credential_provider: CredentialProvider):
endpoint = credential_provider.cache_endpoint
endpoint = credential_provider.vector_endpoint
self._logger = logs.logger
self._logger.debug("Vector index data client instantiated with endpoint: %s", endpoint)
self._endpoint = endpoint
Expand Down
2 changes: 1 addition & 1 deletion src/momento/internal/aio/_vector_index_grpc_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class _VectorIndexDataGrpcManager:

def __init__(self, configuration: VectorIndexConfiguration, credential_provider: CredentialProvider):
self._secure_channel = grpc.aio.secure_channel(
target=credential_provider.cache_endpoint,
target=credential_provider.vector_endpoint,
credentials=grpc.ssl_channel_credentials(),
interceptors=_interceptors(credential_provider.auth_token),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class _VectorIndexDataClient:
"""Internal vector index data client."""

def __init__(self, configuration: VectorIndexConfiguration, credential_provider: CredentialProvider):
endpoint = credential_provider.cache_endpoint
endpoint = credential_provider.vector_endpoint
self._logger = logs.logger
self._logger.debug("Vector index data client instantiated with endpoint: %s", endpoint)
self._endpoint = endpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class _VectorIndexDataGrpcManager:

def __init__(self, configuration: VectorIndexConfiguration, credential_provider: CredentialProvider):
self._secure_channel = grpc.secure_channel(
target=credential_provider.cache_endpoint,
target=credential_provider.vector_endpoint,
credentials=grpc.ssl_channel_credentials(),
)
intercept_channel = grpc.intercept_channel(self._secure_channel, *_interceptors(credential_provider.auth_token))
Expand Down

0 comments on commit 6f6e0f8

Please sign in to comment.