diff --git a/spark/build.gradle b/spark/build.gradle index c06b5b6ecf..c2c925ecaf 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -52,15 +52,38 @@ dependencies { api group: 'com.amazonaws', name: 'aws-java-sdk-emrserverless', version: '1.12.545' implementation group: 'commons-io', name: 'commons-io', version: '2.8.0' - testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') + testImplementation(platform("org.junit:junit-bom:5.6.2")) + + testImplementation('org.junit.jupiter:junit-jupiter') testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.2.0' testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '5.2.0' - testImplementation 'junit:junit:4.13.1' - testImplementation "org.opensearch.test:framework:${opensearch_version}" + + testCompileOnly('junit:junit:4.13.1') { + exclude group: 'org.hamcrest', module: 'hamcrest-core' + } + testRuntimeOnly("org.junit.vintage:junit-vintage-engine") { + exclude group: 'org.hamcrest', module: 'hamcrest-core' + } + testRuntimeOnly("org.junit.platform:junit-platform-launcher") { + because 'allows tests to run from IDEs that bundle older version of launcher' + } + testImplementation("org.opensearch.test:framework:${opensearch_version}") } test { - useJUnitPlatform() + useJUnitPlatform { + includeEngines("junit-jupiter") + } + testLogging { + events "failed" + exceptionFormat "full" + } +} +task junit4(type: Test) { + useJUnitPlatform { + includeEngines("junit-vintage") + } + systemProperty 'tests.security.manager', 'false' testLogging { events "failed" exceptionFormat "full" @@ -68,6 +91,8 @@ test { } jacocoTestReport { + dependsOn test, junit4 + executionData test, junit4 reports { html.enabled true xml.enabled true @@ -78,9 +103,10 @@ jacocoTestReport { })) } } -test.finalizedBy(project.tasks.jacocoTestReport) jacocoTestCoverageVerification { + dependsOn test, junit4 + executionData test, junit4 violationRules { rule { element = 'CLASS' @@ -92,6 +118,9 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.asyncquery.exceptions.*', 'org.opensearch.sql.spark.dispatcher.model.*', 'org.opensearch.sql.spark.flint.FlintIndexType', + // ignore because XContext IOException + 'org.opensearch.sql.spark.execution.statestore.SessionStateStore', + 'org.opensearch.sql.spark.execution.session.SessionModel' ] limit { counter = 'LINE' diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java new file mode 100644 index 0000000000..17e3346248 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/CreateSessionRequest.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import lombok.Data; +import org.opensearch.sql.spark.client.StartJobRequest; + +@Data +public class CreateSessionRequest { + private final StartJobRequest startJobRequest; + private final String datasourceName; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java new file mode 100644 index 0000000000..620e46b9be --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; + +import java.util.Optional; +import lombok.Builder; +import lombok.Getter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; + +/** + * Interactive session. + * + *

ENTRY_STATE: not_started + */ +@Getter +@Builder +public class InteractiveSession implements Session { + private static final Logger LOG = LogManager.getLogger(); + + private final SessionId sessionId; + private final SessionStateStore sessionStateStore; + private final EMRServerlessClient serverlessClient; + + private SessionModel sessionModel; + + @Override + public void open(CreateSessionRequest createSessionRequest) { + try { + String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest()); + String applicationId = createSessionRequest.getStartJobRequest().getApplicationId(); + + sessionModel = + initInteractiveSession( + applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); + sessionStateStore.create(sessionModel); + } catch (VersionConflictEngineException e) { + String errorMsg = "session already exist. " + sessionId; + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + @Override + public void close() { + Optional model = sessionStateStore.get(sessionModel.getSessionId()); + if (model.isEmpty()) { + throw new IllegalStateException("session not exist. " + sessionModel.getSessionId()); + } else { + serverlessClient.cancelJobRun(sessionModel.getApplicationId(), sessionModel.getJobId()); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java new file mode 100644 index 0000000000..ec9775e60a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +/** Session define the statement execution context. Each session is binding to one Spark Job. */ +public interface Session { + /** open session. */ + void open(CreateSessionRequest createSessionRequest); + + /** close session. */ + void close(); + + SessionModel getSessionModel(); + + SessionId getSessionId(); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java new file mode 100644 index 0000000000..a2847cde18 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionId.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import lombok.Data; +import org.apache.commons.lang3.RandomStringUtils; + +@Data +public class SessionId { + private final String sessionId; + + public static SessionId newSessionId() { + return new SessionId(RandomStringUtils.random(10, true, true)); + } + + @Override + public String toString() { + return "sessionId=" + sessionId; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java new file mode 100644 index 0000000000..3d0916bac8 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId; + +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; + +/** + * Singleton Class + * + *

todo. add Session cache and Session sweeper. + */ +@RequiredArgsConstructor +public class SessionManager { + private final SessionStateStore stateStore; + private final EMRServerlessClient emrServerlessClient; + + public Session createSession(CreateSessionRequest request) { + InteractiveSession session = + InteractiveSession.builder() + .sessionId(newSessionId()) + .sessionStateStore(stateStore) + .serverlessClient(emrServerlessClient) + .build(); + session.open(request); + return session; + } + + public Optional getSession(SessionId sid) { + Optional model = stateStore.get(sid); + if (model.isPresent()) { + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sid) + .sessionStateStore(stateStore) + .serverlessClient(emrServerlessClient) + .sessionModel(model.get()) + .build(); + return Optional.ofNullable(session); + } + return Optional.empty(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java new file mode 100644 index 0000000000..656f0ec8ce --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; +import static org.opensearch.sql.spark.execution.session.SessionType.INTERACTIVE; + +import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.SneakyThrows; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.index.seqno.SequenceNumbers; + +/** Session data in flint.ql.sessions index. */ +@Data +@Builder +public class SessionModel implements ToXContentObject { + public static final String VERSION = "version"; + public static final String TYPE = "type"; + public static final String SESSION_TYPE = "sessionType"; + public static final String SESSION_ID = "sessionId"; + public static final String SESSION_STATE = "state"; + public static final String DATASOURCE_NAME = "dataSourceName"; + public static final String LAST_UPDATE_TIME = "lastUpdateTime"; + public static final String APPLICATION_ID = "applicationId"; + public static final String JOB_ID = "jobId"; + public static final String ERROR = "error"; + public static final String UNKNOWN = "unknown"; + public static final String SESSION_DOC_TYPE = "session"; + + private final String version; + private final SessionType sessionType; + private final SessionId sessionId; + private final SessionState sessionState; + private final String applicationId; + private final String jobId; + private final String datasourceName; + private final String error; + private final long lastUpdateTime; + + private final long seqNo; + private final long primaryTerm; + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .startObject() + .field(VERSION, version) + .field(TYPE, SESSION_DOC_TYPE) + .field(SESSION_TYPE, sessionType.getSessionType()) + .field(SESSION_ID, sessionId.getSessionId()) + .field(SESSION_STATE, sessionState.getSessionState()) + .field(DATASOURCE_NAME, datasourceName) + .field(APPLICATION_ID, applicationId) + .field(JOB_ID, jobId) + .field(LAST_UPDATE_TIME, lastUpdateTime) + .field(ERROR, error) + .endObject(); + return builder; + } + + public static SessionModel of(SessionModel copy, long seqNo, long primaryTerm) { + return builder() + .version(copy.version) + .sessionType(copy.sessionType) + .sessionId(new SessionId(copy.sessionId.getSessionId())) + .sessionState(copy.sessionState) + .datasourceName(copy.datasourceName) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + @SneakyThrows + public static SessionModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) { + SessionModelBuilder builder = new SessionModelBuilder(); + XContentParserUtils.ensureExpectedToken( + XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case VERSION: + builder.version(parser.text()); + break; + case SESSION_TYPE: + builder.sessionType(SessionType.fromString(parser.text())); + break; + case SESSION_ID: + builder.sessionId(new SessionId(parser.text())); + break; + case SESSION_STATE: + builder.sessionState(SessionState.fromString(parser.text())); + break; + case DATASOURCE_NAME: + builder.datasourceName(parser.text()); + break; + case ERROR: + builder.error(parser.text()); + break; + case APPLICATION_ID: + builder.applicationId(parser.text()); + break; + case JOB_ID: + builder.jobId(parser.text()); + break; + case LAST_UPDATE_TIME: + builder.lastUpdateTime(parser.longValue()); + break; + case TYPE: + // do nothing. + break; + } + } + builder.seqNo(seqNo); + builder.primaryTerm(primaryTerm); + return builder.build(); + } + + public static SessionModel initInteractiveSession( + String applicationId, String jobId, SessionId sid, String datasourceName) { + return builder() + .version("1.0") + .sessionType(INTERACTIVE) + .sessionId(sid) + .sessionState(NOT_STARTED) + .datasourceName(datasourceName) + .applicationId(applicationId) + .jobId(jobId) + .error(UNKNOWN) + .lastUpdateTime(System.currentTimeMillis()) + .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO) + .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) + .build(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java new file mode 100644 index 0000000000..509d5105e9 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; + +@Getter +public enum SessionState { + NOT_STARTED("not_started"), + RUNNING("running"), + DEAD("dead"), + FAIL("fail"); + + private final String sessionState; + + SessionState(String sessionState) { + this.sessionState = sessionState; + } + + private static Map STATES = + Arrays.stream(SessionState.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + public static SessionState fromString(String key) { + if (STATES.containsKey(key)) { + return STATES.get(key); + } + throw new IllegalArgumentException("Invalid session state: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java new file mode 100644 index 0000000000..dd179a1dc5 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionType.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; + +@Getter +public enum SessionType { + INTERACTIVE("interactive"); + + private final String sessionType; + + SessionType(String sessionType) { + this.sessionType = sessionType; + } + + private static Map TYPES = + Arrays.stream(SessionType.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + public static SessionType fromString(String key) { + if (TYPES.containsKey(key)) { + return TYPES.get(key); + } + throw new IllegalArgumentException("Invalid session type: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java new file mode 100644 index 0000000000..6ddce55360 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import java.io.IOException; +import java.util.Locale; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionModel; + +@RequiredArgsConstructor +public class SessionStateStore { + private static final Logger LOG = LogManager.getLogger(); + + private final String indexName; + private final Client client; + + public SessionModel create(SessionModel session) { + try { + IndexRequest indexRequest = + new IndexRequest(indexName) + .id(session.getSessionId().getSessionId()) + .source(session.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .setIfSeqNo(session.getSeqNo()) + .setIfPrimaryTerm(session.getPrimaryTerm()) + .create(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", session.getSessionId()); + return SessionModel.of(session, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + session.getSessionId(), + indexResponse.getResult().getLowercase())); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public Optional get(SessionId sid) { + try { + GetRequest getRequest = new GetRequest().index(indexName).id(sid.getSessionId()); + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + SessionModel.fromXContent( + parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { + return Optional.empty(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java new file mode 100644 index 0000000000..53dc211ded --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -0,0 +1,213 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; +import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; + +import com.amazonaws.services.emrserverless.model.CancelJobRunResult; +import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import java.util.HashMap; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +/** mock-maker-inline does not work with OpenSearchTestCase. */ +public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { + + private static final String indexName = "mockindex"; + + private TestEMRServerlessClient emrsClient; + private StartJobRequest startJobRequest; + private SessionStateStore stateStore; + + @Before + public void setup() { + emrsClient = new TestEMRServerlessClient(); + startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); + stateStore = new SessionStateStore(indexName, client()); + createIndex(indexName); + } + + @After + public void clean() { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } + + @Test + public void openCloseSession() { + InteractiveSession session = + InteractiveSession.builder() + .sessionId(SessionId.newSessionId()) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + + // open session + TestSession testSession = testSession(session, stateStore); + testSession + .open(new CreateSessionRequest(startJobRequest, "datasource")) + .assertSessionState(NOT_STARTED) + .assertAppId("appId") + .assertJobId("jobId"); + emrsClient.startJobRunCalled(1); + + // close session + testSession.close(); + emrsClient.cancelJobRunCalled(1); + } + + @Test + public void openSessionFailedConflict() { + SessionId sessionId = new SessionId("duplicate-session-id"); + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sessionId) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + session.open(new CreateSessionRequest(startJobRequest, "datasource")); + + InteractiveSession duplicateSession = + InteractiveSession.builder() + .sessionId(sessionId) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> duplicateSession.open(new CreateSessionRequest(startJobRequest, "datasource"))); + assertEquals("session already exist. sessionId=duplicate-session-id", exception.getMessage()); + } + + @Test + public void closeNotExistSession() { + SessionId sessionId = SessionId.newSessionId(); + InteractiveSession session = + InteractiveSession.builder() + .sessionId(sessionId) + .sessionStateStore(stateStore) + .serverlessClient(emrsClient) + .build(); + session.open(new CreateSessionRequest(startJobRequest, "datasource")); + + client().delete(new DeleteRequest(indexName, sessionId.getSessionId())); + + IllegalStateException exception = assertThrows(IllegalStateException.class, session::close); + assertEquals("session not exist. " + sessionId, exception.getMessage()); + emrsClient.cancelJobRunCalled(0); + } + + @Test + public void sessionManagerCreateSession() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + TestSession testSession = testSession(session, stateStore); + testSession.assertSessionState(NOT_STARTED).assertAppId("appId").assertJobId("jobId"); + } + + @Test + public void sessionManagerGetSession() { + SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + Session session = + sessionManager.createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + Optional managerSession = sessionManager.getSession(session.getSessionId()); + assertTrue(managerSession.isPresent()); + assertEquals(session.getSessionId(), managerSession.get().getSessionId()); + } + + @Test + public void sessionManagerGetSessionNotExist() { + SessionManager sessionManager = new SessionManager(stateStore, emrsClient); + + Optional managerSession = sessionManager.getSession(new SessionId("no-exist")); + assertTrue(managerSession.isEmpty()); + } + + @RequiredArgsConstructor + static class TestSession { + private final Session session; + private final SessionStateStore stateStore; + + public static TestSession testSession(Session session, SessionStateStore stateStore) { + return new TestSession(session, stateStore); + } + + public TestSession assertSessionState(SessionState expected) { + assertEquals(expected, session.getSessionModel().getSessionState()); + + Optional sessionStoreState = + stateStore.get(session.getSessionModel().getSessionId()); + assertTrue(sessionStoreState.isPresent()); + assertEquals(expected, sessionStoreState.get().getSessionState()); + + return this; + } + + public TestSession assertAppId(String expected) { + assertEquals(expected, session.getSessionModel().getApplicationId()); + return this; + } + + public TestSession assertJobId(String expected) { + assertEquals(expected, session.getSessionModel().getJobId()); + return this; + } + + public TestSession open(CreateSessionRequest req) { + session.open(req); + return this; + } + + public TestSession close() { + session.close(); + return this; + } + } + + static class TestEMRServerlessClient implements EMRServerlessClient { + + private int startJobRunCalled = 0; + private int cancelJobRunCalled = 0; + + @Override + public String startJobRun(StartJobRequest startJobRequest) { + startJobRunCalled++; + return "jobId"; + } + + @Override + public GetJobRunResult getJobRunResult(String applicationId, String jobId) { + return null; + } + + @Override + public CancelJobRunResult cancelJobRun(String applicationId, String jobId) { + cancelJobRunCalled++; + return null; + } + + public void startJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, startJobRunCalled); + } + + public void cancelJobRunCalled(int expectedTimes) { + assertEquals(expectedTimes, cancelJobRunCalled); + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java new file mode 100644 index 0000000000..d35105f787 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.After; +import org.junit.Before; +import org.mockito.MockMakers; +import org.mockito.MockSettings; +import org.mockito.Mockito; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +class SessionManagerTest extends OpenSearchSingleNodeTestCase { + private static final String indexName = "mockindex"; + + // mock-maker-inline does not work with OpenSearchTestCase. make sure use mockSettings when mock. + private static final MockSettings mockSettings = + Mockito.withSettings().mockMaker(MockMakers.SUBCLASS); + + private SessionStateStore stateStore; + + @Before + public void setup() { + stateStore = new SessionStateStore(indexName, client()); + createIndex(indexName); + } + + @After + public void clean() { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java new file mode 100644 index 0000000000..a987c80d59 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionStateTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Test; + +class SessionStateTest { + @Test + public void invalidSessionType() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> SessionState.fromString("invalid")); + assertEquals("Invalid session state: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java new file mode 100644 index 0000000000..a2ab43e709 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionTypeTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.session; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Test; + +class SessionTypeTest { + @Test + public void invalidSessionType() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> SessionType.fromString("invalid")); + assertEquals("Invalid session type: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java new file mode 100644 index 0000000000..9c779555d7 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.client.Client; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionModel; + +@ExtendWith(MockitoExtension.class) +class SessionStateStoreTest { + @Mock(answer = RETURNS_DEEP_STUBS) + private Client client; + + @Mock private IndexResponse indexResponse; + + @Test + public void createWithException() { + when(client.index(any()).actionGet()).thenReturn(indexResponse); + doReturn(DocWriteResponse.Result.NOT_FOUND).when(indexResponse).getResult(); + SessionModel sessionModel = + SessionModel.initInteractiveSession( + "appId", "jobId", SessionId.newSessionId(), "datasource"); + SessionStateStore sessionStateStore = new SessionStateStore("indexName", client); + + assertThrows(RuntimeException.class, () -> sessionStateStore.create(sessionModel)); + } +}