Skip to content

Commit

Permalink
[SPARK-48056][CONNECT][PYTHON] Re-execute plan if a SESSION_NOT_FOUND…
Browse files Browse the repository at this point in the history
… error is raised and no partial response was received

### What changes were proposed in this pull request?

Similar to OPERATION_NOT_FOUND, re-attempt to execute
the original spark connect plan when a SESSION_NOT_FOUND is
received from the spark connect service and no partial responses
were previously received.

### Why are the changes needed?

This error has been noticed to occur during a cluster cold start
and when a request arrives when the connect service is not fully
initialized.

### Does this PR introduce _any_ user-facing change?

Prevoiusly, connect-based pyspark APIs would fail with the error code
"INVALID_HANDLE.SESSION_NOT_FOUND" in the very first request to
the service.
With this change, the client will now automatically retry.

### How was this patch tested?

Attached unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #46297 from nija-at/session-not-found.

Authored-by: Niranjan Jayakar <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
nija-at authored and HyukjinKwon committed May 2, 2024
1 parent ae5da18 commit 2f31d05
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 5 deletions.
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/client/reattach.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,10 @@ def _call_iter(self, iter_fun: Callable) -> Any:
return iter_fun()
except grpc.RpcError as e:
status = rpc_status.from_call(cast(grpc.Call, e))
if status is not None and "INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message:
if status is not None and (
"INVALID_HANDLE.OPERATION_NOT_FOUND" in status.message
or "INVALID_HANDLE.SESSION_NOT_FOUND" in status.message
):
if self._last_returned_response_id is not None:
raise PySparkRuntimeError(
error_class="RESPONSE_ALREADY_RECEIVED",
Expand Down
85 changes: 81 additions & 4 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import unittest
import uuid
from collections.abc import Generator
from typing import Optional, Any
from typing import Optional, Any, Union

from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import eventually

if should_test_connect:
import grpc
from google.rpc import status_pb2
import pandas as pd
import pyarrow as pa
from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder
Expand All @@ -33,7 +34,7 @@
DefaultPolicy,
)
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.errors import RetriesExceeded
from pyspark.errors import PySparkRuntimeError, RetriesExceeded
import pyspark.sql.connect.proto as proto

class TestPolicy(DefaultPolicy):
Expand All @@ -50,18 +51,29 @@ def __init__(self):
class TestException(grpc.RpcError, grpc.Call):
"""Exception mock to test retryable exceptions."""

def __init__(self, msg, code=grpc.StatusCode.INTERNAL):
def __init__(
self,
msg,
code=grpc.StatusCode.INTERNAL,
trailing_status: Union[status_pb2.Status, None] = None,
):
self.msg = msg
self._code = code
self._trailer: dict[str, Any] = {}
if trailing_status is not None:
self._trailer["grpc-status-details-bin"] = trailing_status.SerializeToString()

def code(self):
return self._code

def __str__(self):
return self.msg

def details(self):
return self.msg

def trailing_metadata(self):
return ()
return None if not self._trailer else self._trailer.items()

class ResponseGenerator(Generator):
"""This class is used to generate values that are returned by the streaming
Expand Down Expand Up @@ -340,6 +352,71 @@ def check():

eventually(timeout=1, catch_assertions=True)(check)()

def test_not_found_recovers(self):
"""SPARK-48056: Assert that the client recovers from session or operation not
found error if no partial responses were previously received.
"""

def not_found_recovers(error_code: str):
def not_found():
raise TestException(
error_code,
grpc.StatusCode.UNAVAILABLE,
trailing_status=status_pb2.Status(code=14, message=error_code, details=""),
)

stub = self._stub_with([not_found, self.finished])
ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, [])

for _ in ite:
pass

def checks():
self.assertEquals(2, stub.execute_calls)
self.assertEquals(0, stub.attach_calls)
self.assertEquals(0, stub.release_calls)
self.assertEquals(0, stub.release_until_calls)

eventually(timeout=1, catch_assertions=True)(checks)()

parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"]
for b in parameters:
not_found_recovers(b)

def test_not_found_fails(self):
"""SPARK-48056: Assert that the client fails from session or operation not found error
if a partial response was previously received.
"""

def not_found_fails(error_code: str):
def not_found():
raise TestException(
error_code,
grpc.StatusCode.UNAVAILABLE,
trailing_status=status_pb2.Status(code=14, message=error_code, details=""),
)

stub = self._stub_with([self.response], [not_found])

with self.assertRaises(PySparkRuntimeError) as e:
ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, [])
for _ in ite:
pass

self.assertTrue("RESPONSE_ALREADY_RECEIVED" in e.exception.getMessage())

def checks():
self.assertEquals(1, stub.execute_calls)
self.assertEquals(1, stub.attach_calls)
self.assertEquals(0, stub.release_calls)
self.assertEquals(0, stub.release_until_calls)

eventually(timeout=1, catch_assertions=True)(checks)()

parameters = ["INVALID_HANDLE.SESSION_NOT_FOUND", "INVALID_HANDLE.OPERATION_NOT_FOUND"]
for b in parameters:
not_found_fails(b)


if __name__ == "__main__":
from pyspark.sql.tests.connect.client.test_client import * # noqa: F401
Expand Down

0 comments on commit 2f31d05

Please sign in to comment.