Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create new session if client provided session is invalid #2368

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,9 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
// get session from request
SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId());
Optional<Session> createdSession = sessionManager.getSession(sessionId);
if (createdSession.isEmpty()) {
throw new IllegalArgumentException("no session found. " + sessionId);
if (createdSession.isPresent()) {
session = createdSession.get();
}
session = createdSession.get();
}
if (session == null || !session.isReady()) {
// create session if not exist or session dead/fail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class StatementModel extends StateModel {
public static final String QUERY_ID = "queryId";
public static final String SUBMIT_TIME = "submitTime";
public static final String ERROR = "error";
public static final String UNKNOWN = "unknown";
public static final String UNKNOWN = "";
public static final String STATEMENT_DOC_TYPE = "statement";

private final String version;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.Strings;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.plugins.Plugin;
Expand Down Expand Up @@ -227,6 +228,7 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() {
// 2. fetch async query result.
AsyncQueryExecutionResponse asyncQueryResults =
asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId());
assertTrue(Strings.isEmpty(asyncQueryResults.getError()));
assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus());

// 3. cancel async query.
Expand Down Expand Up @@ -460,24 +462,22 @@ public void recreateSessionIfNotReady() {
}

@Test
public void submitQueryInInvalidSessionThrowException() {
public void submitQueryInInvalidSessionWillCreateNewSession() {
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

// enable session
enableSession(true);

// 1. create async query.
SessionId sessionId = SessionId.newSessionId(DATASOURCE);
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest(
"select 1", DATASOURCE, LangType.SQL, sessionId.getSessionId())));
assertEquals("no session found. " + sessionId, exception.getMessage());
// 1. create async query with invalid sessionId
SessionId invalidSessionId = SessionId.newSessionId(DATASOURCE);
CreateAsyncQueryResponse asyncQuery =
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest(
"select 1", DATASOURCE, LangType.SQL, invalidSessionId.getSessionId()));
assertNotNull(asyncQuery.getSessionId());
assertNotEquals(invalidSessionId.getSessionId(), asyncQuery.getSessionId());
}

private DataSourceServiceImpl createDataSourceService() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,26 +327,6 @@ void testDispatchSelectQueryReuseSession() {
Assertions.assertEquals(MOCK_SESSION_ID, dispatchQueryResponse.getSessionId());
}

@Test
void testDispatchSelectQueryInvalidSession() {
String query = "select * from my_glue.default.http_logs";
DispatchQueryRequest queryRequest = dispatchQueryRequestWithSessionId(query, "invalid");

doReturn(true).when(sessionManager).isEnabled();
doReturn(Optional.empty()).when(sessionManager).getSession(any());
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata);
doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata);
IllegalArgumentException exception =
Assertions.assertThrows(
IllegalArgumentException.class, () -> sparkQueryDispatcher.dispatch(queryRequest));

verifyNoInteractions(emrServerlessClient);
verify(sessionManager, never()).createSession(any());
Assertions.assertEquals(
"no session found. " + new SessionId("invalid"), exception.getMessage());
}

@Test
void testDispatchSelectQueryFailedCreateSession() {
String query = "select * from my_glue.default.http_logs";
Expand Down
Loading