Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48056][CONNECT][PYTHON] Re-execute plan if a SESSION_NOT_FOUND error is raised and no partial response was received #46297

Closed
wants to merge 14 commits into from
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/client/reattach.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,10 @@ def _call_iter(self, iter_fun: Callable) -> Any:
return iter_fun()
except grpc.RpcError as e:
status = rpc_status.from_call(cast(grpc.Call, e))
if status is not None and "INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message:
if status is not None and (
"INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message
or "INVALID_HANDLE.SESSION_NOT_FOUND" in status.message
):
if self._last_returned_response_id is not None:
raise PySparkRuntimeError(
error_class="RESPONSE_ALREADY_RECEIVED",
Expand Down
85 changes: 81 additions & 4 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import unittest
import uuid
from collections.abc import Generator
from typing import Optional, Any
from typing import Optional, Any, Union

from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import eventually

if should_test_connect:
import grpc
from google.rpc import status_pb2
import pandas as pd
import pyarrow as pa
from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder
Expand All @@ -33,7 +34,7 @@
DefaultPolicy,
)
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.errors import RetriesExceeded
from pyspark.errors import PySparkRuntimeError, RetriesExceeded
import pyspark.sql.connect.proto as proto

class TestPolicy(DefaultPolicy):
Expand All @@ -50,18 +51,29 @@ def __init__(self):
class TestException(grpc.RpcError, grpc.Call):
"""Exception mock to test retryable exceptions."""

def __init__(self, msg, code=grpc.StatusCode.INTERNAL):
def __init__(
self,
msg,
code=grpc.StatusCode.INTERNAL,
trailing_status: Union[status_pb2.Status, None] = None,
):
self.msg = msg
self._code = code
self._trailer: dict[str, Any] = {}
if trailing_status is not None:
self._trailer["grpc-status-details-bin"] = trailing_status.SerializeToString()

def code(self):
return self._code

def __str__(self):
return self.msg

def details(self):
return self.msg

def trailing_metadata(self):
return ()
return None if not self._trailer else self._trailer.items()

class ResponseGenerator(Generator):
"""This class is used to generate values that are returned by the streaming
Expand Down Expand Up @@ -340,6 +352,71 @@ def check():

eventually(timeout=1, catch_assertions=True)(check)()

def test_not_found_recovers(self):
"""SPARK-48056: Assert that the client recovers from session or operation not
found error if no partial responses were previously received.
"""

def not_found_recovers(error_code: str):
def not_found():
raise TestException(
error_code,
grpc.StatusCode.UNAVAILABLE,
trailing_status=status_pb2.Status(code=14, message=error_code, details=""),
)

stub = self._stub_with([not_found, self.finished])
ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, [])

for _ in ite:
pass

def checks():
self.assertEquals(2, stub.execute_calls)
self.assertEquals(0, stub.attach_calls)
self.assertEquals(0, stub.release_calls)
self.assertEquals(0, stub.release_until_calls)

eventually(timeout=1, catch_assertions=True)(checks)()

parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"]
for b in parameters:
not_found_recovers(b)

def test_not_found_fails(self):
"""SPARK-48056: Assert that the client fails from session or operation not found error
if a partial response was previously received.
"""

def not_found_fails(error_code: str):
def not_found():
raise TestException(
error_code,
grpc.StatusCode.UNAVAILABLE,
trailing_status=status_pb2.Status(code=14, message=error_code, details=""),
)

stub = self._stub_with([self.response], [not_found])

with self.assertRaises(PySparkRuntimeError) as e:
ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, [])
for _ in ite:
pass

self.assertTrue("RESPONSE_ALREADY_RECEIVED" in e.exception.getMessage())

def checks():
self.assertEquals(1, stub.execute_calls)
self.assertEquals(1, stub.attach_calls)
self.assertEquals(0, stub.release_calls)
self.assertEquals(0, stub.release_until_calls)

eventually(timeout=1, catch_assertions=True)(checks)()

parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"]
for b in parameters:
not_found_fails(b)


if __name__ == "__main__":
from pyspark.sql.tests.connect.client.test_client import * # noqa: F401
Expand Down