diff --git a/spark/build.gradle b/spark/build.gradle
index c06b5b6ecf..c2c925ecaf 100644
--- a/spark/build.gradle
+++ b/spark/build.gradle
@@ -52,15 +52,38 @@ dependencies {
api group: 'com.amazonaws', name: 'aws-java-sdk-emrserverless', version: '1.12.545'
implementation group: 'commons-io', name: 'commons-io', version: '2.8.0'
- testImplementation('org.junit.jupiter:junit-jupiter:5.6.2')
+ testImplementation(platform("org.junit:junit-bom:5.6.2"))
+
+ testImplementation('org.junit.jupiter:junit-jupiter')
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0'
testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.2.0'
- testImplementation 'junit:junit:4.13.1'
- testImplementation "org.opensearch.test:framework:${opensearch_version}"
+
+ testCompileOnly('junit:junit:4.13.1') {
+ exclude group: 'org.hamcrest', module: 'hamcrest-core'
+ }
+ testRuntimeOnly("org.junit.vintage:junit-vintage-engine") {
+ exclude group: 'org.hamcrest', module: 'hamcrest-core'
+ }
+ testRuntimeOnly("org.junit.platform:junit-platform-launcher") {
+ because 'allows tests to run from IDEs that bundle older version of launcher'
+ }
+ testImplementation("org.opensearch.test:framework:${opensearch_version}")
}
test {
- useJUnitPlatform()
+ useJUnitPlatform {
+ includeEngines("junit-jupiter")
+ }
+ testLogging {
+ events "failed"
+ exceptionFormat "full"
+ }
+}
+task junit4(type: Test) {
+ useJUnitPlatform {
+ includeEngines("junit-vintage")
+ }
+ systemProperty 'tests.security.manager', 'false'
testLogging {
events "failed"
exceptionFormat "full"
@@ -68,6 +91,8 @@ test {
}
jacocoTestReport {
+ dependsOn test, junit4
+ executionData test, junit4
reports {
html.enabled true
xml.enabled true
@@ -78,9 +103,10 @@ jacocoTestReport {
}))
}
}
-test.finalizedBy(project.tasks.jacocoTestReport)
jacocoTestCoverageVerification {
+ dependsOn test, junit4
+ executionData test, junit4
violationRules {
rule {
element = 'CLASS'
@@ -92,6 +118,9 @@ jacocoTestCoverageVerification {
'org.opensearch.sql.spark.asyncquery.exceptions.*',
'org.opensearch.sql.spark.dispatcher.model.*',
'org.opensearch.sql.spark.flint.FlintIndexType',
+ // ignore because XContext IOException
+ 'org.opensearch.sql.spark.execution.statestore.SessionStateStore',
+ 'org.opensearch.sql.spark.execution.session.SessionModel'
]
limit {
counter = 'LINE'
diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java
new file mode 100644
index 0000000000..17e3346248
--- /dev/null
+++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java
@@ -0,0 +1,15 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+import lombok.Data;
+import org.opensearch.sql.spark.client.StartJobRequest;
+
+@Data
+public class CreateSessionRequest {
+ private final StartJobRequest startJobRequest;
+ private final String datasourceName;
+}
diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java
new file mode 100644
index 0000000000..620e46b9be
--- /dev/null
+++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession;
+
+import java.util.Optional;
+import lombok.Builder;
+import lombok.Getter;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.opensearch.index.engine.VersionConflictEngineException;
+import org.opensearch.sql.spark.client.EMRServerlessClient;
+import org.opensearch.sql.spark.execution.statestore.SessionStateStore;
+
+/**
+ * Interactive session.
+ *
+ *
ENTRY_STATE: not_started
+ */
+@Getter
+@Builder
+public class InteractiveSession implements Session {
+ private static final Logger LOG = LogManager.getLogger();
+
+ private final SessionId sessionId;
+ private final SessionStateStore sessionStateStore;
+ private final EMRServerlessClient serverlessClient;
+
+ private SessionModel sessionModel;
+
+ @Override
+ public void open(CreateSessionRequest createSessionRequest) {
+ try {
+ String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest());
+ String applicationId = createSessionRequest.getStartJobRequest().getApplicationId();
+
+ sessionModel =
+ initInteractiveSession(
+ applicationId, jobID, sessionId, createSessionRequest.getDatasourceName());
+ sessionStateStore.create(sessionModel);
+ } catch (VersionConflictEngineException e) {
+ String errorMsg = "session already exist. " + sessionId;
+ LOG.error(errorMsg);
+ throw new IllegalStateException(errorMsg);
+ }
+ }
+
+ @Override
+ public void close() {
+ Optional model = sessionStateStore.get(sessionModel.getSessionId());
+ if (model.isEmpty()) {
+ throw new IllegalStateException("session not exist. " + sessionModel.getSessionId());
+ } else {
+ serverlessClient.cancelJobRun(sessionModel.getApplicationId(), sessionModel.getJobId());
+ }
+ }
+}
diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java
new file mode 100644
index 0000000000..ec9775e60a
--- /dev/null
+++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java
@@ -0,0 +1,19 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+/** Session define the statement execution context. Each session is binding to one Spark Job. */
+public interface Session {
+ /** open session. */
+ void open(CreateSessionRequest createSessionRequest);
+
+ /** close session. */
+ void close();
+
+ SessionModel getSessionModel();
+
+ SessionId getSessionId();
+}
diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java
new file mode 100644
index 0000000000..a2847cde18
--- /dev/null
+++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java
@@ -0,0 +1,23 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+import lombok.Data;
+import org.apache.commons.lang3.RandomStringUtils;
+
+@Data
+public class SessionId {
+ private final String sessionId;
+
+ public static SessionId newSessionId() {
+ return new SessionId(RandomStringUtils.random(10, true, true));
+ }
+
+ @Override
+ public String toString() {
+ return "sessionId=" + sessionId;
+ }
+}
diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java
new file mode 100644
index 0000000000..3d0916bac8
--- /dev/null
+++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java
@@ -0,0 +1,50 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId;
+
+import java.util.Optional;
+import lombok.RequiredArgsConstructor;
+import org.opensearch.sql.spark.client.EMRServerlessClient;
+import org.opensearch.sql.spark.execution.statestore.SessionStateStore;
+
+/**
+ * Singleton Class
+ *
+ * todo. add Session cache and Session sweeper.
+ */
+@RequiredArgsConstructor
+public class SessionManager {
+ private final SessionStateStore stateStore;
+ private final EMRServerlessClient emrServerlessClient;
+
+ public Session createSession(CreateSessionRequest request) {
+ InteractiveSession session =
+ InteractiveSession.builder()
+ .sessionId(newSessionId())
+ .sessionStateStore(stateStore)
+ .serverlessClient(emrServerlessClient)
+ .build();
+ session.open(request);
+ return session;
+ }
+
+ public Optional getSession(SessionId sid) {
+ Optional model = stateStore.get(sid);
+ if (model.isPresent()) {
+ InteractiveSession session =
+ InteractiveSession.builder()
+ .sessionId(sid)
+ .sessionStateStore(stateStore)
+ .serverlessClient(emrServerlessClient)
+ .sessionModel(model.get())
+ .build();
+ return Optional.ofNullable(session);
+ }
+ return Optional.empty();
+ }
+}
diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java
new file mode 100644
index 0000000000..656f0ec8ce
--- /dev/null
+++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java
@@ -0,0 +1,143 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED;
+import static org.opensearch.sql.spark.execution.session.SessionType.INTERACTIVE;
+
+import java.io.IOException;
+import lombok.Builder;
+import lombok.Data;
+import lombok.SneakyThrows;
+import org.opensearch.core.xcontent.ToXContentObject;
+import org.opensearch.core.xcontent.XContentBuilder;
+import org.opensearch.core.xcontent.XContentParser;
+import org.opensearch.core.xcontent.XContentParserUtils;
+import org.opensearch.index.seqno.SequenceNumbers;
+
+/** Session data in flint.ql.sessions index. */
+@Data
+@Builder
+public class SessionModel implements ToXContentObject {
+ public static final String VERSION = "version";
+ public static final String TYPE = "type";
+ public static final String SESSION_TYPE = "sessionType";
+ public static final String SESSION_ID = "sessionId";
+ public static final String SESSION_STATE = "state";
+ public static final String DATASOURCE_NAME = "dataSourceName";
+ public static final String LAST_UPDATE_TIME = "lastUpdateTime";
+ public static final String APPLICATION_ID = "applicationId";
+ public static final String JOB_ID = "jobId";
+ public static final String ERROR = "error";
+ public static final String UNKNOWN = "unknown";
+ public static final String SESSION_DOC_TYPE = "session";
+
+ private final String version;
+ private final SessionType sessionType;
+ private final SessionId sessionId;
+ private final SessionState sessionState;
+ private final String applicationId;
+ private final String jobId;
+ private final String datasourceName;
+ private final String error;
+ private final long lastUpdateTime;
+
+ private final long seqNo;
+ private final long primaryTerm;
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder
+ .startObject()
+ .field(VERSION, version)
+ .field(TYPE, SESSION_DOC_TYPE)
+ .field(SESSION_TYPE, sessionType.getSessionType())
+ .field(SESSION_ID, sessionId.getSessionId())
+ .field(SESSION_STATE, sessionState.getSessionState())
+ .field(DATASOURCE_NAME, datasourceName)
+ .field(APPLICATION_ID, applicationId)
+ .field(JOB_ID, jobId)
+ .field(LAST_UPDATE_TIME, lastUpdateTime)
+ .field(ERROR, error)
+ .endObject();
+ return builder;
+ }
+
+ public static SessionModel of(SessionModel copy, long seqNo, long primaryTerm) {
+ return builder()
+ .version(copy.version)
+ .sessionType(copy.sessionType)
+ .sessionId(new SessionId(copy.sessionId.getSessionId()))
+ .sessionState(copy.sessionState)
+ .datasourceName(copy.datasourceName)
+ .seqNo(seqNo)
+ .primaryTerm(primaryTerm)
+ .build();
+ }
+
+ @SneakyThrows
+ public static SessionModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) {
+ SessionModelBuilder builder = new SessionModelBuilder();
+ XContentParserUtils.ensureExpectedToken(
+ XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
+ while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) {
+ String fieldName = parser.currentName();
+ parser.nextToken();
+ switch (fieldName) {
+ case VERSION:
+ builder.version(parser.text());
+ break;
+ case SESSION_TYPE:
+ builder.sessionType(SessionType.fromString(parser.text()));
+ break;
+ case SESSION_ID:
+ builder.sessionId(new SessionId(parser.text()));
+ break;
+ case SESSION_STATE:
+ builder.sessionState(SessionState.fromString(parser.text()));
+ break;
+ case DATASOURCE_NAME:
+ builder.datasourceName(parser.text());
+ break;
+ case ERROR:
+ builder.error(parser.text());
+ break;
+ case APPLICATION_ID:
+ builder.applicationId(parser.text());
+ break;
+ case JOB_ID:
+ builder.jobId(parser.text());
+ break;
+ case LAST_UPDATE_TIME:
+ builder.lastUpdateTime(parser.longValue());
+ break;
+ case TYPE:
+ // do nothing.
+ break;
+ }
+ }
+ builder.seqNo(seqNo);
+ builder.primaryTerm(primaryTerm);
+ return builder.build();
+ }
+
+ public static SessionModel initInteractiveSession(
+ String applicationId, String jobId, SessionId sid, String datasourceName) {
+ return builder()
+ .version("1.0")
+ .sessionType(INTERACTIVE)
+ .sessionId(sid)
+ .sessionState(NOT_STARTED)
+ .datasourceName(datasourceName)
+ .applicationId(applicationId)
+ .jobId(jobId)
+ .error(UNKNOWN)
+ .lastUpdateTime(System.currentTimeMillis())
+ .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO)
+ .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM)
+ .build();
+ }
+}
diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java
new file mode 100644
index 0000000000..509d5105e9
--- /dev/null
+++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.stream.Collectors;
+import lombok.Getter;
+
+@Getter
+public enum SessionState {
+ NOT_STARTED("not_started"),
+ RUNNING("running"),
+ DEAD("dead"),
+ FAIL("fail");
+
+ private final String sessionState;
+
+ SessionState(String sessionState) {
+ this.sessionState = sessionState;
+ }
+
+ private static Map STATES =
+ Arrays.stream(SessionState.values())
+ .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t));
+
+ public static SessionState fromString(String key) {
+ if (STATES.containsKey(key)) {
+ return STATES.get(key);
+ }
+ throw new IllegalArgumentException("Invalid session state: " + key);
+ }
+}
diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java
new file mode 100644
index 0000000000..dd179a1dc5
--- /dev/null
+++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java
@@ -0,0 +1,33 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.stream.Collectors;
+import lombok.Getter;
+
+@Getter
+public enum SessionType {
+ INTERACTIVE("interactive");
+
+ private final String sessionType;
+
+ SessionType(String sessionType) {
+ this.sessionType = sessionType;
+ }
+
+ private static Map TYPES =
+ Arrays.stream(SessionType.values())
+ .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t));
+
+ public static SessionType fromString(String key) {
+ if (TYPES.containsKey(key)) {
+ return TYPES.get(key);
+ }
+ throw new IllegalArgumentException("Invalid session type: " + key);
+ }
+}
diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java
new file mode 100644
index 0000000000..6ddce55360
--- /dev/null
+++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java
@@ -0,0 +1,87 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.statestore;
+
+import java.io.IOException;
+import java.util.Locale;
+import java.util.Optional;
+import lombok.RequiredArgsConstructor;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.opensearch.action.DocWriteResponse;
+import org.opensearch.action.get.GetRequest;
+import org.opensearch.action.get.GetResponse;
+import org.opensearch.action.index.IndexRequest;
+import org.opensearch.action.index.IndexResponse;
+import org.opensearch.action.support.WriteRequest;
+import org.opensearch.client.Client;
+import org.opensearch.common.xcontent.LoggingDeprecationHandler;
+import org.opensearch.common.xcontent.XContentFactory;
+import org.opensearch.common.xcontent.XContentType;
+import org.opensearch.core.xcontent.NamedXContentRegistry;
+import org.opensearch.core.xcontent.ToXContent;
+import org.opensearch.core.xcontent.XContentParser;
+import org.opensearch.sql.spark.execution.session.SessionId;
+import org.opensearch.sql.spark.execution.session.SessionModel;
+
+@RequiredArgsConstructor
+public class SessionStateStore {
+ private static final Logger LOG = LogManager.getLogger();
+
+ private final String indexName;
+ private final Client client;
+
+ public SessionModel create(SessionModel session) {
+ try {
+ IndexRequest indexRequest =
+ new IndexRequest(indexName)
+ .id(session.getSessionId().getSessionId())
+ .source(session.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS))
+ .setIfSeqNo(session.getSeqNo())
+ .setIfPrimaryTerm(session.getPrimaryTerm())
+ .create(true)
+ .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL);
+ IndexResponse indexResponse = client.index(indexRequest).actionGet();
+ if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) {
+ LOG.debug("Successfully created doc. id: {}", session.getSessionId());
+ return SessionModel.of(session, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm());
+ } else {
+ throw new RuntimeException(
+ String.format(
+ Locale.ROOT,
+ "Failed create doc. id: %s, error: %s",
+ session.getSessionId(),
+ indexResponse.getResult().getLowercase()));
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public Optional get(SessionId sid) {
+ try {
+ GetRequest getRequest = new GetRequest().index(indexName).id(sid.getSessionId());
+ GetResponse getResponse = client.get(getRequest).actionGet();
+ if (getResponse.isExists()) {
+ XContentParser parser =
+ XContentType.JSON
+ .xContent()
+ .createParser(
+ NamedXContentRegistry.EMPTY,
+ LoggingDeprecationHandler.INSTANCE,
+ getResponse.getSourceAsString());
+ parser.nextToken();
+ return Optional.of(
+ SessionModel.fromXContent(
+ parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm()));
+ } else {
+ return Optional.empty();
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java
new file mode 100644
index 0000000000..53dc211ded
--- /dev/null
+++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java
@@ -0,0 +1,213 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession;
+import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED;
+
+import com.amazonaws.services.emrserverless.model.CancelJobRunResult;
+import com.amazonaws.services.emrserverless.model.GetJobRunResult;
+import java.util.HashMap;
+import java.util.Optional;
+import lombok.RequiredArgsConstructor;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.opensearch.action.admin.indices.delete.DeleteIndexRequest;
+import org.opensearch.action.delete.DeleteRequest;
+import org.opensearch.sql.spark.client.EMRServerlessClient;
+import org.opensearch.sql.spark.client.StartJobRequest;
+import org.opensearch.sql.spark.execution.statestore.SessionStateStore;
+import org.opensearch.test.OpenSearchSingleNodeTestCase;
+
+/** mock-maker-inline does not work with OpenSearchTestCase. */
+public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase {
+
+ private static final String indexName = "mockindex";
+
+ private TestEMRServerlessClient emrsClient;
+ private StartJobRequest startJobRequest;
+ private SessionStateStore stateStore;
+
+ @Before
+ public void setup() {
+ emrsClient = new TestEMRServerlessClient();
+ startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, "");
+ stateStore = new SessionStateStore(indexName, client());
+ createIndex(indexName);
+ }
+
+ @After
+ public void clean() {
+ client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet();
+ }
+
+ @Test
+ public void openCloseSession() {
+ InteractiveSession session =
+ InteractiveSession.builder()
+ .sessionId(SessionId.newSessionId())
+ .sessionStateStore(stateStore)
+ .serverlessClient(emrsClient)
+ .build();
+
+ // open session
+ TestSession testSession = testSession(session, stateStore);
+ testSession
+ .open(new CreateSessionRequest(startJobRequest, "datasource"))
+ .assertSessionState(NOT_STARTED)
+ .assertAppId("appId")
+ .assertJobId("jobId");
+ emrsClient.startJobRunCalled(1);
+
+ // close session
+ testSession.close();
+ emrsClient.cancelJobRunCalled(1);
+ }
+
+ @Test
+ public void openSessionFailedConflict() {
+ SessionId sessionId = new SessionId("duplicate-session-id");
+ InteractiveSession session =
+ InteractiveSession.builder()
+ .sessionId(sessionId)
+ .sessionStateStore(stateStore)
+ .serverlessClient(emrsClient)
+ .build();
+ session.open(new CreateSessionRequest(startJobRequest, "datasource"));
+
+ InteractiveSession duplicateSession =
+ InteractiveSession.builder()
+ .sessionId(sessionId)
+ .sessionStateStore(stateStore)
+ .serverlessClient(emrsClient)
+ .build();
+ IllegalStateException exception =
+ assertThrows(
+ IllegalStateException.class,
+ () -> duplicateSession.open(new CreateSessionRequest(startJobRequest, "datasource")));
+ assertEquals("session already exist. sessionId=duplicate-session-id", exception.getMessage());
+ }
+
+ @Test
+ public void closeNotExistSession() {
+ SessionId sessionId = SessionId.newSessionId();
+ InteractiveSession session =
+ InteractiveSession.builder()
+ .sessionId(sessionId)
+ .sessionStateStore(stateStore)
+ .serverlessClient(emrsClient)
+ .build();
+ session.open(new CreateSessionRequest(startJobRequest, "datasource"));
+
+ client().delete(new DeleteRequest(indexName, sessionId.getSessionId()));
+
+ IllegalStateException exception = assertThrows(IllegalStateException.class, session::close);
+ assertEquals("session not exist. " + sessionId, exception.getMessage());
+ emrsClient.cancelJobRunCalled(0);
+ }
+
+ @Test
+ public void sessionManagerCreateSession() {
+ Session session =
+ new SessionManager(stateStore, emrsClient)
+ .createSession(new CreateSessionRequest(startJobRequest, "datasource"));
+
+ TestSession testSession = testSession(session, stateStore);
+ testSession.assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId");
+ }
+
+ @Test
+ public void sessionManagerGetSession() {
+ SessionManager sessionManager = new SessionManager(stateStore, emrsClient);
+ Session session =
+ sessionManager.createSession(new CreateSessionRequest(startJobRequest, "datasource"));
+
+ Optional managerSession = sessionManager.getSession(session.getSessionId());
+ assertTrue(managerSession.isPresent());
+ assertEquals(session.getSessionId(), managerSession.get().getSessionId());
+ }
+
+ @Test
+ public void sessionManagerGetSessionNotExist() {
+ SessionManager sessionManager = new SessionManager(stateStore, emrsClient);
+
+ Optional managerSession = sessionManager.getSession(new SessionId("no-exist"));
+ assertTrue(managerSession.isEmpty());
+ }
+
+ @RequiredArgsConstructor
+ static class TestSession {
+ private final Session session;
+ private final SessionStateStore stateStore;
+
+ public static TestSession testSession(Session session, SessionStateStore stateStore) {
+ return new TestSession(session, stateStore);
+ }
+
+ public TestSession assertSessionState(SessionState expected) {
+ assertEquals(expected, session.getSessionModel().getSessionState());
+
+ Optional sessionStoreState =
+ stateStore.get(session.getSessionModel().getSessionId());
+ assertTrue(sessionStoreState.isPresent());
+ assertEquals(expected, sessionStoreState.get().getSessionState());
+
+ return this;
+ }
+
+ public TestSession assertAppId(String expected) {
+ assertEquals(expected, session.getSessionModel().getApplicationId());
+ return this;
+ }
+
+ public TestSession assertJobId(String expected) {
+ assertEquals(expected, session.getSessionModel().getJobId());
+ return this;
+ }
+
+ public TestSession open(CreateSessionRequest req) {
+ session.open(req);
+ return this;
+ }
+
+ public TestSession close() {
+ session.close();
+ return this;
+ }
+ }
+
+ static class TestEMRServerlessClient implements EMRServerlessClient {
+
+ private int startJobRunCalled = 0;
+ private int cancelJobRunCalled = 0;
+
+ @Override
+ public String startJobRun(StartJobRequest startJobRequest) {
+ startJobRunCalled++;
+ return "jobId";
+ }
+
+ @Override
+ public GetJobRunResult getJobRunResult(String applicationId, String jobId) {
+ return null;
+ }
+
+ @Override
+ public CancelJobRunResult cancelJobRun(String applicationId, String jobId) {
+ cancelJobRunCalled++;
+ return null;
+ }
+
+ public void startJobRunCalled(int expectedTimes) {
+ assertEquals(expectedTimes, startJobRunCalled);
+ }
+
+ public void cancelJobRunCalled(int expectedTimes) {
+ assertEquals(expectedTimes, cancelJobRunCalled);
+ }
+ }
+}
diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java
new file mode 100644
index 0000000000..d35105f787
--- /dev/null
+++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import org.junit.After;
+import org.junit.Before;
+import org.mockito.MockMakers;
+import org.mockito.MockSettings;
+import org.mockito.Mockito;
+import org.opensearch.action.admin.indices.delete.DeleteIndexRequest;
+import org.opensearch.sql.spark.execution.statestore.SessionStateStore;
+import org.opensearch.test.OpenSearchSingleNodeTestCase;
+
+class SessionManagerTest extends OpenSearchSingleNodeTestCase {
+ private static final String indexName = "mockindex";
+
+ // mock-maker-inline does not work with OpenSearchTestCase. make sure use mockSettings when mock.
+ private static final MockSettings mockSettings =
+ Mockito.withSettings().mockMaker(MockMakers.SUBCLASS);
+
+ private SessionStateStore stateStore;
+
+ @Before
+ public void setup() {
+ stateStore = new SessionStateStore(indexName, client());
+ createIndex(indexName);
+ }
+
+ @After
+ public void clean() {
+ client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet();
+ }
+}
diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java
new file mode 100644
index 0000000000..a987c80d59
--- /dev/null
+++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java
@@ -0,0 +1,20 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+
+import org.junit.jupiter.api.Test;
+
+class SessionStateTest {
+ @Test
+ public void invalidSessionType() {
+ IllegalArgumentException exception =
+ assertThrows(IllegalArgumentException.class, () -> SessionState.fromString("invalid"));
+ assertEquals("Invalid session state: invalid", exception.getMessage());
+ }
+}
diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java
new file mode 100644
index 0000000000..a2ab43e709
--- /dev/null
+++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java
@@ -0,0 +1,20 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.session;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+
+import org.junit.jupiter.api.Test;
+
+class SessionTypeTest {
+ @Test
+ public void invalidSessionType() {
+ IllegalArgumentException exception =
+ assertThrows(IllegalArgumentException.class, () -> SessionType.fromString("invalid"));
+ assertEquals("Invalid session type: invalid", exception.getMessage());
+ }
+}
diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java
new file mode 100644
index 0000000000..9c779555d7
--- /dev/null
+++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java
@@ -0,0 +1,42 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.spark.execution.statestore;
+
+import static org.junit.Assert.assertThrows;
+import static org.mockito.Answers.RETURNS_DEEP_STUBS;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.when;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.opensearch.action.DocWriteResponse;
+import org.opensearch.action.index.IndexResponse;
+import org.opensearch.client.Client;
+import org.opensearch.sql.spark.execution.session.SessionId;
+import org.opensearch.sql.spark.execution.session.SessionModel;
+
+@ExtendWith(MockitoExtension.class)
+class SessionStateStoreTest {
+ @Mock(answer = RETURNS_DEEP_STUBS)
+ private Client client;
+
+ @Mock private IndexResponse indexResponse;
+
+ @Test
+ public void createWithException() {
+ when(client.index(any()).actionGet()).thenReturn(indexResponse);
+ doReturn(DocWriteResponse.Result.NOT_FOUND).when(indexResponse).getResult();
+ SessionModel sessionModel =
+ SessionModel.initInteractiveSession(
+ "appId", "jobId", SessionId.newSessionId(), "datasource");
+ SessionStateStore sessionStateStore = new SessionStateStore("indexName", client);
+
+ assertThrows(RuntimeException.class, () -> sessionStateStore.create(sessionModel));
+ }
+}