Skip to content

Commit

Permalink
update milvus connector to support dynamic schema, failed retry, etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
nianliuu committed Oct 24, 2024
1 parent 4406fbc commit 1d72624
Show file tree
Hide file tree
Showing 23 changed files with 1,778 additions and 843 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,25 @@ public static PhysicalColumn of(
String comment,
String sourceType,
Map<String, Object> options) {
return new PhysicalColumn(
name, dataType, columnLength, nullable, defaultValue, comment, sourceType, options);
}

public static PhysicalColumn of(
String name,
SeaTunnelDataType<?> dataType,
Long columnLength,
Integer scale,
boolean nullable,
Object defaultValue,
String comment,
String sourceType,
Map<String, Object> options) {
return new PhysicalColumn(
name,
dataType,
columnLength,
null,
scale,
nullable,
defaultValue,
comment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

Expand All @@ -35,6 +36,8 @@ public final class SeaTunnelRow implements Serializable {

private volatile int size;

private Map<String, Object> options = new HashMap<>();

public SeaTunnelRow(int arity) {
this.fields = new Object[arity];
}
Expand All @@ -55,6 +58,10 @@ public void setRowKind(RowKind rowKind) {
this.rowKind = rowKind;
}

public void setOptions(Map<String, Object> options) {
this.options = options;
}

public int getArity() {
return fields.length;
}
Expand All @@ -67,6 +74,10 @@ public RowKind getRowKind() {
return this.rowKind;
}

public Map<String, Object> getOptions() {
return options;
}

public Object[] getFields() {
return fields;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
* limitations under the License.
*/

package org.apache.seatunnel.connectors.seatunnel.milvus.sink.batch;
package org.apache.seatunnel.common.constants;

import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import lombok.Getter;

public interface MilvusBatchWriter {
@Getter
public enum CommonOptions {
JSON("Json"),
METADATA("Metadata"),
PARTITION("Partition"),;

void addToBatch(SeaTunnelRow element);
private final String name;

boolean needFlush();

boolean flush();

void close();
CommonOptions(String name) {
this.name = name;
}
}
25 changes: 10 additions & 15 deletions seatunnel-connectors-v2/connector-milvus/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,20 @@

<artifactId>connector-milvus</artifactId>
<name>SeaTunnel : Connectors V2 : Milvus</name>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.10.1</version>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<version>2.4.3</version>
<version>2.4.5</version>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
Expand All @@ -42,19 +50,6 @@
</exclusions>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>4.11.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>4.11.0</version>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.seatunnel.api.table.catalog.ConstraintKey;
import org.apache.seatunnel.api.table.catalog.InfoPreviewResult;
import org.apache.seatunnel.api.table.catalog.PreviewResult;
import org.apache.seatunnel.api.table.catalog.PrimaryKey;
import org.apache.seatunnel.api.table.catalog.TablePath;
import org.apache.seatunnel.api.table.catalog.TableSchema;
import org.apache.seatunnel.api.table.catalog.VectorIndex;
Expand All @@ -33,20 +32,21 @@
import org.apache.seatunnel.api.table.catalog.exception.DatabaseNotExistException;
import org.apache.seatunnel.api.table.catalog.exception.TableAlreadyExistException;
import org.apache.seatunnel.api.table.catalog.exception.TableNotExistException;
import org.apache.seatunnel.api.table.type.ArrayType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.common.constants.CommonOptions;
import org.apache.seatunnel.connectors.seatunnel.milvus.config.MilvusSinkConfig;
import org.apache.seatunnel.connectors.seatunnel.milvus.convert.MilvusConvertUtils;
import org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectionErrorCode;
import org.apache.seatunnel.connectors.seatunnel.milvus.exception.MilvusConnectorException;
import org.apache.seatunnel.connectors.seatunnel.milvus.utils.sink.MilvusSinkConverter;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;

import com.google.protobuf.ProtocolStringList;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.grpc.DataType;
import io.milvus.grpc.ListDatabasesResponse;
import io.milvus.grpc.ShowCollectionsResponse;
import io.milvus.grpc.ShowPartitionsResponse;
import io.milvus.grpc.ShowType;
import io.milvus.param.ConnectParam;
import io.milvus.param.IndexType;
Expand All @@ -61,6 +61,8 @@
import io.milvus.param.collection.HasCollectionParam;
import io.milvus.param.collection.ShowCollectionsParam;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.param.partition.CreatePartitionParam;
import io.milvus.param.partition.ShowPartitionsParam;
import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
Expand All @@ -70,6 +72,7 @@
import java.util.Optional;

import static com.google.common.base.Preconditions.checkNotNull;
import static org.apache.seatunnel.connectors.seatunnel.milvus.config.MilvusSinkConfig.CREATE_INDEX;

@Slf4j
public class MilvusCatalog implements Catalog {
Expand Down Expand Up @@ -196,7 +199,8 @@ public void createTable(TablePath tablePath, CatalogTable catalogTable, boolean
checkNotNull(tableSchema, "tableSchema must not be null");
createTableInternal(tablePath, catalogTable);

if (CollectionUtils.isNotEmpty(tableSchema.getConstraintKeys())) {
if (CollectionUtils.isNotEmpty(tableSchema.getConstraintKeys())
&& config.get(CREATE_INDEX)) {
for (ConstraintKey constraintKey : tableSchema.getConstraintKeys()) {
if (constraintKey
.getConstraintType()
Expand Down Expand Up @@ -231,27 +235,61 @@ private void createIndexInternal(

public void createTableInternal(TablePath tablePath, CatalogTable catalogTable) {
try {
Map<String, String> options = catalogTable.getOptions();

// partition key logic
boolean existPartitionKeyField = options.containsKey(MilvusOptions.PARTITION_KEY_FIELD);
String partitionKeyField =
existPartitionKeyField ? options.get(MilvusOptions.PARTITION_KEY_FIELD) : null;
// if options set, will overwrite aut read
if (StringUtils.isNotEmpty(config.get(MilvusSinkConfig.PARTITION_KEY))) {
existPartitionKeyField = true;
partitionKeyField = config.get(MilvusSinkConfig.PARTITION_KEY);
}

TableSchema tableSchema = catalogTable.getTableSchema();
List<FieldType> fieldTypes = new ArrayList<>();
for (Column column : tableSchema.getColumns()) {
fieldTypes.add(convertToFieldType(column, tableSchema.getPrimaryKey()));
if (column.getOptions() != null
&& column.getOptions().containsKey(CommonOptions.METADATA.getName())
&& (Boolean) column.getOptions().get(CommonOptions.METADATA.getName())) {
// skip dynamic field
continue;
}
FieldType fieldType =
MilvusSinkConverter.convertToFieldType(
column,
tableSchema.getPrimaryKey(),
partitionKeyField,
config.get(MilvusSinkConfig.ENABLE_AUTO_ID));
fieldTypes.add(fieldType);
}

Map<String, String> options = catalogTable.getOptions();
Boolean enableDynamicField =
(options.containsKey(MilvusOptions.ENABLE_DYNAMIC_FIELD))
? Boolean.valueOf(options.get(MilvusOptions.ENABLE_DYNAMIC_FIELD))
: config.get(MilvusSinkConfig.ENABLE_DYNAMIC_FIELD);

String collectionDescription = "";
if (config.get(MilvusSinkConfig.COLLECTION_DESCRIPTION) != null
&& config.get(MilvusSinkConfig.COLLECTION_DESCRIPTION)
.containsKey(tablePath.getTableName())) {
// use description from config first
collectionDescription =
config.get(MilvusSinkConfig.COLLECTION_DESCRIPTION)
.get(tablePath.getTableName());
} else if (null != catalogTable.getComment()) {
collectionDescription = catalogTable.getComment();
}
CreateCollectionParam.Builder builder =
CreateCollectionParam.newBuilder()
.withDatabaseName(tablePath.getDatabaseName())
.withCollectionName(tablePath.getTableName())
.withDescription(collectionDescription)
.withFieldTypes(fieldTypes)
.withEnableDynamicField(enableDynamicField)
.withConsistencyLevel(ConsistencyLevelEnum.BOUNDED);
if (null != catalogTable.getComment()) {
builder.withDescription(catalogTable.getComment());
if (StringUtils.isNotEmpty(options.get(MilvusOptions.SHARDS_NUM))) {
builder.withShardsNum(Integer.parseInt(options.get(MilvusOptions.SHARDS_NUM)));
}

CreateCollectionParam createCollectionParam = builder.build();
Expand All @@ -260,89 +298,51 @@ public void createTableInternal(TablePath tablePath, CatalogTable catalogTable)
throw new MilvusConnectorException(
MilvusConnectionErrorCode.CREATE_COLLECTION_ERROR, response.getMessage());
}

// not exist partition key field, will read show partitions to create
if (!existPartitionKeyField && options.containsKey(MilvusOptions.PARTITION_KEY_FIELD)) {
createPartitionInternal(options.get(MilvusOptions.PARTITION_KEY_FIELD), tablePath);
}

} catch (Exception e) {
throw new MilvusConnectorException(
MilvusConnectionErrorCode.CREATE_COLLECTION_ERROR, e);
}
}

private FieldType convertToFieldType(Column column, PrimaryKey primaryKey) {
SeaTunnelDataType<?> seaTunnelDataType = column.getDataType();
FieldType.Builder build =
FieldType.newBuilder()
.withName(column.getName())
.withDataType(
MilvusConvertUtils.convertSqlTypeToDataType(
seaTunnelDataType.getSqlType()));
switch (seaTunnelDataType.getSqlType()) {
case ROW:
build.withMaxLength(65535);
break;
case DATE:
build.withMaxLength(20);
break;
case INT:
build.withDataType(DataType.Int32);
break;
case SMALLINT:
build.withDataType(DataType.Int16);
break;
case TINYINT:
build.withDataType(DataType.Int8);
break;
case FLOAT:
build.withDataType(DataType.Float);
break;
case DOUBLE:
build.withDataType(DataType.Double);
break;
case MAP:
build.withDataType(DataType.JSON);
break;
case BOOLEAN:
build.withDataType(DataType.Bool);
break;
case STRING:
if (column.getColumnLength() == 0) {
build.withMaxLength(512);
} else {
build.withMaxLength((int) (column.getColumnLength() / 4));
}
break;
case ARRAY:
ArrayType arrayType = (ArrayType) column.getDataType();
SeaTunnelDataType elementType = arrayType.getElementType();
build.withElementType(
MilvusConvertUtils.convertSqlTypeToDataType(elementType.getSqlType()));
build.withMaxCapacity(4095);
switch (elementType.getSqlType()) {
case STRING:
if (column.getColumnLength() == 0) {
build.withMaxLength(512);
} else {
build.withMaxLength((int) (column.getColumnLength() / 4));
}
break;
}
break;
case BINARY_VECTOR:
case FLOAT_VECTOR:
case FLOAT16_VECTOR:
case BFLOAT16_VECTOR:
build.withDimension(column.getScale());
break;
private void createPartitionInternal(String partitionNames, TablePath tablePath) {
R<ShowPartitionsResponse> showPartitionsResponseR =
this.client.showPartitions(
ShowPartitionsParam.newBuilder()
.withDatabaseName(tablePath.getDatabaseName())
.withCollectionName(tablePath.getTableName())
.build());
if (!Objects.equals(showPartitionsResponseR.getStatus(), R.success().getStatus())) {
throw new MilvusConnectorException(
MilvusConnectionErrorCode.SHOW_PARTITION_ERROR,
showPartitionsResponseR.getMessage());
}

if (null != primaryKey && primaryKey.getColumnNames().contains(column.getName())) {
build.withPrimaryKey(true);
if (null != primaryKey.getEnableAutoId()) {
build.withAutoID(primaryKey.getEnableAutoId());
} else {
build.withAutoID(config.get(MilvusSinkConfig.ENABLE_AUTO_ID));
ProtocolStringList existPartitionNames =
showPartitionsResponseR.getData().getPartitionNamesList();

// start to loop create partition
String[] partitionNameArray = partitionNames.split(",");
for (String partitionName : partitionNameArray) {
if (existPartitionNames.contains(partitionName)) {
continue;
}
R<RpcStatus> response =
this.client.createPartition(
CreatePartitionParam.newBuilder()
.withDatabaseName(tablePath.getDatabaseName())
.withCollectionName(tablePath.getTableName())
.withPartitionName(partitionName)
.build());
if (!R.success().getStatus().equals(response.getStatus())) {
throw new MilvusConnectorException(
MilvusConnectionErrorCode.CREATE_PARTITION_ERROR, response.getMessage());
}
}

return build.build();
}

@Override
Expand Down
Loading

0 comments on commit 1d72624

Please sign in to comment.