Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new state persistence for state reads #14126

Merged
merged 21 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import io.airbyte.commons.json.Jsons;
import io.airbyte.config.State;
import io.airbyte.config.StateType;
import io.airbyte.config.StateWrapper;
import io.airbyte.protocol.models.AirbyteStateMessage;
Expand Down Expand Up @@ -53,6 +54,25 @@ public static Optional<StateWrapper> getTypedState(final JsonNode state) {
}
}

/**
* Converts a StateWrapper to a State
*
* LegacyStates are directly serialized into the state. GlobalStates and StreamStates are serialized
* as a list of AirbyteStateMessage in the state attribute.
*
* @param stateWrapper the StateWrapper to convert
* @return the Converted State
*/
@SuppressWarnings("UnnecessaryDefault")
public static State getState(final StateWrapper stateWrapper) {
return switch (stateWrapper.getStateType()) {
case LEGACY -> new State().withState(stateWrapper.getLegacyState());
case STREAM -> new State().withState(Jsons.jsonNode(stateWrapper.getStateMessages()));
case GLOBAL -> new State().withState(Jsons.jsonNode(List.of(stateWrapper.getGlobal())));
default -> throw new RuntimeException("Unexpected StateType " + stateWrapper.getStateType());
};
}

private static StateWrapper getLegacyStateWrapper(final JsonNode state) {
return new StateWrapper()
.withStateType(StateType.LEGACY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@

import com.google.common.collect.Lists;
import io.airbyte.commons.json.Jsons;
import io.airbyte.config.State;
import io.airbyte.config.StateType;
import io.airbyte.config.StateWrapper;
import io.airbyte.protocol.models.AirbyteGlobalState;
import io.airbyte.protocol.models.AirbyteStateMessage;
import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType;
import io.airbyte.protocol.models.AirbyteStreamState;
import io.airbyte.protocol.models.StreamDescriptor;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import org.assertj.core.api.Assertions;
Expand Down Expand Up @@ -120,4 +123,59 @@ public void testEmptyStateList() {
.isInstanceOf(IllegalStateException.class);
}

@Test
public void testLegacyStateConversion() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT/OPT:
"""
"""
let you create string and multiline string without escaping " and using +. It helps making the json some readable.

final StateWrapper stateWrapper = new StateWrapper()
.withStateType(StateType.LEGACY)
.withLegacyState(Jsons.deserialize("{\"json\": \"blob\"}"));
final State expectedState = new State().withState(Jsons.deserialize("{\"json\": \"blob\"}"));

final State convertedState = StateMessageHelper.getState(stateWrapper);
Assertions.assertThat(convertedState).isEqualTo(expectedState);
}

@Test
public void testGlobalStateConversion() {
final StateWrapper stateWrapper = new StateWrapper()
.withStateType(StateType.GLOBAL)
.withGlobal(
new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL).withGlobal(
new AirbyteGlobalState()
.withSharedState(Jsons.deserialize("\"shared\""))
.withStreamStates(Collections.singletonList(
new AirbyteStreamState()
.withStreamDescriptor(new StreamDescriptor().withNamespace("ns").withName("name"))
.withStreamState(Jsons.deserialize("\"stream state\""))))));
final State expectedState = new State().withState(Jsons.deserialize(
"[{\"type\":\"GLOBAL\",\"global\":{\"shared_state\":\"shared\",\"stream_states\":[" +
"{\"stream_descriptor\":{\"name\":\"name\",\"namespace\":\"ns\"},\"stream_state\":\"stream state\"}" +
"]}}]"));

final State convertedState = StateMessageHelper.getState(stateWrapper);
Assertions.assertThat(convertedState).isEqualTo(expectedState);
}

@Test
public void testStreamStateConversion() {
final StateWrapper stateWrapper = new StateWrapper()
.withStateType(StateType.STREAM)
.withStateMessages(Arrays.asList(
new AirbyteStateMessage().withType(AirbyteStateType.STREAM).withStream(
new AirbyteStreamState()
.withStreamDescriptor(new StreamDescriptor().withNamespace("ns1").withName("name1"))
.withStreamState(Jsons.deserialize("\"state1\""))),
new AirbyteStateMessage().withType(AirbyteStateType.STREAM).withStream(
new AirbyteStreamState()
.withStreamDescriptor(new StreamDescriptor().withNamespace("ns2").withName("name2"))
.withStreamState(Jsons.deserialize("\"state2\"")))));
final State expectedState = new State().withState(Jsons.deserialize(
"[" +
"{\"type\":\"STREAM\",\"stream\":{\"stream_descriptor\":{\"name\":\"name1\",\"namespace\":\"ns1\"},\"stream_state\":\"state1\"}}," +
"{\"type\":\"STREAM\",\"stream\":{\"stream_descriptor\":{\"name\":\"name2\",\"namespace\":\"ns2\"},\"stream_state\":\"state2\"}}" +
"]"));

final State convertedState = StateMessageHelper.getState(stateWrapper);
Assertions.assertThat(convertedState).isEqualTo(expectedState);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -757,20 +757,6 @@ public List<DestinationOAuthParameter> listDestinationOAuthParam() throws JsonVa
return persistence.listConfigs(ConfigSchema.DESTINATION_OAUTH_PARAM, DestinationOAuthParameter.class);
}

public Optional<State> getConnectionState(final UUID connectionId) throws IOException {
try {
final StandardSyncState connectionState = persistence.getConfig(
ConfigSchema.STANDARD_SYNC_STATE,
connectionId.toString(),
StandardSyncState.class);
return Optional.of(connectionState.getState());
} catch (final ConfigNotFoundException e) {
return Optional.empty();
} catch (final JsonValidationException e) {
throw new IllegalStateException(e);
}
}

public void updateConnectionState(final UUID connectionId, final State state) throws IOException {
LOGGER.info("Updating connection {} state: {}", connectionId, state);
final StandardSyncState connectionState = new StandardSyncState().withConnectionId(connectionId).withState(state);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.jooq.Result;
import org.junit.jupiter.api.AfterEach;
Expand Down Expand Up @@ -113,22 +112,6 @@ void testWorkspaceByConnectionId(final boolean isTombstone) throws ConfigNotFoun
verify(configRepository).getStandardWorkspace(WORKSPACE_ID, isTombstone);
}

@Test
void testGetConnectionState() throws Exception {
final UUID connectionId = UUID.randomUUID();
final State state = new State().withState(Jsons.deserialize("{ \"cursor\": 1000 }"));
final StandardSyncState connectionState = new StandardSyncState().withConnectionId(connectionId).withState(state);

when(configPersistence.getConfig(ConfigSchema.STANDARD_SYNC_STATE, connectionId.toString(), StandardSyncState.class))
.thenThrow(new ConfigNotFoundException(ConfigSchema.STANDARD_SYNC_STATE, connectionId));
assertEquals(Optional.empty(), configRepository.getConnectionState(connectionId));

reset(configPersistence);
when(configPersistence.getConfig(ConfigSchema.STANDARD_SYNC_STATE, connectionId.toString(), StandardSyncState.class))
.thenReturn(connectionState);
assertEquals(Optional.of(state), configRepository.getConnectionState(connectionId));
}

@Test
void testUpdateConnectionState() throws Exception {
final UUID connectionId = UUID.randomUUID();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import io.airbyte.config.StandardSyncOperation;
import io.airbyte.config.State;
import io.airbyte.config.StreamDescriptor;
import io.airbyte.config.persistence.ConfigRepository;
import io.airbyte.config.helpers.StateMessageHelper;
import io.airbyte.config.persistence.StatePersistence;
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.DestinationSyncMode;
import io.airbyte.protocol.models.SyncMode;
Expand All @@ -31,15 +32,15 @@
public class DefaultJobCreator implements JobCreator {

private final JobPersistence jobPersistence;
private final ConfigRepository configRepository;
private final ResourceRequirements workerResourceRequirements;
private final StatePersistence statePersistence;

public DefaultJobCreator(final JobPersistence jobPersistence,
final ConfigRepository configRepository,
final ResourceRequirements workerResourceRequirements) {
final ResourceRequirements workerResourceRequirements,
final StatePersistence statePersistence) {
this.jobPersistence = jobPersistence;
this.configRepository = configRepository;
this.workerResourceRequirements = workerResourceRequirements;
this.statePersistence = statePersistence;
}

@Override
Expand Down Expand Up @@ -126,10 +127,8 @@ public Optional<Long> createResetConnectionJob(final DestinationConnection desti
return jobPersistence.enqueueJob(standardSync.getConnectionId().toString(), jobConfig);
}

// TODO (https://github.com/airbytehq/airbyte/issues/13620): update this method implementation
// to fetch and serialize the new per-stream state format into a State object
private Optional<State> getCurrentConnectionState(final UUID connectionId) throws IOException {
return configRepository.getConnectionState(connectionId);
return statePersistence.getCurrentState(connectionId).map(StateMessageHelper::getState);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
import io.airbyte.config.StandardSyncOperation.OperatorType;
import io.airbyte.config.State;
import io.airbyte.config.StreamDescriptor;
import io.airbyte.config.persistence.ConfigRepository;
import io.airbyte.config.helpers.StateMessageHelper;
import io.airbyte.config.persistence.StatePersistence;
import io.airbyte.protocol.models.CatalogHelpers;
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.ConfiguredAirbyteStream;
Expand Down Expand Up @@ -65,7 +66,7 @@ public class DefaultJobCreatorTest {
private static final StreamDescriptor STREAM_DESCRIPTOR2 = new StreamDescriptor().withName("stream 2").withNamespace("namespace 2");

private JobPersistence jobPersistence;
private ConfigRepository configRepository;
private StatePersistence statePersistence;
private JobCreator jobCreator;
private ResourceRequirements workerResourceRequirements;

Expand Down Expand Up @@ -126,13 +127,13 @@ public class DefaultJobCreatorTest {
@BeforeEach
void setup() {
jobPersistence = mock(JobPersistence.class);
configRepository = mock(ConfigRepository.class);
statePersistence = mock(StatePersistence.class);
workerResourceRequirements = new ResourceRequirements()
.withCpuLimit("0.2")
.withCpuRequest("0.2")
.withMemoryLimit("200Mi")
.withMemoryRequest("200Mi");
jobCreator = new DefaultJobCreator(jobPersistence, configRepository, workerResourceRequirements);
jobCreator = new DefaultJobCreator(jobPersistence, workerResourceRequirements, statePersistence);
}

@Test
Expand Down Expand Up @@ -336,7 +337,8 @@ void testCreateResetConnectionJob() throws IOException {
});

final State connectionState = new State().withState(Jsons.jsonNode(Map.of("key", "val")));
when(configRepository.getConnectionState(STANDARD_SYNC.getConnectionId())).thenReturn(Optional.of(connectionState));
when(statePersistence.getCurrentState(STANDARD_SYNC.getConnectionId()))
.thenReturn(StateMessageHelper.getTypedState(connectionState.getState()));

final JobResetConnectionConfig jobResetConnectionConfig = new JobResetConnectionConfig()
.withNamespaceDefinition(STANDARD_SYNC.getNamespaceDefinition())
Expand Down Expand Up @@ -379,7 +381,8 @@ void testCreateResetConnectionJobEnsureNoQueuing() throws IOException {
});

final State connectionState = new State().withState(Jsons.jsonNode(Map.of("key", "val")));
when(configRepository.getConnectionState(STANDARD_SYNC.getConnectionId())).thenReturn(Optional.of(connectionState));
when(statePersistence.getCurrentState(STANDARD_SYNC.getConnectionId()))
.thenReturn(StateMessageHelper.getTypedState(connectionState.getState()));

final JobResetConnectionConfig jobResetConnectionConfig = new JobResetConnectionConfig()
.withNamespaceDefinition(STANDARD_SYNC.getNamespaceDefinition())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.airbyte.config.persistence.ConfigRepository;
import io.airbyte.config.persistence.SecretsRepositoryReader;
import io.airbyte.config.persistence.SecretsRepositoryWriter;
import io.airbyte.config.persistence.StatePersistence;
import io.airbyte.db.Database;
import io.airbyte.scheduler.client.EventRunner;
import io.airbyte.scheduler.client.SynchronousSchedulerClient;
Expand Down Expand Up @@ -46,6 +47,7 @@ public class ConfigurationApiFactory implements Factory<ConfigurationApi> {
private static EventRunner eventRunner;
private static Flyway configsFlyway;
private static Flyway jobsFlyway;
private static StatePersistence statePersistence;

public static void setValues(
final ConfigRepository configRepository,
Expand All @@ -66,7 +68,8 @@ public static void setValues(
final HttpClient httpClient,
final EventRunner eventRunner,
final Flyway configsFlyway,
final Flyway jobsFlyway) {
final Flyway jobsFlyway,
final StatePersistence statePersistence) {
ConfigurationApiFactory.configRepository = configRepository;
ConfigurationApiFactory.jobPersistence = jobPersistence;
ConfigurationApiFactory.seed = seed;
Expand All @@ -86,6 +89,7 @@ public static void setValues(
ConfigurationApiFactory.eventRunner = eventRunner;
ConfigurationApiFactory.configsFlyway = configsFlyway;
ConfigurationApiFactory.jobsFlyway = jobsFlyway;
ConfigurationApiFactory.statePersistence = statePersistence;
}

@Override
Expand All @@ -110,7 +114,8 @@ public ConfigurationApi provide() {
ConfigurationApiFactory.httpClient,
ConfigurationApiFactory.eventRunner,
ConfigurationApiFactory.configsFlyway,
ConfigurationApiFactory.jobsFlyway);
ConfigurationApiFactory.jobsFlyway,
ConfigurationApiFactory.statePersistence);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.airbyte.config.persistence.DatabaseConfigPersistence;
import io.airbyte.config.persistence.SecretsRepositoryReader;
import io.airbyte.config.persistence.SecretsRepositoryWriter;
import io.airbyte.config.persistence.StatePersistence;
import io.airbyte.config.persistence.split_secrets.JsonSecretsProcessor;
import io.airbyte.config.persistence.split_secrets.SecretPersistence;
import io.airbyte.config.persistence.split_secrets.SecretsHydrator;
Expand Down Expand Up @@ -184,6 +185,7 @@ public static ServerRunnable getServer(final ServerFactory apiFactory,
LOGGER.info("Creating jobs persistence...");
final Database jobsDatabase = new Database(jobsDslContext);
final JobPersistence jobPersistence = new DefaultJobPersistence(jobsDatabase);
final StatePersistence statePersistence = new StatePersistence(configsDatabase);

TrackingClientSingleton.initialize(
configs.getTrackingStrategy(),
Expand Down Expand Up @@ -234,7 +236,8 @@ public static ServerRunnable getServer(final ServerFactory apiFactory,
httpClient,
eventRunner,
configsFlyway,
jobsFlyway);
jobsFlyway,
statePersistence);
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.airbyte.config.persistence.ConfigRepository;
import io.airbyte.config.persistence.SecretsRepositoryReader;
import io.airbyte.config.persistence.SecretsRepositoryWriter;
import io.airbyte.config.persistence.StatePersistence;
import io.airbyte.db.Database;
import io.airbyte.scheduler.client.EventRunner;
import io.airbyte.scheduler.client.SynchronousSchedulerClient;
Expand Down Expand Up @@ -43,7 +44,8 @@ ServerRunnable create(SynchronousSchedulerClient cachingSchedulerClient,
HttpClient httpClient,
EventRunner eventRunner,
Flyway configsFlyway,
Flyway jobsFlyway);
Flyway jobsFlyway,
StatePersistence statePersistence);

class Api implements ServerFactory {

Expand All @@ -64,7 +66,8 @@ public ServerRunnable create(final SynchronousSchedulerClient synchronousSchedul
final HttpClient httpClient,
final EventRunner eventRunner,
final Flyway configsFlyway,
final Flyway jobsFlyway) {
final Flyway jobsFlyway,
final StatePersistence statePersistence) {
// set static values for factory
ConfigurationApiFactory.setValues(
configRepository,
Expand All @@ -85,7 +88,8 @@ public ServerRunnable create(final SynchronousSchedulerClient synchronousSchedul
httpClient,
eventRunner,
configsFlyway,
jobsFlyway);
jobsFlyway,
statePersistence);

// server configurations
final Set<Class<?>> componentClasses = Set.of(ConfigurationApi.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
import io.airbyte.config.persistence.ConfigRepository;
import io.airbyte.config.persistence.SecretsRepositoryReader;
import io.airbyte.config.persistence.SecretsRepositoryWriter;
import io.airbyte.config.persistence.StatePersistence;
import io.airbyte.db.Database;
import io.airbyte.scheduler.client.EventRunner;
import io.airbyte.scheduler.client.SynchronousSchedulerClient;
Expand Down Expand Up @@ -178,7 +179,8 @@ public ConfigurationApi(final ConfigRepository configRepository,
final HttpClient httpClient,
final EventRunner eventRunner,
final Flyway configsFlyway,
final Flyway jobsFlyway) {
final Flyway jobsFlyway,
final StatePersistence statePersistence) {
this.workerEnvironment = workerEnvironment;
this.logConfigs = logConfigs;
this.workspaceRoot = workspaceRoot;
Expand All @@ -195,7 +197,8 @@ public ConfigurationApi(final ConfigRepository configRepository,
jobPersistence,
workerEnvironment,
logConfigs,
eventRunner);
eventRunner,
statePersistence);

connectionsHandler = new ConnectionsHandler(
configRepository,
Expand Down
Loading