diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 4468582ca80ea..cc50e58926316 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -254,7 +254,10 @@ def _call_iter(self, iter_fun: Callable) -> Any: return iter_fun() except grpc.RpcError as e: status = rpc_status.from_call(cast(grpc.Call, e)) - if status is not None and "INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message: + if status is not None and ( + "INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message + or "INVALID_HANDLE.SESSION_NOT_FOUND" in status.message + ): if self._last_returned_response_id is not None: raise PySparkRuntimeError( error_class="RESPONSE_ALREADY_RECEIVED", diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index b96fc44d50a7e..4f54a0a67d8aa 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -18,13 +18,14 @@ import unittest import uuid from collections.abc import Generator -from typing import Optional, Any +from typing import Optional, Any, Union from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import eventually if should_test_connect: import grpc + from google.rpc import status_pb2 import pandas as pd import pyarrow as pa from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder @@ -33,7 +34,7 @@ DefaultPolicy, ) from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator - from pyspark.errors import RetriesExceeded + from pyspark.errors import PySparkRuntimeError, RetriesExceeded import pyspark.sql.connect.proto as proto class TestPolicy(DefaultPolicy): @@ -50,9 +51,17 @@ def __init__(self): class TestException(grpc.RpcError, grpc.Call): """Exception mock to test retryable exceptions.""" - def __init__(self, msg, code=grpc.StatusCode.INTERNAL): + def __init__( + self, + msg, + code=grpc.StatusCode.INTERNAL, + trailing_status: Union[status_pb2.Status, None] = None, + ): self.msg = msg self._code = code + self._trailer: dict[str, Any] = {} + if trailing_status is not None: + self._trailer["grpc-status-details-bin"] = trailing_status.SerializeToString() def code(self): return self._code @@ -60,8 +69,11 @@ def code(self): def __str__(self): return self.msg + def details(self): + return self.msg + def trailing_metadata(self): - return () + return None if not self._trailer else self._trailer.items() class ResponseGenerator(Generator): """This class is used to generate values that are returned by the streaming @@ -340,6 +352,71 @@ def check(): eventually(timeout=1, catch_assertions=True)(check)() + def test_not_found_recovers(self): + """SPARK-48056: Assert that the client recovers from session or operation not + found error if no partial responses were previously received. + """ + + def not_found_recovers(error_code: str): + def not_found(): + raise TestException( + error_code, + grpc.StatusCode.UNAVAILABLE, + trailing_status=status_pb2.Status(code=14, message=error_code, details=""), + ) + + stub = self._stub_with([not_found, self.finished]) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + + for _ in ite: + pass + + def checks(): + self.assertEquals(2, stub.execute_calls) + self.assertEquals(0, stub.attach_calls) + self.assertEquals(0, stub.release_calls) + self.assertEquals(0, stub.release_until_calls) + + eventually(timeout=1, catch_assertions=True)(checks)() + + parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"] + for b in parameters: + not_found_recovers(b) + + def test_not_found_fails(self): + """SPARK-48056: Assert that the client fails from session or operation not found error + if a partial response was previously received. + """ + + def not_found_fails(error_code: str): + def not_found(): + raise TestException( + error_code, + grpc.StatusCode.UNAVAILABLE, + trailing_status=status_pb2.Status(code=14, message=error_code, details=""), + ) + + stub = self._stub_with([self.response], [not_found]) + + with self.assertRaises(PySparkRuntimeError) as e: + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + for _ in ite: + pass + + self.assertTrue("RESPONSE_ALREADY_RECEIVED" in e.exception.getMessage()) + + def checks(): + self.assertEquals(1, stub.execute_calls) + self.assertEquals(1, stub.attach_calls) + self.assertEquals(0, stub.release_calls) + self.assertEquals(0, stub.release_until_calls) + + eventually(timeout=1, catch_assertions=True)(checks)() + + parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"] + for b in parameters: + not_found_fails(b) + if __name__ == "__main__": from pyspark.sql.tests.connect.client.test_client import * # noqa: F401