Skip to content

Commit

Permalink
Refactor state management out of BufferStrategy (#13669)
Browse files Browse the repository at this point in the history
Co-authored-by: Edward Gao <[email protected]>
  • Loading branch information
cgardens and edgao authored Jun 11, 2022
1 parent edb74ec commit 0886ee0
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ public class BufferedStreamConsumer extends FailureTrackingAirbyteMessageConsume
private boolean hasStarted;
private boolean hasClosed;

private AirbyteMessage lastFlushedState;
// represents the last state message for which all of it records have been flushed to tmp storage in
// the destination.
private AirbyteMessage lastFlushedToTmpDstState;
// presents the last state message whose state is waiting to be flushed to tmp storage in the
// destination.
private AirbyteMessage pendingState;

public BufferedStreamConsumer(final Consumer<AirbyteMessage> outputRecordCollector,
Expand All @@ -103,7 +107,6 @@ public BufferedStreamConsumer(final Consumer<AirbyteMessage> outputRecordCollect
this.isValidRecord = isValidRecord;
this.streamToIgnoredRecordCount = new HashMap<>();
this.bufferingStrategy = bufferingStrategy;
bufferingStrategy.registerFlushAllEventHook(this::flushQueueToDestination);
}

@Override
Expand Down Expand Up @@ -134,7 +137,11 @@ protected void acceptTracked(final AirbyteMessage message) throws Exception {
return;
}

bufferingStrategy.addRecord(stream, message);
// if the buffer flushes, update the states appropriately.
if (bufferingStrategy.addRecord(stream, message)) {
markStatesAsFlushedToTmpDestination();
}

} else if (message.getType() == Type.STATE) {
pendingState = message;
} else {
Expand All @@ -143,9 +150,9 @@ protected void acceptTracked(final AirbyteMessage message) throws Exception {

}

private void flushQueueToDestination() {
private void markStatesAsFlushedToTmpDestination() {
if (pendingState != null) {
lastFlushedState = pendingState;
lastFlushedToTmpDstState = pendingState;
pendingState = null;
}
}
Expand All @@ -169,13 +176,14 @@ protected void close(final boolean hasFailed) throws Exception {
} else {
LOGGER.info("executing on success close procedure.");
bufferingStrategy.flushAll();
markStatesAsFlushedToTmpDestination();
}
bufferingStrategy.close();

try {
// if no state was emitted (i.e. full refresh), if there were still no failures, then we can
// still succeed.
if (lastFlushedState == null) {
if (lastFlushedToTmpDstState == null) {
onClose.accept(hasFailed);
} else {
// if any state message flushed that means we can still go for at least a partial success.
Expand All @@ -184,8 +192,8 @@ protected void close(final boolean hasFailed) throws Exception {

// if onClose succeeds without exception then we can emit the state record because it means its
// records were not only flushed, but committed.
if (lastFlushedState != null) {
outputRecordCollector.accept(lastFlushedState);
if (lastFlushedToTmpDstState != null) {
outputRecordCollector.accept(lastFlushedToTmpDstState);
}
} catch (final Exception e) {
LOGGER.error("Close failed.", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

package io.airbyte.integrations.destination.record_buffer;

import io.airbyte.commons.concurrency.VoidCallable;
import io.airbyte.integrations.base.AirbyteStreamNameNamespacePair;
import io.airbyte.protocol.models.AirbyteMessage;

Expand All @@ -22,8 +21,13 @@ public interface BufferingStrategy extends AutoCloseable {

/**
* Add a new message to the buffer while consuming streams
*
* @param stream - stream associated with record
* @param message - message to buffer
* @return true if this record cause ALL records in the buffer to flush, otherwise false.
* @throws Exception throw on failure
*/
void addRecord(AirbyteStreamNameNamespacePair stream, AirbyteMessage message) throws Exception;
boolean addRecord(AirbyteStreamNameNamespacePair stream, AirbyteMessage message) throws Exception;

/**
* Flush buffered messages in a writer from a particular stream
Expand All @@ -40,12 +44,4 @@ public interface BufferingStrategy extends AutoCloseable {
*/
void clear() throws Exception;

/**
* When all buffers are being flushed, we can signal some parent function of this event for further
* processing.
*
* THis install such a hook to be triggered when that happens.
*/
void registerFlushAllEventHook(VoidCallable onFlushAllEventHook);

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

package io.airbyte.integrations.destination.record_buffer;

import io.airbyte.commons.concurrency.VoidCallable;
import io.airbyte.integrations.base.AirbyteStreamNameNamespacePair;
import io.airbyte.integrations.base.sentry.AirbyteSentry;
import io.airbyte.integrations.destination.buffered_stream_consumer.CheckAndRemoveRecordWriter;
Expand Down Expand Up @@ -39,7 +38,6 @@ public class InMemoryRecordBufferingStrategy implements BufferingStrategy {
private final RecordSizeEstimator recordSizeEstimator;
private final long maxQueueSizeInBytes;
private long bufferSizeInBytes;
private VoidCallable onFlushAllEventHook;

public InMemoryRecordBufferingStrategy(final RecordWriter<AirbyteRecordMessage> recordWriter,
final long maxQueueSizeInBytes) {
Expand All @@ -55,20 +53,24 @@ public InMemoryRecordBufferingStrategy(final RecordWriter<AirbyteRecordMessage>
this.maxQueueSizeInBytes = maxQueueSizeInBytes;
this.bufferSizeInBytes = 0;
this.recordSizeEstimator = new RecordSizeEstimator();
this.onFlushAllEventHook = null;
}

@Override
public void addRecord(final AirbyteStreamNameNamespacePair stream, final AirbyteMessage message) throws Exception {
public boolean addRecord(final AirbyteStreamNameNamespacePair stream, final AirbyteMessage message) throws Exception {
boolean didFlush = false;

final long messageSizeInBytes = recordSizeEstimator.getEstimatedByteSize(message.getRecord());
if (bufferSizeInBytes + messageSizeInBytes > maxQueueSizeInBytes) {
flushAll();
didFlush = true;
bufferSizeInBytes = 0;
}

final List<AirbyteRecordMessage> bufferedRecords = streamBuffer.computeIfAbsent(stream, k -> new ArrayList<>());
bufferedRecords.add(message.getRecord());
bufferSizeInBytes += messageSizeInBytes;

return didFlush;
}

@Override
Expand All @@ -91,22 +93,13 @@ public void flushAll() throws Exception {
}, Map.of("bufferSizeInBytes", bufferSizeInBytes));
close();
clear();

if (onFlushAllEventHook != null) {
onFlushAllEventHook.call();
}
}

@Override
public void clear() {
streamBuffer = new HashMap<>();
}

@Override
public void registerFlushAllEventHook(final VoidCallable onFlushAllEventHook) {
this.onFlushAllEventHook = onFlushAllEventHook;
}

@Override
public void close() throws Exception {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

package io.airbyte.integrations.destination.record_buffer;

import io.airbyte.commons.concurrency.VoidCallable;
import io.airbyte.commons.functional.CheckedBiConsumer;
import io.airbyte.commons.functional.CheckedBiFunction;
import io.airbyte.commons.string.Strings;
Expand All @@ -27,7 +26,6 @@ public class SerializedBufferingStrategy implements BufferingStrategy {

private final CheckedBiFunction<AirbyteStreamNameNamespacePair, ConfiguredAirbyteCatalog, SerializableBuffer, Exception> onCreateBuffer;
private final CheckedBiConsumer<AirbyteStreamNameNamespacePair, SerializableBuffer, Exception> onStreamFlush;
private VoidCallable onFlushAllEventHook;

private Map<AirbyteStreamNameNamespacePair, SerializableBuffer> allBuffers = new HashMap<>();
private long totalBufferSizeInBytes;
Expand All @@ -40,16 +38,11 @@ public SerializedBufferingStrategy(final CheckedBiFunction<AirbyteStreamNameName
this.catalog = catalog;
this.onStreamFlush = onStreamFlush;
this.totalBufferSizeInBytes = 0;
this.onFlushAllEventHook = null;
}

@Override
public void registerFlushAllEventHook(final VoidCallable onFlushAllEventHook) {
this.onFlushAllEventHook = onFlushAllEventHook;
}

@Override
public void addRecord(final AirbyteStreamNameNamespacePair stream, final AirbyteMessage message) throws Exception {
public boolean addRecord(final AirbyteStreamNameNamespacePair stream, final AirbyteMessage message) throws Exception {
boolean didFlush = false;

final SerializableBuffer streamBuffer = allBuffers.computeIfAbsent(stream, k -> {
LOGGER.info("Starting a new buffer for stream {} (current state: {} in {} buffers)",
Expand All @@ -71,10 +64,28 @@ public void addRecord(final AirbyteStreamNameNamespacePair stream, final Airbyte
if (totalBufferSizeInBytes >= streamBuffer.getMaxTotalBufferSizeInBytes()
|| allBuffers.size() >= streamBuffer.getMaxConcurrentStreamsInBuffer()) {
flushAll();
didFlush = true;
totalBufferSizeInBytes = 0;
} else if (streamBuffer.getByteCount() >= streamBuffer.getMaxPerStreamBufferSizeInBytes()) {
flushWriter(stream, streamBuffer);
/*
* Note: We intentionally do not mark didFlush as true in the branch of this conditional. Because
* this branch flushes individual streams, there is no guaranteee that it will flush records in the
* same order that state messages were received. The outcome here is that records get flushed but
* our updating of which state messages have been flushed falls behind.
*
* This is not ideal from a checkpoint point of view, because it means in the case where there is a
* failure, we will not be able to report that those records that were flushed and committed were
* committed because there corresponding state messages weren't marked as flushed. Thus, it weakens
* checkpointing, but it does not cause a correctness issue.
*
* In non-failure cases, using this conditional branch relies on the state messages getting flushed
* by some other means. That can be caused by the previous branch in this conditional. It is
* guaranteed by the fact that we always flush all state messages at the end of a sync.
*/
}

return didFlush;
}

@Override
Expand All @@ -99,9 +110,6 @@ public void flushAll() throws Exception {
clear();
}, Map.of("bufferSizeInBytes", totalBufferSizeInBytes));

if (onFlushAllEventHook != null) {
onFlushAllEventHook.call();
}
totalBufferSizeInBytes = 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

package io.airbyte.integrations.destination.record_buffer;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.fasterxml.jackson.databind.JsonNode;
import io.airbyte.commons.concurrency.VoidCallable;
import io.airbyte.commons.json.Jsons;
import io.airbyte.integrations.base.AirbyteStreamNameNamespacePair;
import io.airbyte.integrations.destination.buffered_stream_consumer.RecordWriter;
Expand All @@ -25,6 +26,7 @@ public class InMemoryRecordBufferingStrategyTest {
// instances
private static final int MAX_QUEUE_SIZE_IN_BYTES = 130;

@SuppressWarnings("unchecked")
private final RecordWriter<AirbyteRecordMessage> recordWriter = mock(RecordWriter.class);

@Test
Expand All @@ -36,25 +38,19 @@ public void testBuffering() throws Exception {
final AirbyteMessage message2 = generateMessage(stream2);
final AirbyteMessage message3 = generateMessage(stream2);
final AirbyteMessage message4 = generateMessage(stream2);
final VoidCallable hook = mock(VoidCallable.class);
buffering.registerFlushAllEventHook(hook);

buffering.addRecord(stream1, message1);
buffering.addRecord(stream2, message2);
assertFalse(buffering.addRecord(stream1, message1));
assertFalse(buffering.addRecord(stream2, message2));
// Buffer still has room
verify(hook, times(0)).call();

buffering.addRecord(stream2, message3);
assertTrue(buffering.addRecord(stream2, message3));
// Buffer limit reach, flushing all messages so far before adding the new incoming one
verify(hook, times(1)).call();
verify(recordWriter, times(1)).accept(stream1, List.of(message1.getRecord()));
verify(recordWriter, times(1)).accept(stream2, List.of(message2.getRecord()));

buffering.addRecord(stream2, message4);

// force flush to terminate test
buffering.flushAll();
verify(hook, times(2)).call();
verify(recordWriter, times(1)).accept(stream2, List.of(message3.getRecord(), message4.getRecord()));
}

Expand Down
Loading

0 comments on commit 0886ee0

Please sign in to comment.