Skip to content

Commit

Permalink
Merge pull request #414 from aurelio-labs/vittorio/fix-pinecone-tests
Browse files Browse the repository at this point in the history
fix: Fixed pinecone tests
  • Loading branch information
jamescalam authored Sep 19, 2024
2 parents 413f147 + 0196644 commit 3e6bd22
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 139 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ lint lint_diff:
poetry run mypy $(PYTHON_FILES)

test:
poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml
poetry run pytest -vv --cov=semantic_router --cov-report=term-missing --cov-report=xml

test_functional:
poetry run pytest -vv -n 20 tests/functional
Expand Down
73 changes: 58 additions & 15 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,11 @@ def _sync_index(
for route in all_routes:
local_utterances = local_dict.get(route, {}).get("utterances", set())
remote_utterances = remote_dict.get(route, {}).get("utterances", set())
local_function_schemas = local_dict.get(route, {}).get(
"function_schemas", {}
local_function_schemas = (
local_dict.get(route, {}).get("function_schemas", {}) or {}
)
remote_function_schemas = remote_dict.get(route, {}).get(
"function_schemas", {}
remote_function_schemas = (
remote_dict.get(route, {}).get("function_schemas", {}) or {}
)
local_metadata = local_dict.get(route, {}).get("metadata", {})
remote_metadata = remote_dict.get(route, {}).get("metadata", {})
Expand All @@ -295,15 +295,19 @@ def _sync_index(
if local_utterances:
layer_routes[route] = {
"utterances": list(local_utterances),
"function_schemas": local_function_schemas,
"function_schemas": (
local_function_schemas if local_function_schemas else None
),
"metadata": local_metadata,
}

elif self.sync == "remote":
if remote_utterances:
layer_routes[route] = {
"utterances": list(remote_utterances),
"function_schemas": remote_function_schemas,
"function_schemas": (
remote_function_schemas if remote_function_schemas else None
),
"metadata": remote_metadata,
}

Expand All @@ -319,7 +323,9 @@ def _sync_index(
if local_utterances:
layer_routes[route] = {
"utterances": list(local_utterances),
"function_schemas": local_function_schemas,
"function_schemas": (
local_function_schemas if local_function_schemas else None
),
"metadata": local_metadata,
}

Expand All @@ -329,14 +335,22 @@ def _sync_index(
if local_utterances:
layer_routes[route] = {
"utterances": list(local_utterances),
"function_schemas": local_function_schemas,
"function_schemas": (
local_function_schemas
if local_function_schemas
else None
),
"metadata": local_metadata,
}
else:
if remote_utterances:
layer_routes[route] = {
"utterances": list(remote_utterances),
"function_schemas": remote_function_schemas,
"function_schemas": (
remote_function_schemas
if remote_function_schemas
else None
),
"metadata": remote_metadata,
}

Expand All @@ -353,14 +367,22 @@ def _sync_index(
if local_utterances:
layer_routes[route] = {
"utterances": list(local_utterances),
"function_schemas": local_function_schemas,
"function_schemas": (
local_function_schemas
if local_function_schemas
else None
),
"metadata": local_metadata,
}
else:
if remote_utterances:
layer_routes[route] = {
"utterances": list(remote_utterances),
"function_schemas": remote_function_schemas,
"function_schemas": (
remote_function_schemas
if remote_function_schemas
else None
),
"metadata": remote_metadata,
}

Expand All @@ -375,7 +397,9 @@ def _sync_index(
}
layer_routes[route] = {
"utterances": list(remote_utterances.union(local_utterances)),
"function_schemas": merged_function_schemas,
"function_schemas": (
merged_function_schemas if merged_function_schemas else None
),
"metadata": merged_metadata,
}

Expand All @@ -389,17 +413,36 @@ def _sync_index(
]:
for utterance in local_utterances:
routes_to_add.append(
(route, utterance, local_function_schemas, local_metadata)
(
route,
utterance,
local_function_schemas if local_function_schemas else None,
local_metadata,
)
)
if (metadata_changed or function_schema_changed) and self.sync == "merge":
for utterance in local_utterances:
routes_to_add.append(
(route, utterance, merged_function_schemas, merged_metadata)
(
route,
utterance,
(
merged_function_schemas
if merged_function_schemas
else None
),
merged_metadata,
)
)
elif utterances_to_include:
for utterance in utterances_to_include:
routes_to_add.append(
(route, utterance, local_function_schemas, local_metadata)
(
route,
utterance,
local_function_schemas if local_function_schemas else None,
local_metadata,
)
)

return routes_to_add, routes_to_delete, layer_routes
Expand Down
12 changes: 5 additions & 7 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def __call__(
if text is None:
raise ValueError("Either text or vector must be provided")
vector = self._encode(text=text)

route, top_class_scores = self._retrieve_top_route(vector, route_filter)
passed = self._check_threshold(top_class_scores, route)
if passed and route is not None and not simulate_static:
Expand Down Expand Up @@ -448,7 +447,10 @@ def delete(self, route_name: str):
if route_name not in [route.name for route in self.routes]:
err_msg = f"Route `{route_name}` not found in RouteLayer"
logger.warning(err_msg)
self.index.delete(route_name=route_name)
try:
self.index.delete(route_name=route_name)
except Exception as e:
logger.error(f"Failed to delete route from the index: {e}")
else:
self.routes = [route for route in self.routes if route.name != route_name]
self.index.delete(route_name=route_name)
Expand Down Expand Up @@ -503,13 +505,9 @@ def _add_and_sync_routes(self, routes: List[Route]):
local_utterances,
local_function_schemas,
local_metadata,
dimensions=len(self.encoder(["dummy"])[0]),
dimensions=self.index.dimensions or len(self.encoder(["dummy"])[0]),
)

logger.info(f"Routes to add: {routes_to_add}")
logger.info(f"Routes to delete: {routes_to_delete}")
logger.info(f"Layer routes: {layer_routes_dict}")

data_to_delete = {} # type: ignore
for route, utterance in routes_to_delete:
data_to_delete.setdefault(route, []).append(utterance)
Expand Down
Loading

0 comments on commit 3e6bd22

Please sign in to comment.