Skip to content

Commit

Permalink
Use new state persistence for state reads (#14126)
Browse files Browse the repository at this point in the history
* Inject StatePersistence into DefaultJobCreator
* Read the state from StatePersistence instead of ConfigRepository
* Add a conversion helper to convert StateWrapper to State
* Remove unused ConfigRepository.getConnectionState
  • Loading branch information
gosusnp authored Jun 28, 2022
1 parent 09798a1 commit 34ed33b
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.collect.Iterables;
import io.airbyte.commons.json.Jsons;
import io.airbyte.config.State;
import io.airbyte.config.StateType;
import io.airbyte.config.StateWrapper;
import io.airbyte.protocol.models.AirbyteStateMessage;
Expand Down Expand Up @@ -74,6 +75,25 @@ public static Optional<StateWrapper> getTypedState(final JsonNode state, final b
}
}

/**
* Converts a StateWrapper to a State
*
* LegacyStates are directly serialized into the state. GlobalStates and StreamStates are serialized
* as a list of AirbyteStateMessage in the state attribute.
*
* @param stateWrapper the StateWrapper to convert
* @return the Converted State
*/
@SuppressWarnings("UnnecessaryDefault")
public static State getState(final StateWrapper stateWrapper) {
return switch (stateWrapper.getStateType()) {
case LEGACY -> new State().withState(stateWrapper.getLegacyState());
case STREAM -> new State().withState(Jsons.jsonNode(stateWrapper.getStateMessages()));
case GLOBAL -> new State().withState(Jsons.jsonNode(List.of(stateWrapper.getGlobal())));
default -> throw new RuntimeException("Unexpected StateType " + stateWrapper.getStateType());
};
}

private static StateWrapper provideGlobalState(final AirbyteStateMessage stateMessages, final boolean useStreamCapableState) {
if (useStreamCapableState) {
return new StateWrapper()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@

import com.fasterxml.jackson.databind.JsonNode;
import io.airbyte.commons.json.Jsons;
import io.airbyte.config.State;
import io.airbyte.config.StateType;
import io.airbyte.config.StateWrapper;
import io.airbyte.protocol.models.AirbyteGlobalState;
import io.airbyte.protocol.models.AirbyteStateMessage;
import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType;
import io.airbyte.protocol.models.AirbyteStreamState;
import io.airbyte.protocol.models.StreamDescriptor;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -181,4 +184,69 @@ public void testDuplicatedGlobalState() {
.isInstanceOf(IllegalStateException.class);
}

@Test
public void testLegacyStateConversion() {
final StateWrapper stateWrapper = new StateWrapper()
.withStateType(StateType.LEGACY)
.withLegacyState(Jsons.deserialize("{\"json\": \"blob\"}"));
final State expectedState = new State().withState(Jsons.deserialize("{\"json\": \"blob\"}"));

final State convertedState = StateMessageHelper.getState(stateWrapper);
Assertions.assertThat(convertedState).isEqualTo(expectedState);
}

@Test
public void testGlobalStateConversion() {
final StateWrapper stateWrapper = new StateWrapper()
.withStateType(StateType.GLOBAL)
.withGlobal(
new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL).withGlobal(
new AirbyteGlobalState()
.withSharedState(Jsons.deserialize("\"shared\""))
.withStreamStates(Collections.singletonList(
new AirbyteStreamState()
.withStreamDescriptor(new StreamDescriptor().withNamespace("ns").withName("name"))
.withStreamState(Jsons.deserialize("\"stream state\""))))));
final State expectedState = new State().withState(Jsons.deserialize(
"""
[{
"type":"GLOBAL",
"global":{
"shared_state":"shared",
"stream_states":[
{"stream_descriptor":{"name":"name","namespace":"ns"},"stream_state":"stream state"}
]
}
}]
"""));

final State convertedState = StateMessageHelper.getState(stateWrapper);
Assertions.assertThat(convertedState).isEqualTo(expectedState);
}

@Test
public void testStreamStateConversion() {
final StateWrapper stateWrapper = new StateWrapper()
.withStateType(StateType.STREAM)
.withStateMessages(Arrays.asList(
new AirbyteStateMessage().withType(AirbyteStateType.STREAM).withStream(
new AirbyteStreamState()
.withStreamDescriptor(new StreamDescriptor().withNamespace("ns1").withName("name1"))
.withStreamState(Jsons.deserialize("\"state1\""))),
new AirbyteStateMessage().withType(AirbyteStateType.STREAM).withStream(
new AirbyteStreamState()
.withStreamDescriptor(new StreamDescriptor().withNamespace("ns2").withName("name2"))
.withStreamState(Jsons.deserialize("\"state2\"")))));
final State expectedState = new State().withState(Jsons.deserialize(
"""
[
{"type":"STREAM","stream":{"stream_descriptor":{"name":"name1","namespace":"ns1"},"stream_state":"state1"}},
{"type":"STREAM","stream":{"stream_descriptor":{"name":"name2","namespace":"ns2"},"stream_state":"state2"}}
]
"""));

final State convertedState = StateMessageHelper.getState(stateWrapper);
Assertions.assertThat(convertedState).isEqualTo(expectedState);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -757,22 +757,6 @@ public List<DestinationOAuthParameter> listDestinationOAuthParam() throws JsonVa
return persistence.listConfigs(ConfigSchema.DESTINATION_OAUTH_PARAM, DestinationOAuthParameter.class);
}

@Deprecated(forRemoval = true)
// use StatePersistence instead
public Optional<State> getConnectionState(final UUID connectionId) throws IOException {
try {
final StandardSyncState connectionState = persistence.getConfig(
ConfigSchema.STANDARD_SYNC_STATE,
connectionId.toString(),
StandardSyncState.class);
return Optional.of(connectionState.getState());
} catch (final ConfigNotFoundException e) {
return Optional.empty();
} catch (final JsonValidationException e) {
throw new IllegalStateException(e);
}
}

@Deprecated(forRemoval = true)
// use StatePersistence instead
public void updateConnectionState(final UUID connectionId, final State state) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.jooq.Result;
import org.junit.jupiter.api.AfterEach;
Expand Down Expand Up @@ -113,22 +112,6 @@ void testWorkspaceByConnectionId(final boolean isTombstone) throws ConfigNotFoun
verify(configRepository).getStandardWorkspace(WORKSPACE_ID, isTombstone);
}

@Test
void testGetConnectionState() throws Exception {
final UUID connectionId = UUID.randomUUID();
final State state = new State().withState(Jsons.deserialize("{ \"cursor\": 1000 }"));
final StandardSyncState connectionState = new StandardSyncState().withConnectionId(connectionId).withState(state);

when(configPersistence.getConfig(ConfigSchema.STANDARD_SYNC_STATE, connectionId.toString(), StandardSyncState.class))
.thenThrow(new ConfigNotFoundException(ConfigSchema.STANDARD_SYNC_STATE, connectionId));
assertEquals(Optional.empty(), configRepository.getConnectionState(connectionId));

reset(configPersistence);
when(configPersistence.getConfig(ConfigSchema.STANDARD_SYNC_STATE, connectionId.toString(), StandardSyncState.class))
.thenReturn(connectionState);
assertEquals(Optional.of(state), configRepository.getConnectionState(connectionId));
}

@Test
void testUpdateConnectionState() throws Exception {
final UUID connectionId = UUID.randomUUID();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.jooq.JSONB;
Expand Down Expand Up @@ -508,8 +509,15 @@ public void testStatePersistenceLegacyWriteConsistency() throws IOException {
final StateWrapper stateWrapper = new StateWrapper().withStateType(StateType.LEGACY).withLegacyState(jsonState);
statePersistence.updateOrCreateState(connectionId, stateWrapper);

final State readState = configRepository.getConnectionState(connectionId).orElseThrow();
Assertions.assertEquals(readState.getState(), stateWrapper.getLegacyState());
// Making sure we still follow the legacy format
final List<State> readStates = dslContext
.selectFrom("state")
.where(DSL.field("connection_id").eq(connectionId))
.fetch().map(r -> Jsons.deserialize(r.get(DSL.field("state", JSONB.class)).data(), State.class))
.stream().toList();
Assertions.assertEquals(1, readStates.size());

Assertions.assertEquals(readStates.get(0).getState(), stateWrapper.getLegacyState());
}

@BeforeEach
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import io.airbyte.config.StandardSyncOperation;
import io.airbyte.config.State;
import io.airbyte.config.StreamDescriptor;
import io.airbyte.config.persistence.ConfigRepository;
import io.airbyte.config.helpers.StateMessageHelper;
import io.airbyte.config.persistence.StatePersistence;
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.DestinationSyncMode;
import io.airbyte.protocol.models.SyncMode;
Expand All @@ -31,15 +32,15 @@
public class DefaultJobCreator implements JobCreator {

private final JobPersistence jobPersistence;
private final ConfigRepository configRepository;
private final ResourceRequirements workerResourceRequirements;
private final StatePersistence statePersistence;

public DefaultJobCreator(final JobPersistence jobPersistence,
final ConfigRepository configRepository,
final ResourceRequirements workerResourceRequirements) {
final ResourceRequirements workerResourceRequirements,
final StatePersistence statePersistence) {
this.jobPersistence = jobPersistence;
this.configRepository = configRepository;
this.workerResourceRequirements = workerResourceRequirements;
this.statePersistence = statePersistence;
}

@Override
Expand Down Expand Up @@ -126,10 +127,8 @@ public Optional<Long> createResetConnectionJob(final DestinationConnection desti
return jobPersistence.enqueueJob(standardSync.getConnectionId().toString(), jobConfig);
}

// TODO (https://github.com/airbytehq/airbyte/issues/13620): update this method implementation
// to fetch and serialize the new per-stream state format into a State object
private Optional<State> getCurrentConnectionState(final UUID connectionId) throws IOException {
return configRepository.getConnectionState(connectionId);
return statePersistence.getCurrentState(connectionId).map(StateMessageHelper::getState);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
import io.airbyte.config.StandardSyncOperation.OperatorType;
import io.airbyte.config.State;
import io.airbyte.config.StreamDescriptor;
import io.airbyte.config.persistence.ConfigRepository;
import io.airbyte.config.helpers.StateMessageHelper;
import io.airbyte.config.persistence.StatePersistence;
import io.airbyte.protocol.models.CatalogHelpers;
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.ConfiguredAirbyteStream;
Expand Down Expand Up @@ -65,7 +66,7 @@ public class DefaultJobCreatorTest {
private static final StreamDescriptor STREAM_DESCRIPTOR2 = new StreamDescriptor().withName("stream 2").withNamespace("namespace 2");

private JobPersistence jobPersistence;
private ConfigRepository configRepository;
private StatePersistence statePersistence;
private JobCreator jobCreator;
private ResourceRequirements workerResourceRequirements;

Expand Down Expand Up @@ -126,13 +127,13 @@ public class DefaultJobCreatorTest {
@BeforeEach
void setup() {
jobPersistence = mock(JobPersistence.class);
configRepository = mock(ConfigRepository.class);
statePersistence = mock(StatePersistence.class);
workerResourceRequirements = new ResourceRequirements()
.withCpuLimit("0.2")
.withCpuRequest("0.2")
.withMemoryLimit("200Mi")
.withMemoryRequest("200Mi");
jobCreator = new DefaultJobCreator(jobPersistence, configRepository, workerResourceRequirements);
jobCreator = new DefaultJobCreator(jobPersistence, workerResourceRequirements, statePersistence);
}

@Test
Expand Down Expand Up @@ -336,7 +337,8 @@ void testCreateResetConnectionJob() throws IOException {
});

final State connectionState = new State().withState(Jsons.jsonNode(Map.of("key", "val")));
when(configRepository.getConnectionState(STANDARD_SYNC.getConnectionId())).thenReturn(Optional.of(connectionState));
when(statePersistence.getCurrentState(STANDARD_SYNC.getConnectionId()))
.thenReturn(StateMessageHelper.getTypedState(connectionState.getState(), false));

final JobResetConnectionConfig jobResetConnectionConfig = new JobResetConnectionConfig()
.withNamespaceDefinition(STANDARD_SYNC.getNamespaceDefinition())
Expand Down Expand Up @@ -379,7 +381,8 @@ void testCreateResetConnectionJobEnsureNoQueuing() throws IOException {
});

final State connectionState = new State().withState(Jsons.jsonNode(Map.of("key", "val")));
when(configRepository.getConnectionState(STANDARD_SYNC.getConnectionId())).thenReturn(Optional.of(connectionState));
when(statePersistence.getCurrentState(STANDARD_SYNC.getConnectionId()))
.thenReturn(StateMessageHelper.getTypedState(connectionState.getState(), false));

final JobResetConnectionConfig jobResetConnectionConfig = new JobResetConnectionConfig()
.withNamespaceDefinition(STANDARD_SYNC.getNamespaceDefinition())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ public static void setValues(
ConfigurationApiFactory.secretsRepositoryWriter = secretsRepositoryWriter;
ConfigurationApiFactory.synchronousSchedulerClient = synchronousSchedulerClient;
ConfigurationApiFactory.archiveTtlManager = archiveTtlManager;
ConfigurationApiFactory.statePersistence = statePersistence;
ConfigurationApiFactory.mdc = mdc;
ConfigurationApiFactory.configsDatabase = configsDatabase;
ConfigurationApiFactory.jobsDatabase = jobsDatabase;
Expand All @@ -90,6 +89,7 @@ public static void setValues(
ConfigurationApiFactory.eventRunner = eventRunner;
ConfigurationApiFactory.configsFlyway = configsFlyway;
ConfigurationApiFactory.jobsFlyway = jobsFlyway;
ConfigurationApiFactory.statePersistence = statePersistence;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ public ConfigurationApi(final ConfigRepository configRepository,
workerEnvironment,
logConfigs,
eventRunner);

stateHandler = new StateHandler(statePersistence);
connectionsHandler = new ConnectionsHandler(
configRepository,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import io.airbyte.config.persistence.ConfigNotFoundException;
import io.airbyte.config.persistence.ConfigRepository;
import io.airbyte.config.persistence.SecretsRepositoryWriter;
import io.airbyte.config.persistence.StatePersistence;
import io.airbyte.protocol.models.AirbyteCatalog;
import io.airbyte.protocol.models.CatalogHelpers;
import io.airbyte.protocol.models.ConnectorSpecification;
Expand Down Expand Up @@ -123,6 +124,7 @@ class SchedulerHandlerTest {
private JobPersistence jobPersistence;
private EventRunner eventRunner;
private JobConverter jobConverter;
private StatePersistence statePersistence;

@BeforeEach
void setup() {
Expand All @@ -138,6 +140,7 @@ void setup() {
configRepository = mock(ConfigRepository.class);
secretsRepositoryWriter = mock(SecretsRepositoryWriter.class);
jobPersistence = mock(JobPersistence.class);
statePersistence = mock(StatePersistence.class);
eventRunner = mock(EventRunner.class);

jobConverter = spy(new JobConverter(WorkerEnvironment.DOCKER, LogConfigs.EMPTY));
Expand Down
12 changes: 9 additions & 3 deletions airbyte-workers/src/main/java/io/airbyte/workers/WorkerApp.java
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ public class WorkerApp {
private final JobErrorReporter jobErrorReporter;
private final StreamResetPersistence streamResetPersistence;
private final FeatureFlags featureFlags;
private final JobCreator jobCreator;
private final StatePersistence statePersistence;

public void start() {
Expand Down Expand Up @@ -181,7 +182,6 @@ public void start() {
}

private void registerConnectionManager(final WorkerFactory factory) {
final JobCreator jobCreator = new DefaultJobCreator(jobPersistence, configRepository, defaultWorkerConfigs.getResourceRequirements());
final FeatureFlags featureFlags = new EnvVariableFeatureFlags();

final Worker connectionUpdaterWorker =
Expand Down Expand Up @@ -404,6 +404,12 @@ private static void launchWorkerApp(final Configs configs, final DSLContext conf
final Database jobDatabase = new Database(jobsDslContext);

final JobPersistence jobPersistence = new DefaultJobPersistence(jobDatabase);
final StatePersistence statePersistence = new StatePersistence(configDatabase);
final DefaultJobCreator jobCreator = new DefaultJobCreator(
jobPersistence,
defaultWorkerConfigs.getResourceRequirements(),
statePersistence);

TrackingClientSingleton.initialize(
configs.getTrackingStrategy(),
new Deployment(configs.getDeploymentMode(), jobPersistence.getDeployment().orElseThrow(), configs.getWorkerEnvironment()),
Expand All @@ -413,7 +419,7 @@ private static void launchWorkerApp(final Configs configs, final DSLContext conf
final TrackingClient trackingClient = TrackingClientSingleton.get();
final SyncJobFactory jobFactory = new DefaultSyncJobFactory(
configs.connectorSpecificResourceDefaultsEnabled(),
new DefaultJobCreator(jobPersistence, configRepository, defaultWorkerConfigs.getResourceRequirements()),
jobCreator,
configRepository,
new OAuthConfigSupplier(configRepository, trackingClient));

Expand Down Expand Up @@ -450,7 +456,6 @@ private static void launchWorkerApp(final Configs configs, final DSLContext conf

final StreamResetPersistence streamResetPersistence = new StreamResetPersistence(configDatabase);

final StatePersistence statePersistence = new StatePersistence(configDatabase);
new WorkerApp(
workspaceRoot,
defaultProcessFactory,
Expand Down Expand Up @@ -481,6 +486,7 @@ private static void launchWorkerApp(final Configs configs, final DSLContext conf
jobErrorReporter,
streamResetPersistence,
featureFlags,
jobCreator,
statePersistence).start();
}

Expand Down

0 comments on commit 34ed33b

Please sign in to comment.