Skip to content

Commit

Permalink
[SNOW-943288] Do not skip records when we're expecting the offset to …
Browse files Browse the repository at this point in the history
…be reset (#729)
  • Loading branch information
sfc-gh-rcheng authored Nov 9, 2023
1 parent 7edb215 commit c9a3b2c
Show file tree
Hide file tree
Showing 11 changed files with 302 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ public boolean hasSchemaEvolutionPermission(String tableName, String role) {
public void appendColumnsToTable(String tableName, Map<String, String> columnToType) {
checkConnection();
InternalUtils.assertNotEmpty("tableName", tableName);
StringBuilder appendColumnQuery = new StringBuilder("alter table identifier(?) add column if not exists ");
StringBuilder appendColumnQuery =
new StringBuilder("alter table identifier(?) add column if not exists ");
boolean first = true;
StringBuilder logColumn = new StringBuilder("[");
for (String columnName : columnToType.keySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ public class TopicPartitionChannel {
* <li>If channel fails to fetch offsetToken from Snowflake, we reopen the channel and try to
* fetch offset from Snowflake again
* <li>If channel fails to ingest a buffer(Buffer containing rows/offsets), we reopen the
* channel and try to fetch offset from Snowflake again
* channel and try to fetch offset from Snowflake again. Schematization purposefully fails
* the first buffer insert in order to alter the table, and then expects Kafka to resend
* data
* </ol>
*
* <p>In both cases above, we ask Kafka to send back offsets, strictly from offset number after
Expand All @@ -124,7 +126,7 @@ public class TopicPartitionChannel {
* <p>This boolean is used to indicate that we reset offset in kafka and we will only buffer once
* we see the offset which is one more than an offset present in Snowflake.
*/
private boolean isOffsetResetInKafka;
private boolean isOffsetResetInKafka = false;

private final SnowflakeStreamingIngestClient streamingIngestClient;

Expand Down Expand Up @@ -391,14 +393,13 @@ public void insertRecordToBuffer(SinkRecord kafkaSinkRecord) {
*
* @param kafkaSinkRecord Record to check for above condition only in case of failures
* (isOffsetResetInKafka = true)
* @param currentProcessedOffset The current processed offset
* @return true if this record can be skipped to add into buffer, false otherwise.
*/
private boolean shouldIgnoreAddingRecordToBuffer(
SinkRecord kafkaSinkRecord, long currentProcessedOffset) {
// Don't skip rows if there is no offset reset or there is no offset token information in the
// channel
if (!isOffsetResetInKafka
|| currentProcessedOffset == NO_OFFSET_TOKEN_REGISTERED_IN_SNOWFLAKE) {
SinkRecord kafkaSinkRecord, final long currentProcessedOffset) {
// Don't skip rows if there is no offset reset
if (!isOffsetResetInKafka) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import net.snowflake.ingest.streaming.OpenChannelRequest;
import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClient;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaAndValue;
import org.apache.kafka.connect.json.JsonConverter;
import org.apache.kafka.connect.sink.SinkRecord;
import org.junit.After;
import org.junit.Assert;
Expand Down Expand Up @@ -485,4 +488,78 @@ public void testSimpleInsertRowsFailureWithArrowBDECFormat() throws Exception {
service.insert(records);
service.closeAll();
}

@Test
public void testPartialBatchChannelInvalidationIngestion_schematization() throws Exception {
Map<String, String> config = TestUtils.getConfForStreaming();
config.put(
SnowflakeSinkConnectorConfig.BUFFER_COUNT_RECORDS, "500"); // we want to flush on record
config.put(SnowflakeSinkConnectorConfig.BUFFER_FLUSH_TIME_SEC, "500000");
config.put(SnowflakeSinkConnectorConfig.BUFFER_SIZE_BYTES, "500000");
config.put(
SnowflakeSinkConnectorConfig.ENABLE_SCHEMATIZATION_CONFIG,
"true"); // using schematization to invalidate

// setup
InMemorySinkTaskContext inMemorySinkTaskContext =
new InMemorySinkTaskContext(Collections.singleton(topicPartition));
SnowflakeSinkService service =
SnowflakeSinkServiceFactory.builder(conn, IngestionMethodConfig.SNOWPIPE_STREAMING, config)
.setRecordNumber(1)
.setErrorReporter(new InMemoryKafkaRecordErrorReporter())
.setSinkTaskContext(inMemorySinkTaskContext)
.addTask(testTableName, topicPartition)
.build();

final long firstBatchCount = 18;
final long secondBatchCount = 500;

// create 18 blank records that do not kick off schematization
JsonConverter converter = new JsonConverter();
HashMap<String, String> converterConfig = new HashMap<>();
converterConfig.put("schemas.enable", "false");
converter.configure(converterConfig, false);
SchemaAndValue schemaInputValue = converter.toConnectData("test", null);

List<SinkRecord> firstBatch = new ArrayList<>();
for (int i = 0; i < firstBatchCount; i++) {
firstBatch.add(
new SinkRecord(
topic,
PARTITION,
Schema.STRING_SCHEMA,
"test",
schemaInputValue.schema(),
schemaInputValue.value(),
i));
}

service.insert(firstBatch);

// send batch with 500, should kick off a record based flush and schematization on record 19,
// which will fail the batches
List<SinkRecord> secondBatch =
TestUtils.createNativeJsonSinkRecords(firstBatchCount, secondBatchCount, topic, PARTITION);
service.insert(secondBatch);

// resend batch 1 and 2 because 2 failed for schematization
service.insert(firstBatch);
service.insert(secondBatch);

// ensure all data was ingested
TestUtils.assertWithRetry(
() ->
service.getOffset(new TopicPartition(topic, PARTITION))
== firstBatchCount + secondBatchCount,
20,
5);
assert TestUtils.tableSize(testTableName) == firstBatchCount + secondBatchCount
: "expected: "
+ firstBatchCount
+ secondBatchCount
+ " actual: "
+ TestUtils.tableSize(testTableName);

service.closeAll();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -397,15 +397,24 @@ public void testFetchOffsetTokenWithRetry_RuntimeException() {
/* Only SFExceptions goes into fallback -> reopens channel, fetch offsetToken and throws Appropriate exception */
@Test
public void testInsertRows_SuccessAfterReopenChannel() throws Exception {
final int noOfRecords = 5;
int expectedInsertRowsCount = 0;
int expectedOpenChannelCount = 0;
int expectedGetOffsetCount = 0;

// setup mocks to fail first insert and return two null snowflake offsets (open channel and
// failed insert) before succeeding
Mockito.when(
mockStreamingChannel.insertRows(
ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class)))
.thenThrow(SF_EXCEPTION);

// get null from snowflake first time it is called and null for second time too since insert
// rows was failure
Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken()).thenReturn(null);
.thenThrow(SF_EXCEPTION)
.thenReturn(new InsertValidationResponse());
Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken())
.thenReturn(null)
.thenReturn(null)
.thenReturn(Long.toString(noOfRecords - 1));

// create tpchannel
TopicPartitionChannel topicPartitionChannel =
new TopicPartitionChannel(
mockStreamingClient,
Expand All @@ -417,37 +426,47 @@ public void testInsertRows_SuccessAfterReopenChannel() throws Exception {
mockKafkaRecordErrorReporter,
mockSinkTaskContext,
mockTelemetryService);
final int noOfRecords = 5;
// Since record 0 was not able to ingest, all records in this batch will not be added into the
// buffer.
expectedOpenChannelCount++;
expectedGetOffsetCount++;

// verify initial mock counts after tpchannel creation
Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(expectedInsertRowsCount))
.insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class));
Mockito.verify(mockStreamingClient, Mockito.times(expectedOpenChannelCount))
.openChannel(ArgumentMatchers.any());
Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(expectedGetOffsetCount))
.getLatestCommittedOffsetToken();

// Test inserting record 0, which should fail to ingest so the other records are ignored
List<SinkRecord> records =
TestUtils.createJsonStringSinkRecords(0, noOfRecords, TOPIC, PARTITION);

records.forEach(topicPartitionChannel::insertRecordToBuffer);
expectedInsertRowsCount++;
expectedOpenChannelCount++;
expectedGetOffsetCount++;

Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(noOfRecords))
// verify mocks only tried ingesting once
Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(expectedInsertRowsCount))
.insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class));
Mockito.verify(mockStreamingClient, Mockito.times(noOfRecords + 1))
Mockito.verify(mockStreamingClient, Mockito.times(expectedOpenChannelCount))
.openChannel(ArgumentMatchers.any());
Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(noOfRecords + 1))
Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(expectedGetOffsetCount))
.getLatestCommittedOffsetToken();

// Now, it should be successful
Mockito.when(
mockStreamingChannel.insertRows(
ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class)))
.thenReturn(new InsertValidationResponse());

Mockito.when(mockStreamingChannel.getLatestCommittedOffsetToken())
.thenReturn(Long.toString(noOfRecords - 1));

// Retry the insert again, now everything should be ingested and the offset token should be
// noOfRecords-1
records.forEach(topicPartitionChannel::insertRecordToBuffer);
Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(noOfRecords * 2))
.insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class));

Assert.assertEquals(noOfRecords - 1, topicPartitionChannel.fetchOffsetTokenWithRetry());
expectedInsertRowsCount += noOfRecords;
expectedGetOffsetCount++;

// verify mocks ingested each record
Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(expectedInsertRowsCount))
.insertRows(ArgumentMatchers.any(Iterable.class), ArgumentMatchers.any(String.class));
Mockito.verify(mockStreamingClient, Mockito.times(expectedOpenChannelCount))
.openChannel(ArgumentMatchers.any());
Mockito.verify(topicPartitionChannel.getChannel(), Mockito.times(expectedGetOffsetCount))
.getLatestCommittedOffsetToken();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"name": "SNOWFLAKE_CONNECTOR_NAME",
"config": {
"connector.class": "com.snowflake.kafka.connector.SnowflakeSinkConnector",
"topics": "SNOWFLAKE_TEST_TOPIC0,SNOWFLAKE_TEST_TOPIC1",
"snowflake.topic2table.map": "SNOWFLAKE_TEST_TOPIC0:SNOWFLAKE_CONNECTOR_NAME,SNOWFLAKE_TEST_TOPIC1:SNOWFLAKE_CONNECTOR_NAME",
"tasks.max": "1",
"buffer.flush.time": "60",
"buffer.count.records": "300",
"buffer.size.bytes": "5000000",
"snowflake.url.name": "SNOWFLAKE_HOST",
"snowflake.user.name": "SNOWFLAKE_USER",
"snowflake.private.key": "SNOWFLAKE_PRIVATE_KEY",
"snowflake.database.name": "SNOWFLAKE_DATABASE",
"snowflake.schema.name": "SNOWFLAKE_SCHEMA",
"snowflake.role.name": "SNOWFLAKE_ROLE",
"snowflake.ingestion.method": "SNOWPIPE_STREAMING",
"key.converter": "org.apache.kafka.connect.storage.StringConverter",
"value.converter": "org.apache.kafka.connect.json.JsonConverter",
"value.converter.schemas.enable": "false",
"jmx": "true",
"errors.tolerance": "all",
"errors.log.enable": true,
"errors.deadletterqueue.topic.name": "DLQ_TOPIC",
"errors.deadletterqueue.topic.replication.factor": 1,
"snowflake.enable.schematization": true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"topics": "SNOWFLAKE_TEST_TOPIC0,SNOWFLAKE_TEST_TOPIC1",
"snowflake.topic2table.map": "SNOWFLAKE_TEST_TOPIC0:SNOWFLAKE_CONNECTOR_NAME,SNOWFLAKE_TEST_TOPIC1:SNOWFLAKE_CONNECTOR_NAME",
"tasks.max": "1",
"buffer.flush.time": "10",
"buffer.count.records": "100",
"buffer.flush.time": "60",
"buffer.count.records": "300",
"buffer.size.bytes": "5000000",
"snowflake.url.name": "SNOWFLAKE_HOST",
"snowflake.user.name": "SNOWFLAKE_USER",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"topics": "SNOWFLAKE_TEST_TOPIC0,SNOWFLAKE_TEST_TOPIC1",
"snowflake.topic2table.map": "SNOWFLAKE_TEST_TOPIC0:SNOWFLAKE_CONNECTOR_NAME,SNOWFLAKE_TEST_TOPIC1:SNOWFLAKE_CONNECTOR_NAME",
"tasks.max": "1",
"buffer.flush.time": "10",
"buffer.count.records": "100",
"buffer.flush.time": "60",
"buffer.count.records": "300",
"buffer.size.bytes": "5000000",
"snowflake.url.name": "SNOWFLAKE_HOST",
"snowflake.user.name": "SNOWFLAKE_USER",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def __init__(self, driver, nameSalt):
self.fileName = "travis_correct_schema_evolution_w_auto_table_creation_avro_sr"
self.topics = []
self.table = self.fileName + nameSalt
self.recordNum = 100

# records
self.initialRecordCount = 12
self.flushRecordCount = 300
self.recordNum = self.initialRecordCount + self.flushRecordCount

for i in range(2):
self.topics.append(self.table + str(i))
Expand Down Expand Up @@ -78,8 +82,15 @@ def getConfigFileName(self):

def send(self):
for i, topic in enumerate(self.topics):
# send initial batch
value = []
for _ in range(self.initialRecordCount):
value.append(self.records[i])
self.driver.sendAvroSRData(topic, value, self.valueSchema[i], key=[], key_schema="", partition=0)

# send second batch that should flush
value = []
for _ in range(self.recordNum):
for _ in range(self.flushRecordCount):
value.append(self.records[i])
self.driver.sendAvroSRData(topic, value, self.valueSchema[i], key=[], key_schema="", partition=0)

Expand Down
17 changes: 15 additions & 2 deletions test/test_suit/test_schema_evolution_w_auto_table_creation_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def __init__(self, driver, nameSalt):
self.fileName = "travis_correct_schema_evolution_w_auto_table_creation_json"
self.topics = []
self.table = self.fileName + nameSalt
self.recordNum = 100

# records
self.initialRecordCount = 12
self.flushRecordCount = 300
self.recordNum = self.initialRecordCount + self.flushRecordCount

for i in range(2):
self.topics.append(self.table + str(i))
Expand Down Expand Up @@ -48,9 +52,18 @@ def getConfigFileName(self):

def send(self):
for i, topic in enumerate(self.topics):
# send initial batch
key = []
value = []
for e in range(self.initialRecordCount):
key.append(json.dumps({'number': str(e)}).encode('utf-8'))
value.append(json.dumps(self.records[i]).encode('utf-8'))
self.driver.sendBytesData(topic, value, key)

# send second batch that should flush
key = []
value = []
for e in range(self.recordNum):
for e in range(self.flushRecordCount):
key.append(json.dumps({'number': str(e)}).encode('utf-8'))
value.append(json.dumps(self.records[i]).encode('utf-8'))
self.driver.sendBytesData(topic, value, key)
Expand Down
Loading

0 comments on commit c9a3b2c

Please sign in to comment.