diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py index 6c93dc253e8a5..901e9604f6242 100644 --- a/libs/community/langchain_community/graphs/neo4j_graph.py +++ b/libs/community/langchain_community/graphs/neo4j_graph.py @@ -411,7 +411,9 @@ def get_structured_schema(self) -> Dict[str, Any]: return self.structured_schema def query( - self, query: str, params: dict = {}, retry_on_session_expired: bool = True + self, + query: str, + params: dict = {}, ) -> List[Dict[str, Any]]: """Query Neo4j database. @@ -423,26 +425,44 @@ def query( List[Dict[str, Any]]: The list of dictionaries containing the query results. """ from neo4j import Query - from neo4j.exceptions import CypherSyntaxError, SessionExpired + from neo4j.exceptions import Neo4jError - with self._driver.session(database=self._database) as session: - try: - data = session.run(Query(text=query, timeout=self.timeout), params) - json_data = [r.data() for r in data] - if self.sanitize: - json_data = [value_sanitize(el) for el in json_data] - return json_data - except CypherSyntaxError as e: - raise ValueError(f"Generated Cypher Statement is not valid\n{e}") - except ( - SessionExpired - ) as e: # Session expired is a transient error that can be retried - if retry_on_session_expired: - return self.query( - query, params=params, retry_on_session_expired=False + try: + data, _, _ = self._driver.execute_query( + Query(text=query, timeout=self.timeout), + database=self._database, + parameters_=params, + ) + json_data = [r.data() for r in data] + if self.sanitize: + json_data = [value_sanitize(el) for el in json_data] + return json_data + except Neo4jError as e: + if not ( + ( + ( # isCallInTransactionError + e.code == "Neo.DatabaseError.Statement.ExecutionFailed" + or e.code + == "Neo.DatabaseError.Transaction.TransactionStartFailed" ) - else: - raise e + and "in an implicit transaction" in e.message + ) + or ( # isPeriodicCommitError + e.code == "Neo.ClientError.Statement.SemanticError" + and ( + "in an open transaction is not possible" in e.message + or "tried to execute in an explicit transaction" in e.message + ) + ) + ): + raise + # fallback to allow implicit transactions + with self._driver.session() as session: + data = session.run(Query(text=query, timeout=self.timeout), params) + json_data = [r.data() for r in data] + if self.sanitize: + json_data = [value_sanitize(el) for el in json_data] + return json_data def refresh_schema(self) -> None: """ diff --git a/libs/community/langchain_community/vectorstores/neo4j_vector.py b/libs/community/langchain_community/vectorstores/neo4j_vector.py index 2d3eff317ad69..fb13a6257c907 100644 --- a/libs/community/langchain_community/vectorstores/neo4j_vector.py +++ b/libs/community/langchain_community/vectorstores/neo4j_vector.py @@ -595,11 +595,8 @@ def query( query: str, *, params: Optional[dict] = None, - retry_on_session_expired: bool = True, ) -> List[Dict[str, Any]]: - """ - This method sends a Cypher query to the connected Neo4j database - and returns the results as a list of dictionaries. + """Query Neo4j database with retries and exponential backoff. Args: query (str): The Cypher query to execute. @@ -608,24 +605,38 @@ def query( Returns: List[Dict[str, Any]]: List of dictionaries containing the query results. """ - from neo4j.exceptions import CypherSyntaxError, SessionExpired + from neo4j import Query + from neo4j.exceptions import Neo4jError params = params or {} - with self._driver.session(database=self._database) as session: - try: - data = session.run(query, params) - return [r.data() for r in data] - except CypherSyntaxError as e: - raise ValueError(f"Cypher Statement is not valid\n{e}") - except ( - SessionExpired - ) as e: # Session expired is a transient error that can be retried - if retry_on_session_expired: - return self.query( - query, params=params, retry_on_session_expired=False + try: + data, _, _ = self._driver.execute_query( + query, database=self._database, parameters_=params + ) + return [r.data() for r in data] + except Neo4jError as e: + if not ( + ( + ( # isCallInTransactionError + e.code == "Neo.DatabaseError.Statement.ExecutionFailed" + or e.code + == "Neo.DatabaseError.Transaction.TransactionStartFailed" ) - else: - raise e + and "in an implicit transaction" in e.message + ) + or ( # isPeriodicCommitError + e.code == "Neo.ClientError.Statement.SemanticError" + and ( + "in an open transaction is not possible" in e.message + or "tried to execute in an explicit transaction" in e.message + ) + ) + ): + raise + # Fallback to allow implicit transactions + with self._driver.session() as session: + data = session.run(Query(text=query), params) + return [r.data() for r in data] def verify_version(self) -> None: """