Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: defer the database connection to when it's needed #804

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,35 +66,45 @@ def __init__(
"""
self.api_key = api_key
spec = spec or DEFAULT_STARTER_PLAN_SPEC
self.namespace = namespace
self.batch_size = batch_size
self.metric = metric
self.spec = spec
self.dimension = dimension
self.index_name = index

self._index = None
self._dummy_vector = [-10.0] * self.dimension

@property
def index(self):
if self._index is not None:
return self._index

client = Pinecone(api_key=api_key.resolve_value(), source_tag="haystack")
client = Pinecone(api_key=self.api_key.resolve_value(), source_tag="haystack")

if index not in client.list_indexes().names():
logger.info(f"Index {index} does not exist. Creating a new index.")
pinecone_spec = self._convert_dict_spec_to_pinecone_object(spec)
client.create_index(name=index, dimension=dimension, spec=pinecone_spec, metric=metric)
if self.index_name not in client.list_indexes().names():
logger.info(f"Index {self.index_name} does not exist. Creating a new index.")
pinecone_spec = self._convert_dict_spec_to_pinecone_object(self.spec)
client.create_index(name=self.index_name, dimension=self.dimension, spec=pinecone_spec, metric=self.metric)
else:
logger.info(
f"Index {index} already exists. Connecting to it. `dimension`, `spec`, and `metric` will be ignored."
f"Connecting to existing index {self.index_name}. `dimension`, `spec`, and `metric` will be ignored."
)

self._index = client.Index(name=index)
self._index = client.Index(name=self.index_name)

actual_dimension = self._index.describe_index_stats().get("dimension")
if actual_dimension and actual_dimension != dimension:
if actual_dimension and actual_dimension != self.dimension:
logger.warning(
f"Dimension of index {index} is {actual_dimension}, but {dimension} was specified. "
f"Dimension of index {self.index_name} is {actual_dimension}, but {self.dimension} was specified. "
"The specified dimension will be ignored."
"If you need an index with a different dimension, please create a new one."
)
self.dimension = actual_dimension or dimension

self.dimension = actual_dimension or self.dimension
self._dummy_vector = [-10.0] * self.dimension
self.index = index
self.namespace = namespace
self.batch_size = batch_size
self.metric = metric
self.spec = spec

return self._index

@staticmethod
def _convert_dict_spec_to_pinecone_object(spec: Dict[str, Any]):
Expand Down Expand Up @@ -135,7 +145,7 @@ def to_dict(self) -> Dict[str, Any]:
self,
api_key=self.api_key.to_dict(),
spec=self.spec,
index=self.index,
index=self.index_name,
dimension=self.dimension,
namespace=self.namespace,
batch_size=self.batch_size,
Expand All @@ -147,7 +157,7 @@ def count_documents(self) -> int:
Returns how many documents are present in the document store.
"""
try:
count = self._index.describe_index_stats()["namespaces"][self.namespace]["vector_count"]
count = self.index.describe_index_stats()["namespaces"][self.namespace]["vector_count"]
except KeyError:
count = 0
return count
Expand All @@ -174,9 +184,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D

documents_for_pinecone = self._convert_documents_to_pinecone_format(documents)

result = self._index.upsert(
vectors=documents_for_pinecone, namespace=self.namespace, batch_size=self.batch_size
)
result = self.index.upsert(vectors=documents_for_pinecone, namespace=self.namespace, batch_size=self.batch_size)

written_docs = result["upserted_count"]
return written_docs
Expand Down Expand Up @@ -214,7 +222,7 @@ def delete_documents(self, document_ids: List[str]) -> None:

:param document_ids: the document ids to delete
"""
self._index.delete(ids=document_ids, namespace=self.namespace)
self.index.delete(ids=document_ids, namespace=self.namespace)

def _embedding_retrieval(
self,
Expand Down Expand Up @@ -247,7 +255,7 @@ def _embedding_retrieval(
filters = convert(filters)
filters = _normalize_filters(filters) if filters else None

result = self._index.query(
result = self.index.query(
vector=query_embedding,
top_k=top_k,
namespace=namespace or self.namespace,
Expand Down
2 changes: 1 addition & 1 deletion integrations/pinecone/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def delete_documents_and_wait(filters):

yield store
try:
store._index.delete(delete_all=True, namespace=namespace)
store.index.delete(delete_all=True, namespace=namespace)
except NotFoundException:
pass
30 changes: 23 additions & 7 deletions integrations/pinecone/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from haystack_integrations.document_stores.pinecone import PineconeDocumentStore


@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone")
def test_init_is_lazy(_mock_client):
_ = PineconeDocumentStore(api_key=Secret.from_token("fake-api-key"))
_mock_client.assert_not_called()


@patch("haystack_integrations.document_stores.pinecone.document_store.Pinecone")
def test_init(mock_pinecone):
mock_pinecone.return_value.Index.return_value.describe_index_stats.return_value = {"dimension": 60}
Expand All @@ -25,9 +31,12 @@ def test_init(mock_pinecone):
metric="euclidean",
)

# Trigger an actual connection
_ = document_store.index

mock_pinecone.assert_called_with(api_key="fake-api-key", source_tag="haystack")

assert document_store.index == "my_index"
assert document_store.index_name == "my_index"
assert document_store.namespace == "test"
assert document_store.batch_size == 50
assert document_store.dimension == 60
Expand All @@ -38,14 +47,17 @@ def test_init(mock_pinecone):
def test_init_api_key_in_environment_variable(mock_pinecone, monkeypatch):
monkeypatch.setenv("PINECONE_API_KEY", "env-api-key")

PineconeDocumentStore(
ds = PineconeDocumentStore(
index="my_index",
namespace="test",
batch_size=50,
dimension=30,
metric="euclidean",
)

# Trigger an actual connection
_ = ds.index

mock_pinecone.assert_called_with(api_key="env-api-key", source_tag="haystack")


Expand All @@ -61,6 +73,9 @@ def test_to_from_dict(mock_pinecone, monkeypatch):
metric="euclidean",
)

# Trigger an actual connection
_ = document_store.index

dict_output = {
"type": "haystack_integrations.document_stores.pinecone.document_store.PineconeDocumentStore",
"init_parameters": {
Expand All @@ -83,7 +98,7 @@ def test_to_from_dict(mock_pinecone, monkeypatch):

document_store = PineconeDocumentStore.from_dict(dict_output)
assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True)
assert document_store.index == "my_index"
assert document_store.index_name == "my_index"
assert document_store.namespace == "test"
assert document_store.batch_size == 50
assert document_store.dimension == 60
Expand All @@ -94,9 +109,9 @@ def test_to_from_dict(mock_pinecone, monkeypatch):
def test_init_fails_wo_api_key(monkeypatch):
monkeypatch.delenv("PINECONE_API_KEY", raising=False)
with pytest.raises(ValueError):
PineconeDocumentStore(
_ = PineconeDocumentStore(
index="my_index",
)
).index


def test_convert_dict_spec_to_pinecone_object_serverless():
Expand All @@ -108,7 +123,6 @@ def test_convert_dict_spec_to_pinecone_object_serverless():


def test_convert_dict_spec_to_pinecone_object_pod():

dict_spec = {"pod": {"replicas": 1, "shards": 1, "pods": 1, "pod_type": "p1.x1", "environment": "us-west1-gcp"}}
pinecone_object = PineconeDocumentStore._convert_dict_spec_to_pinecone_object(dict_spec)

Expand Down Expand Up @@ -141,14 +155,16 @@ def test_serverless_index_creation_from_scratch(sleep_time):

time.sleep(sleep_time)

PineconeDocumentStore(
ds = PineconeDocumentStore(
index=index_name,
namespace="test",
batch_size=50,
dimension=30,
metric="euclidean",
spec={"serverless": {"region": "us-east-1", "cloud": "aws"}},
)
# Trigger the connection
_ = ds.index

index_description = client.describe_index(name=index_name)
assert index_description["name"] == index_name
Expand Down
2 changes: 1 addition & 1 deletion integrations/pinecone/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_from_dict(mock_pinecone, monkeypatch):

document_store = retriever.document_store
assert document_store.api_key == Secret.from_env_var("PINECONE_API_KEY", strict=True)
assert document_store.index == "default"
assert document_store.index_name == "default"
assert document_store.namespace == "test-namespace"
assert document_store.batch_size == 50
assert document_store.dimension == 512
Expand Down