From 87943d18a53f56447e314e8103b9178e710f7b08 Mon Sep 17 00:00:00 2001 From: JaySon Date: Wed, 28 Feb 2024 14:58:28 +0800 Subject: [PATCH 1/7] [cherry-pick] compute: Add Vector data type and functions (#153) * Add Vector data type (#141) Signed-off-by: Wish * compute: Add vector functions (#146) Signed-off-by: Wish * compute: Fix vector precision (#147) Signed-off-by: Wish * compute: Fix vector distance over NULLs (#148) Signed-off-by: Wish * [cherry-pick] Storages: Make DMFile ready for new column indexes/types (#149) Signed-off-by: Wish --------- Signed-off-by: Wish Co-authored-by: Wenxuan --- dbms/src/Columns/ColumnArray.cpp | 89 +++- dbms/src/Columns/ColumnArray.h | 9 + dbms/src/Columns/ColumnNullable.cpp | 29 +- dbms/src/Columns/ColumnNullable.h | 4 +- dbms/src/Columns/IColumn.h | 5 + dbms/src/DataTypes/DataTypeArray.h | 2 +- dbms/src/Debug/MockExecutor/AstToPB.cpp | 7 + dbms/src/Debug/dbgTools.cpp | 4 + dbms/src/Flash/Coprocessor/ArrowColCodec.cpp | 83 +++ dbms/src/Flash/Coprocessor/DAGCodec.cpp | 11 + dbms/src/Flash/Coprocessor/DAGCodec.h | 2 + .../Coprocessor/DAGExpressionAnalyzer.cpp | 6 + dbms/src/Flash/Coprocessor/DAGUtils.cpp | 62 +++ dbms/src/Flash/Coprocessor/DAGUtils.h | 1 + dbms/src/Flash/Coprocessor/TiDBColumn.h | 11 +- dbms/src/Functions/FunctionHelpers.h | 11 + dbms/src/Functions/FunctionsVector.cpp | 44 ++ dbms/src/Functions/FunctionsVector.h | 472 ++++++++++++++++++ dbms/src/Functions/registerFunctions.cpp | 2 + dbms/src/Functions/tests/gtest_vector.cpp | 361 ++++++++++++++ .../DeltaMerge/FilterParser/FilterParser.cpp | 1 + dbms/src/TiDB/Decode/DatumCodec.cpp | 48 ++ dbms/src/TiDB/Decode/DatumCodec.h | 4 + dbms/src/TiDB/Decode/RowCodec.cpp | 3 + dbms/src/TiDB/Decode/TypeMapping.cpp | 28 +- dbms/src/TiDB/Decode/Vector.cpp | 198 ++++++++ dbms/src/TiDB/Decode/Vector.h | 68 +++ dbms/src/TiDB/Schema/TiDB.cpp | 3 + dbms/src/TiDB/Schema/TiDB.h | 28 +- 29 files changed, 1570 insertions(+), 26 deletions(-) create mode 100644 dbms/src/Functions/FunctionsVector.cpp create mode 100644 dbms/src/Functions/FunctionsVector.h create mode 100644 dbms/src/Functions/tests/gtest_vector.cpp create mode 100644 dbms/src/TiDB/Decode/Vector.cpp create mode 100644 dbms/src/TiDB/Decode/Vector.h diff --git a/dbms/src/Columns/ColumnArray.cpp b/dbms/src/Columns/ColumnArray.cpp index 63bdc4df1df..1ae4c1bc726 100644 --- a/dbms/src/Columns/ColumnArray.cpp +++ b/dbms/src/Columns/ColumnArray.cpp @@ -25,8 +25,13 @@ #include #include #include +#include +#include +#include #include // memcpy +#include + namespace DB { namespace ErrorCodes @@ -798,10 +803,44 @@ void ColumnArray::getPermutation(bool reverse, size_t limit, int nan_direction_h } } -ColumnPtr ColumnArray::replicateRange(size_t /*start_row*/, size_t /*end_row*/, const IColumn::Offsets & /*offsets*/) +ColumnPtr ColumnArray::replicateRange(size_t start_row, size_t end_row, const IColumn::Offsets & replicate_offsets) const { - throw Exception("not implement.", ErrorCodes::NOT_IMPLEMENTED); + size_t col_size = size(); + if (col_size != replicate_offsets.size()) + throw Exception("Size of offsets doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH); + + // We only support replicate to full column. + RUNTIME_CHECK(start_row == 0, start_row); + RUNTIME_CHECK(end_row == replicate_offsets.size(), end_row, replicate_offsets.size()); + + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateConst(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNullable(replicate_offsets); + return replicateGeneric(replicate_offsets); } @@ -1048,4 +1087,50 @@ void ColumnArray::gather(ColumnGathererStream & gatherer) gatherer.gather(*this); } +bool ColumnArray::decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool /* force_decode */) +{ + RUNTIME_CHECK(raw_value.size() >= cursor + length); + insertFromDatumData(raw_value.c_str() + cursor, length); + return true; +} + +void ColumnArray::insertFromDatumData(const char * data, size_t length) +{ + RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); + + RUNTIME_CHECK(checkAndGetColumn>(&getData())); + RUNTIME_CHECK(getData().isFixedAndContiguous()); + + RUNTIME_CHECK(length >= sizeof(UInt32), length); + auto n = readLittleEndian(data); + data += sizeof(UInt32); + + auto precise_data_size = n * sizeof(Float32); + RUNTIME_CHECK(length >= sizeof(UInt32) + precise_data_size, n, length); + insertData(data, precise_data_size); +} + +size_t ColumnArray::encodeIntoDatumData(size_t element_idx, WriteBuffer & writer) const +{ + RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); + + size_t encoded_size = 0; + + RUNTIME_CHECK(checkAndGetColumn>(&getData())); + RUNTIME_CHECK(getData().isFixedAndContiguous()); + + auto n = static_cast(sizeAt(element_idx)); + + writeIntBinary(n, writer); + encoded_size += sizeof(UInt32); + + auto data = getDataAt(element_idx); + RUNTIME_CHECK(data.size == n * sizeof(Float32)); + writer.write(data.data, data.size); + encoded_size += data.size; + + return encoded_size; +} + + } // namespace DB diff --git a/dbms/src/Columns/ColumnArray.h b/dbms/src/Columns/ColumnArray.h index 564637f3595..91e7e824cdc 100644 --- a/dbms/src/Columns/ColumnArray.h +++ b/dbms/src/Columns/ColumnArray.h @@ -167,6 +167,15 @@ class ColumnArray final : public COWPtrHelper callback(data); } + bool canBeInsideNullable() const override { return true; } + + bool decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t /* length */, bool /* force_decode */) + override; + + void insertFromDatumData(const char * data, size_t length) override; + + size_t encodeIntoDatumData(size_t element_idx, WriteBuffer & writer) const; + private: ColumnPtr data; ColumnPtr offsets; diff --git a/dbms/src/Columns/ColumnNullable.cpp b/dbms/src/Columns/ColumnNullable.cpp index ecdfd60d57e..e814bfeb04d 100644 --- a/dbms/src/Columns/ColumnNullable.cpp +++ b/dbms/src/Columns/ColumnNullable.cpp @@ -204,14 +204,29 @@ void ColumnNullable::get(size_t n, Field & res) const getNestedColumn().get(n, res); } -StringRef ColumnNullable::getDataAt(size_t /*n*/) const +StringRef ColumnNullable::getDataAt(size_t n) const { - throw Exception(fmt::format("Method getDataAt is not supported for {}", getName()), ErrorCodes::NOT_IMPLEMENTED); + if (!isNullAt(n)) + return getNestedColumn().getDataAt(n); + + throw Exception( + ErrorCodes::NOT_IMPLEMENTED, + "Method getDataAt is not supported for {} in case if value is NULL", + getName()); } -void ColumnNullable::insertData(const char * /*pos*/, size_t /*length*/) +void ColumnNullable::insertData(const char * pos, size_t length) { - throw Exception(fmt::format("Method insertData is not supported for {}", getName()), ErrorCodes::NOT_IMPLEMENTED); + if (pos == nullptr) + { + getNestedColumn().insertDefault(); + getNullMapData().push_back(1); + } + else + { + getNestedColumn().insertData(pos, length); + getNullMapData().push_back(0); + } } bool ColumnNullable::decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool force_decode) @@ -222,6 +237,12 @@ bool ColumnNullable::decodeTiDBRowV2Datum(size_t cursor, const String & raw_valu return true; } +void ColumnNullable::insertFromDatumData(const char * cursor, size_t len) +{ + getNestedColumn().insertFromDatumData(cursor, len); + getNullMapData().push_back(0); +} + StringRef ColumnNullable::serializeValueIntoArena( size_t n, Arena & arena, diff --git a/dbms/src/Columns/ColumnNullable.h b/dbms/src/Columns/ColumnNullable.h index 0622ce78a0c..f06b8c30b9b 100644 --- a/dbms/src/Columns/ColumnNullable.h +++ b/dbms/src/Columns/ColumnNullable.h @@ -65,9 +65,11 @@ class ColumnNullable final : public COWPtrHelper Field operator[](size_t n) const override; void get(size_t n, Field & res) const override; UInt64 get64(size_t n) const override { return nested_column->get64(n); } - StringRef getDataAt(size_t n) const override; + StringRef getDataAt(size_t) const override; + /// Will insert null value if pos=nullptr void insertData(const char * pos, size_t length) override; bool decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool force_decode) override; + void insertFromDatumData(const char *, size_t) override; StringRef serializeValueIntoArena( size_t n, Arena & arena, diff --git a/dbms/src/Columns/IColumn.h b/dbms/src/Columns/IColumn.h index 6e5db8abf2c..06a906fdb63 100644 --- a/dbms/src/Columns/IColumn.h +++ b/dbms/src/Columns/IColumn.h @@ -173,6 +173,11 @@ class IColumn : public COWPtr throw Exception("Method decodeTiDBRowV2Datum is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } + virtual void insertFromDatumData(const char *, size_t) + { + throw Exception("Method insertFromDatumData is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); + } + /// Like getData, but has special behavior for columns that contain variable-length strings. /// In this special case inserting data should be zero-ending (i.e. length is 1 byte greater than real string size). virtual void insertDataWithTerminatingZero(const char * pos, size_t length) { insertData(pos, length); } diff --git a/dbms/src/DataTypes/DataTypeArray.h b/dbms/src/DataTypes/DataTypeArray.h index 4c7572fdc74..e0083bf1e0a 100644 --- a/dbms/src/DataTypes/DataTypeArray.h +++ b/dbms/src/DataTypes/DataTypeArray.h @@ -34,7 +34,7 @@ class DataTypeArray final : public IDataType const char * getFamilyName() const override { return "Array"; } - bool canBeInsideNullable() const override { return false; } + bool canBeInsideNullable() const override { return true; } TypeIndex getTypeId() const override { return TypeIndex::Array; } diff --git a/dbms/src/Debug/MockExecutor/AstToPB.cpp b/dbms/src/Debug/MockExecutor/AstToPB.cpp index d588da3621a..89b25a3fd85 100644 --- a/dbms/src/Debug/MockExecutor/AstToPB.cpp +++ b/dbms/src/Debug/MockExecutor/AstToPB.cpp @@ -110,6 +110,13 @@ void literalFieldToTiPBExpr(const ColumnInfo & ci, const Field & val_field, tipb encodeDAGInt64(val, ss); break; } + case TiDB::TypeTiDBVectorFloat32: + { + expr->set_tp(tipb::ExprType::TiDBVectorFloat32); + const auto & val = val_field.safeGet(); + encodeDAGVectorFloat32(val, ss); + break; + } default: throw Exception(fmt::format( "Type {} does not support literal in function unit test", diff --git a/dbms/src/Debug/dbgTools.cpp b/dbms/src/Debug/dbgTools.cpp index ee1e62bfc98..878b10c0d51 100644 --- a/dbms/src/Debug/dbgTools.cpp +++ b/dbms/src/Debug/dbgTools.cpp @@ -470,6 +470,10 @@ struct BatchCtrl throw Exception( "Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagJson", ErrorCodes::LOGICAL_ERROR); + case TiDB::CodecFlagVectorFloat32: + throw Exception( + "Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagVectorFloat32", + ErrorCodes::LOGICAL_ERROR); case TiDB::CodecFlagMax: throw Exception("Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagMax", ErrorCodes::LOGICAL_ERROR); case TiDB::CodecFlagDuration: diff --git a/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp b/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp index 901ca3ebad0..83a411a6181 100644 --- a/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp +++ b/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include #include #include +#include #include #include #include @@ -292,6 +294,34 @@ void flashStringColToArrowCol( } } +template +void flashArrayFloat32ColToArrowCol( + TiDBColumn & dag_column, + const IColumn * flash_col_untyped, + size_t start_index, + size_t end_index) +{ + // We only unwrap the NULLABLE() part. + const IColumn * nested_col = getNestedCol(flash_col_untyped); + const auto * flash_col = checkAndGetColumn(nested_col); + for (size_t i = start_index; i < end_index; i++) + { + // todo check if we can convert flash_col to DAG col directly since the internal representation is almost the same + if constexpr (is_nullable) + { + if (flash_col_untyped->isNullAt(i)) + { + dag_column.appendNull(); + continue; + } + } + + auto encoded_size = flash_col->encodeIntoDatumData(i, *dag_column.data); + RUNTIME_CHECK(encoded_size > 0); + dag_column.finishAppendVar(encoded_size); + } +} + template void flashBitColToArrowCol( TiDBColumn & dag_column, @@ -461,6 +491,20 @@ void flashColToArrowCol( else flashStringColToArrowCol(dag_column, col, start_index, end_index); break; + case TiDB::TypeTiDBVectorFloat32: + { + const auto * data_type = checkAndGetDataType(type); + if (!data_type || data_type->getNestedType()->getTypeId() != TypeIndex::Float32) + throw TiFlashException( + "Type un-matched during arrow encode, target col type is array and source column type is " + + type->getName(), + Errors::Coprocessor::Internal); + if (tidb_column_info.hasNotNullFlag()) + flashArrayFloat32ColToArrowCol(dag_column, col, start_index, end_index); + else + flashArrayFloat32ColToArrowCol(dag_column, col, start_index, end_index); + break; + } case TiDB::TypeBit: if (!checkDataType(type)) throw TiFlashException( @@ -525,6 +569,35 @@ const char * arrowStringColToFlashCol( return pos + offsets[length]; } +const char * arrowArrayFloat32ColToFlashCol( + const char * pos, + UInt8, + UInt32 null_count, + const std::vector & null_bitmap, + const std::vector & offsets, + const ColumnWithTypeAndName & col, + const ColumnInfo &, + UInt32 length) +{ + const auto * data_type = checkAndGetDataType(&*col.type); + if (!data_type || data_type->getNestedType()->getTypeId() != TypeIndex::Float32) + throw TiFlashException( + "Type un-matched during arrow decode, target col type is array and source column type is " + + col.type->getName(), + Errors::Coprocessor::Internal); + + for (UInt32 i = 0; i < length; i++) + { + if (checkNull(i, null_count, null_bitmap, col)) + continue; + + auto arrow_data_size = offsets[i + 1] - offsets[i]; + const auto * base_offset = pos + offsets[i]; + col.column->assumeMutable()->insertFromDatumData(base_offset, arrow_data_size); + } + return pos + offsets[length]; +} + const char * arrowEnumColToFlashCol( const char * pos, UInt8, @@ -819,6 +892,16 @@ const char * arrowColToFlashCol( length); case TiDB::TypeBit: return arrowBitColToFlashCol(pos, field_length, null_count, null_bitmap, offsets, flash_col, col_info, length); + case TiDB::TypeTiDBVectorFloat32: + return arrowArrayFloat32ColToFlashCol( + pos, + field_length, + null_count, + null_bitmap, + offsets, + flash_col, + col_info, + length); case TiDB::TypeEnum: return arrowEnumColToFlashCol(pos, field_length, null_count, null_bitmap, offsets, flash_col, col_info, length); default: diff --git a/dbms/src/Flash/Coprocessor/DAGCodec.cpp b/dbms/src/Flash/Coprocessor/DAGCodec.cpp index ef8dc4d7c2e..2b3e7ce10ec 100644 --- a/dbms/src/Flash/Coprocessor/DAGCodec.cpp +++ b/dbms/src/Flash/Coprocessor/DAGCodec.cpp @@ -53,6 +53,11 @@ void encodeDAGDecimal(const Field & field, WriteBuffer & ss) EncodeDecimal(field, ss); } +void encodeDAGVectorFloat32(const Array & v, WriteBuffer & ss) +{ + EncodeVectorFloat32(v, ss); +} + Int64 decodeDAGInt64(const String & s) { auto u = *(reinterpret_cast(s.data())); @@ -93,4 +98,10 @@ Field decodeDAGDecimal(const String & s) return DecodeDecimal(cursor, s); } +Field decodeDAGVectorFloat32(const String & s) +{ + size_t cursor = 0; + return DecodeVectorFloat32(cursor, s); +} + } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGCodec.h b/dbms/src/Flash/Coprocessor/DAGCodec.h index 66cf4b83eda..e0fb33b703c 100644 --- a/dbms/src/Flash/Coprocessor/DAGCodec.h +++ b/dbms/src/Flash/Coprocessor/DAGCodec.h @@ -26,6 +26,7 @@ void encodeDAGFloat64(Float64, WriteBuffer &); void encodeDAGString(const String &, WriteBuffer &); void encodeDAGBytes(const String &, WriteBuffer &); void encodeDAGDecimal(const Field &, WriteBuffer &); +void encodeDAGVectorFloat32(const Array &, WriteBuffer &); Int64 decodeDAGInt64(const String &); UInt64 decodeDAGUInt64(const String &); @@ -34,5 +35,6 @@ Float64 decodeDAGFloat64(const String &); String decodeDAGString(const String &); String decodeDAGBytes(const String &); Field decodeDAGDecimal(const String &); +Field decodeDAGVectorFloat32(const String &); } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index 17e72774a17..b495a10a0e1 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -1050,6 +1050,12 @@ String DAGExpressionAnalyzer::convertToUInt8(const ExpressionActionsPtr & action auto const_expr_name = getActions(const_expr, actions); return applyFunction("notEquals", {column_name, const_expr_name}, actions, nullptr); } + else if (checkDataTypeArray(org_type.get())) + { + tipb::Expr const_expr = constructZeroVectorFloat32TiExpr(); + auto const_expr_name = getActions(const_expr, actions); + return applyFunction("notEquals", {column_name, const_expr_name}, actions, nullptr); + } throw TiFlashException( fmt::format("Filter on {} is not supported.", org_type->getName()), Errors::Coprocessor::Unimplemented); diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 9cd748e1b3b..4c3e7846bb1 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -21,8 +21,10 @@ #include #include #include +#include #include #include +#include #include @@ -129,6 +131,9 @@ const std::unordered_map scalar_func_map({ //{tipb::ScalarFuncSig::CastJsonAsDuration, "cast"}, {tipb::ScalarFuncSig::CastJsonAsJson, "cast_json_as_json"}, + {tipb::ScalarFuncSig::CastVectorFloat32AsString, "cast_vector_float32_as_string"}, + {tipb::ScalarFuncSig::CastVectorFloat32AsVectorFloat32, "cast_vector_float32_as_vector_float32"}, + {tipb::ScalarFuncSig::CoalesceInt, "coalesce"}, {tipb::ScalarFuncSig::CoalesceReal, "coalesce"}, {tipb::ScalarFuncSig::CoalesceString, "coalesce"}, @@ -144,6 +149,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::LTTime, "less"}, {tipb::ScalarFuncSig::LTDuration, "less"}, {tipb::ScalarFuncSig::LTJson, "less"}, + {tipb::ScalarFuncSig::LTVectorFloat32, "less"}, {tipb::ScalarFuncSig::LEInt, "lessOrEquals"}, {tipb::ScalarFuncSig::LEReal, "lessOrEquals"}, @@ -152,6 +158,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::LETime, "lessOrEquals"}, {tipb::ScalarFuncSig::LEDuration, "lessOrEquals"}, {tipb::ScalarFuncSig::LEJson, "lessOrEquals"}, + {tipb::ScalarFuncSig::LEVectorFloat32, "lessOrEquals"}, {tipb::ScalarFuncSig::GTInt, "greater"}, {tipb::ScalarFuncSig::GTReal, "greater"}, @@ -160,6 +167,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::GTTime, "greater"}, {tipb::ScalarFuncSig::GTDuration, "greater"}, {tipb::ScalarFuncSig::GTJson, "greater"}, + {tipb::ScalarFuncSig::GTVectorFloat32, "greater"}, {tipb::ScalarFuncSig::GreatestInt, "tidbGreatest"}, {tipb::ScalarFuncSig::GreatestReal, "tidbGreatest"}, @@ -183,6 +191,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::GETime, "greaterOrEquals"}, {tipb::ScalarFuncSig::GEDuration, "greaterOrEquals"}, {tipb::ScalarFuncSig::GEJson, "greaterOrEquals"}, + {tipb::ScalarFuncSig::GEVectorFloat32, "greaterOrEquals"}, {tipb::ScalarFuncSig::EQInt, "equals"}, {tipb::ScalarFuncSig::EQReal, "equals"}, @@ -191,6 +200,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::EQTime, "equals"}, {tipb::ScalarFuncSig::EQDuration, "equals"}, {tipb::ScalarFuncSig::EQJson, "equals"}, + {tipb::ScalarFuncSig::EQVectorFloat32, "equals"}, {tipb::ScalarFuncSig::NEInt, "notEquals"}, {tipb::ScalarFuncSig::NEReal, "notEquals"}, @@ -199,6 +209,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::NETime, "notEquals"}, {tipb::ScalarFuncSig::NEDuration, "notEquals"}, {tipb::ScalarFuncSig::NEJson, "notEquals"}, + {tipb::ScalarFuncSig::NEVectorFloat32, "notEquals"}, //{tipb::ScalarFuncSig::NullEQInt, "cast"}, //{tipb::ScalarFuncSig::NullEQReal, "cast"}, @@ -316,6 +327,7 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::TimeIsNull, "isNull"}, {tipb::ScalarFuncSig::IntIsNull, "isNull"}, {tipb::ScalarFuncSig::JsonIsNull, "isNull"}, + {tipb::ScalarFuncSig::VectorFloat32IsNull, "isNull"}, {tipb::ScalarFuncSig::BitAndSig, "bitAnd"}, {tipb::ScalarFuncSig::BitOrSig, "bitOr"}, @@ -686,6 +698,14 @@ const std::unordered_map scalar_func_map({ //{tipb::ScalarFuncSig::CharLength, "upper"}, {tipb::ScalarFuncSig::GroupingSig, "grouping"}, + + {tipb::ScalarFuncSig::VecAsTextSig, "vecAsText"}, + {tipb::ScalarFuncSig::VecDimsSig, "vecDims"}, + {tipb::ScalarFuncSig::VecL1DistanceSig, "vecL1Distance"}, + {tipb::ScalarFuncSig::VecL2DistanceSig, "vecL2Distance"}, + {tipb::ScalarFuncSig::VecNegativeInnerProductSig, "vecNegativeInnerProduct"}, + {tipb::ScalarFuncSig::VecCosineDistanceSig, "vecCosineDistance"}, + {tipb::ScalarFuncSig::VecL2NormSig, "vecL2Norm"}, }); template @@ -948,6 +968,24 @@ String exprToString(const tipb::Expr & expr, const std::vector = std::to_string(TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field().get()); return ret; } + case tipb::ExprType::TiDBVectorFloat32: + { + if (!expr.has_field_type()) + throw TiFlashException( + "MySQL Duration literal without field_type" + expr.DebugString(), + Errors::Coprocessor::BadRequest); + auto t = decodeDAGVectorFloat32(expr.val()); + auto arr = t.safeGet(); + String ret = "["; + for (size_t i = 0; i < arr.size(); ++i) + { + if (i > 0) + ret += ","; + ret += std::to_string(arr[i].safeGet::Type>()); + } + ret += "]"; + return ret; + } case tipb::ExprType::ColumnRef: return getColumnNameForColumnExpr(expr, input_col); case tipb::ExprType::Count: @@ -1085,6 +1123,7 @@ bool isLiteralExpr(const tipb::Expr & expr) case tipb::ExprType::MysqlTime: case tipb::ExprType::MysqlJson: case tipb::ExprType::ValueList: + case tipb::ExprType::TiDBVectorFloat32: return true; default: return false; @@ -1134,6 +1173,14 @@ Field decodeLiteral(const tipb::Expr & expr) auto t = decodeDAGInt64(expr.val()); return TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field(); } + case tipb::ExprType::TiDBVectorFloat32: + { + if (!expr.has_field_type()) + throw TiFlashException( + "MySQL Duration literal without field_type" + expr.DebugString(), + Errors::Coprocessor::BadRequest); + return decodeDAGVectorFloat32(expr.val()); + } case tipb::ExprType::MysqlBit: case tipb::ExprType::MysqlEnum: case tipb::ExprType::MysqlHex: @@ -1324,6 +1371,7 @@ UInt8 getFieldLengthForArrowEncode(Int32 tp) case TiDB::TypeBit: case TiDB::TypeEnum: case TiDB::TypeJSON: + case TiDB::TypeTiDBVectorFloat32: return VAR_SIZE; default: throw TiFlashException( @@ -1378,6 +1426,20 @@ tipb::Expr constructNULLLiteralTiExpr() return expr; } +tipb::Expr constructZeroVectorFloat32TiExpr() +{ + RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); + tipb::Expr expr; + expr.set_tp(tipb::ExprType::TiDBVectorFloat32); + WriteBufferFromOwnString ss; + writeIntBinary(static_cast(0), ss); + expr.set_val(ss.releaseStr()); + auto * field_type = expr.mutable_field_type(); + field_type->set_tp(TiDB::TypeTiDBVectorFloat32); + field_type->set_flag(TiDB::ColumnFlagNotNull); + return expr; +} + TiDB::TiDBCollatorPtr getCollatorFromExpr(const tipb::Expr & expr) { if (expr.has_field_type()) diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.h b/dbms/src/Flash/Coprocessor/DAGUtils.h index 9ab68493ad7..31124b45cf7 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.h +++ b/dbms/src/Flash/Coprocessor/DAGUtils.h @@ -61,6 +61,7 @@ tipb::Expr constructStringLiteralTiExpr(const String & value); tipb::Expr constructInt64LiteralTiExpr(Int64 value); tipb::Expr constructDateTimeLiteralTiExpr(UInt64 packed_value); tipb::Expr constructNULLLiteralTiExpr(); +tipb::Expr constructZeroVectorFloat32TiExpr(); DataTypePtr inferDataType4Literal(const tipb::Expr & expr); SortDescription getSortDescription( const std::vector & order_columns, diff --git a/dbms/src/Flash/Coprocessor/TiDBColumn.h b/dbms/src/Flash/Coprocessor/TiDBColumn.h index bb29a09a954..9e66db40b82 100644 --- a/dbms/src/Flash/Coprocessor/TiDBColumn.h +++ b/dbms/src/Flash/Coprocessor/TiDBColumn.h @@ -27,7 +27,7 @@ namespace DB class TiDBColumn { public: - TiDBColumn(Int8 element_len); + explicit TiDBColumn(Int8 element_len); void appendNull(); void append(Int64 value); @@ -44,10 +44,13 @@ class TiDBColumn void encodeColumn(WriteBuffer & ss); void clear(); -private: - bool isFixed() { return fixed_size != VAR_SIZE; }; + std::unique_ptr data; void finishAppendFixed(); void finishAppendVar(UInt32 size); + +private: + bool isFixed() const { return fixed_size != VAR_SIZE; }; + void appendNullBitMap(bool value); UInt32 length; @@ -55,7 +58,7 @@ class TiDBColumn std::vector null_bitmap; std::vector var_offsets; // WriteBufferFromOwnString is not moveable. - std::unique_ptr data; + std::string default_value; UInt64 current_data_size; Int8 fixed_size; diff --git a/dbms/src/Functions/FunctionHelpers.h b/dbms/src/Functions/FunctionHelpers.h index 77c5f790a79..7b9b89685de 100644 --- a/dbms/src/Functions/FunctionHelpers.h +++ b/dbms/src/Functions/FunctionHelpers.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -42,6 +43,16 @@ bool checkDataType(const IDataType * data_type) return checkAndGetDataType(data_type); } +template +bool checkDataTypeArray(const IDataType * data_type) +{ + const auto * array_type = checkAndGetDataType(data_type); + if unlikely (!array_type) + return false; + + const DataTypePtr & inner_type = array_type->getNestedType(); + return checkDataType(inner_type.get()); +} template const Type * checkAndGetColumn(const IColumn * column) diff --git a/dbms/src/Functions/FunctionsVector.cpp b/dbms/src/Functions/FunctionsVector.cpp new file mode 100644 index 00000000000..795b679cf79 --- /dev/null +++ b/dbms/src/Functions/FunctionsVector.cpp @@ -0,0 +1,44 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace DB +{ + +void registerFunctionsVector(FunctionFactory & factory) +{ + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); +} + +} // namespace DB diff --git a/dbms/src/Functions/FunctionsVector.h b/dbms/src/Functions/FunctionsVector.h new file mode 100644 index 00000000000..2e830338952 --- /dev/null +++ b/dbms/src/Functions/FunctionsVector.h @@ -0,0 +1,472 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int ILLEGAL_COLUMN; +} + +class FunctionsCastVectorFloat32AsString : public IFunction +{ +public: + static constexpr auto name = "cast_vector_float32_as_string"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(arguments[0].get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_result = ColumnString::create(); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + RUNTIME_CHECK(!col_a->isNullAt(i)); + auto data = col_a->getDataAt(i); + auto v = VectorFloat32Ref(data); + col_result->insert(v.toString()); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsCastVectorFloat32AsVectorFloat32 : public IFunction +{ +public: + static constexpr auto name = "cast_vector_float32_as_vector_float32"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(arguments[0].get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_result = ColumnArray::create(ColumnFloat32::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + RUNTIME_CHECK(!col_a->isNullAt(i)); + auto data = col_a->getDataAt(i); + auto v = VectorFloat32Ref(data); // Still construct a VectorFloat32Ref to do sanity checks + UNUSED(v); + col_result->insertData(data.data, data.size); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecAsText : public IFunction +{ +public: + static constexpr auto name = "vecAsText"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(arguments[0].get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_result = ColumnString::create(); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + RUNTIME_CHECK(!col_a->isNullAt(i)); + auto data = col_a->getDataAt(i); + auto v = VectorFloat32Ref(data); + col_result->insert(v.toString()); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecDims : public IFunction +{ +public: + static constexpr auto name = "vecDims"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(arguments[0].get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_result = ColumnUInt32::create(); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + RUNTIME_CHECK(!col_a->isNullAt(i)); + auto data = col_a->getDataAt(i); + auto v = VectorFloat32Ref(data); + col_result->insert(v.size()); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecL1Distance : public IFunction +{ +public: + static constexpr auto name = "vecL1Distance"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForConstants() const override { return true; } + + // Calculating vectors with different dimensions is disallowed, so that we cannot use the default impl. + bool useDefaultImplementationForNulls() const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(removeNullable(arguments[0]).get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if unlikely (!checkDataTypeArray(removeNullable(arguments[1]).get())) + throw Exception( + "Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return makeNullable(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_b = block.safeGetByPosition(arguments[1]).column; + auto col_result = ColumnNullable::create(ColumnFloat64::create(), ColumnUInt8::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + if (col_a->isNullAt(i) || col_b->isNullAt(i)) + { + col_result->insertDefault(); + continue; + } + + auto v1 = VectorFloat32Ref(col_a->getDataAt(i)); + auto v2 = VectorFloat32Ref(col_b->getDataAt(i)); + auto d = v1.l1Distance(v2); + if (std::isnan(d)) + col_result->insertDefault(); + else + col_result->insert(d); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecL2Distance : public IFunction +{ +public: + static constexpr auto name = "vecL2Distance"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForConstants() const override { return true; } + + // Calculating vectors with different dimensions is disallowed, so that we cannot use the default impl. + bool useDefaultImplementationForNulls() const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(removeNullable(arguments[0]).get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if unlikely (!checkDataTypeArray(removeNullable(arguments[1]).get())) + throw Exception( + "Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return makeNullable(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_b = block.safeGetByPosition(arguments[1]).column; + auto col_result = ColumnNullable::create(ColumnFloat64::create(), ColumnUInt8::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + if (col_a->isNullAt(i) || col_b->isNullAt(i)) + { + col_result->insertDefault(); + continue; + } + + auto v1 = VectorFloat32Ref(col_a->getDataAt(i)); + auto v2 = VectorFloat32Ref(col_b->getDataAt(i)); + auto d = v1.l2Distance(v2); + + if (std::isnan(d)) + col_result->insertDefault(); + else + col_result->insert(d); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecCosineDistance : public IFunction +{ +public: + static constexpr auto name = "vecCosineDistance"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForConstants() const override { return true; } + + // Calculating vectors with different dimensions is disallowed, so that we cannot use the default impl. + bool useDefaultImplementationForNulls() const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(removeNullable(arguments[0]).get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if unlikely (!checkDataTypeArray(removeNullable(arguments[1]).get())) + throw Exception( + "Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return makeNullable(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_b = block.safeGetByPosition(arguments[1]).column; + auto col_result = ColumnNullable::create(ColumnFloat64::create(), ColumnUInt8::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + if (col_a->isNullAt(i) || col_b->isNullAt(i)) + { + col_result->insertDefault(); + continue; + } + + auto v1 = VectorFloat32Ref(col_a->getDataAt(i)); + auto v2 = VectorFloat32Ref(col_b->getDataAt(i)); + auto d = v1.cosineDistance(v2); + if (std::isnan(d)) + col_result->insertDefault(); + else + col_result->insert(d); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecNegativeInnerProduct : public IFunction +{ +public: + static constexpr auto name = "vecNegativeInnerProduct"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForConstants() const override { return true; } + + // Calculating vectors with different dimensions is disallowed, so that we cannot use the default impl. + bool useDefaultImplementationForNulls() const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(removeNullable(arguments[0]).get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if unlikely (!checkDataTypeArray(removeNullable(arguments[1]).get())) + throw Exception( + "Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return makeNullable(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_b = block.safeGetByPosition(arguments[1]).column; + auto col_result = ColumnNullable::create(ColumnFloat64::create(), ColumnUInt8::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + if (col_a->isNullAt(i) || col_b->isNullAt(i)) + { + col_result->insertDefault(); + continue; + } + + auto v1 = VectorFloat32Ref(col_a->getDataAt(i)); + auto v2 = VectorFloat32Ref(col_b->getDataAt(i)); + auto d = v1.innerProduct(v2) * -1; + if (std::isnan(d)) + col_result->insertDefault(); + else + col_result->insert(d); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +class FunctionsVecL2Norm : public IFunction +{ +public: + static constexpr auto name = "vecL2Norm"; + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 1; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if unlikely (!checkDataTypeArray(arguments[0].get())) + throw Exception( + "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return makeNullable(std::make_shared()); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + auto col_a = block.safeGetByPosition(arguments[0]).column; + auto col_result = ColumnNullable::create(ColumnFloat64::create(), ColumnUInt8::create()); + col_result->reserve(block.rows()); + + for (size_t i = 0; i < block.rows(); ++i) + { + RUNTIME_CHECK(!col_a->isNullAt(i)); + auto v1 = VectorFloat32Ref(col_a->getDataAt(i)); + auto d = v1.l2Norm(); + if (std::isnan(d)) + col_result->insertDefault(); + else + col_result->insert(d); + } + + block.safeGetByPosition(result).column = std::move(col_result); + } +}; + +} // namespace DB diff --git a/dbms/src/Functions/registerFunctions.cpp b/dbms/src/Functions/registerFunctions.cpp index 6957f887805..cb43ae7760d 100644 --- a/dbms/src/Functions/registerFunctions.cpp +++ b/dbms/src/Functions/registerFunctions.cpp @@ -50,6 +50,7 @@ void registerFunctionsRegexpInstr(FunctionFactory &); void registerFunctionsRegexpSubstr(FunctionFactory &); void registerFunctionsRegexpReplace(FunctionFactory &); void registerFunctionsGrouping(FunctionFactory &); +void registerFunctionsVector(FunctionFactory &); void registerFunctions() { @@ -83,6 +84,7 @@ void registerFunctions() registerFunctionsJson(factory); registerFunctionsIsIPAddr(factory); registerFunctionsGrouping(factory); + registerFunctionsVector(factory); } } // namespace DB diff --git a/dbms/src/Functions/tests/gtest_vector.cpp b/dbms/src/Functions/tests/gtest_vector.cpp new file mode 100644 index 00000000000..d67eb683540 --- /dev/null +++ b/dbms/src/Functions/tests/gtest_vector.cpp @@ -0,0 +1,361 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include + +namespace DB +{ +namespace tests +{ +class Vector : public DB::tests::FunctionTest +{ +}; + +TEST_F(Vector, Dims) +try +{ + // Fn(Column) + ASSERT_COLUMN_EQ( + createColumn>({0, 2, 3, std::nullopt}), + executeFunction( + "vecDims", + createColumn>( + std::make_tuple(std::make_shared()), // + {Array{}, Array{1.0, 2.0}, Array{1.0, 2.0, 3.0}, std::nullopt}))); + + // Fn(Column) + ASSERT_COLUMN_EQ( + createColumn({0, 2, 3}), + executeFunction( + "vecDims", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{}, Array{1.0, 2.0}, Array{1.0, 2.0, 3.0}}))); + + // Fn(Const) + ASSERT_COLUMN_EQ( + createConstColumn(3, 2), + executeFunction( + "vecDims", + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, // + Array{1.0, 2.0}))); +} +CATCH + +TEST_F(Vector, L2Norm) +try +{ + // Fn(Column) + ASSERT_COLUMN_EQ( + createColumn>({0.0, 5.0, 1.0}), + executeFunction( + "vecL2Norm", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{}, Array{3.0, 4.0}, Array{0.0, 1.0}}))); + + // Fn(Column) + ASSERT_COLUMN_EQ( + createColumn>({0.0, 5.0, 1.0, std::nullopt}), + executeFunction( + "vecL2Norm", + createColumn>( + std::make_tuple(std::make_shared()), // + {Array{}, Array{3.0, 4.0}, Array{0.0, 1.0}, std::nullopt}))); + + // Fn(Const) + ASSERT_COLUMN_EQ( + createConstColumn>(3, 5.0), + executeFunction( + "vecL2Norm", + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, // + Array{3.0, 4.0}))); +} +CATCH + +TEST_F(Vector, L2Distance) +try +{ + // Fn(NullableColumn, Column) + ASSERT_COLUMN_EQ( + createColumn>({5.0, 1.0, INFINITY, std::nullopt}), + executeFunction( + "vecL2Distance", + createColumn>( + std::make_tuple(std::make_shared()), // + {Array{0.0, 0.0}, Array{0.0, 0.0}, Array{3e38}, std::nullopt}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0, 4.0}, Array{0.0, 1.0}, Array{-3e38}, Array{1}}))); + + // Fn(Column, Column) + ASSERT_COLUMN_EQ( + createColumn>({5.0, 1.0, INFINITY}), + executeFunction( + "vecL2Distance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{0.0, 0.0}, Array{0.0, 0.0}, Array{3e38}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0, 4.0}, Array{0.0, 1.0}, Array{-3e38}}))); + + ASSERT_THROW( + executeFunction( + "vecL2Distance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0}})), + Exception); + + // Fn(Const, Const) + ASSERT_COLUMN_EQ( + createConstColumn>(3, 5.0), + executeFunction( + "vecL2Distance", + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, + Array{0.0, 0.0}), + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, + Array{3.0, 4.0}))); + + // Fn(Const, Column) + ASSERT_COLUMN_EQ( + createColumn>({5.0, 1.0, 1.0}), + executeFunction( + "vecL2Distance", + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, + Array{0.0, 0.0}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0, 4.0}, Array{0.0, 1.0}, Array{0.0, 1.0}}))); +} +CATCH + +TEST_F(Vector, NegativeInnerProduct) +try +{ + ASSERT_COLUMN_EQ( + createColumn>({-11.0, -INFINITY}), + executeFunction( + "vecNegativeInnerProduct", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, Array{3e38}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0, 4.0}, Array{3e38}}))); + + ASSERT_COLUMN_EQ( + createConstColumn>(3, -11.0), + executeFunction( + "vecNegativeInnerProduct", + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, + Array{1.0, 2.0}), + createConstColumn( + std::make_tuple(std::make_shared()), // + 3, + Array{3.0, 4.0}))); + + ASSERT_THROW( + executeFunction( + "vecNegativeInnerProduct", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0}})), + Exception); +} +CATCH + +TEST_F(Vector, CosineDistance) +try +{ + ASSERT_COLUMN_EQ( + createColumn>({0.0, std::nullopt, 0.0, 1.0, 2.0, 0.0, 2.0, std::nullopt}), + executeFunction( + "vecCosineDistance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, + Array{1.0, 2.0}, + Array{1.0, 1.0}, + Array{1.0, 0.0}, + Array{1.0, 1.0}, + Array{1.0, 1.0}, + Array{1.0, 1.0}, + Array{3e38}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{2.0, 4.0}, + Array{0.0, 0.0}, + Array{1.0, 1.0}, + Array{0.0, 2.0}, + Array{-1.0, -1.0}, + Array{1.1, 1.1}, + Array{-1.1, -1.1}, + Array{3e38}}))); + + ASSERT_THROW( + executeFunction( + "vecCosineDistance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0}})), + Exception); +} +CATCH + +TEST_F(Vector, L1Distance) +try +{ + ASSERT_COLUMN_EQ( + createColumn>({7.0, 1.0, INFINITY}), + executeFunction( + "vecL1Distance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{0.0, 0.0}, Array{0.0, 0.0}, Array{3e38}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0, 4.0}, Array{0.0, 1.0}, Array{-3e38}}))); + + ASSERT_THROW( + executeFunction( + "vecL1Distance", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{3.0}})), + Exception); +} +CATCH + +TEST_F(Vector, IsNull) +try +{ + ASSERT_COLUMN_EQ( + createColumn({0, 1}), + executeFunction( + "isNull", + createColumn>( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, std::nullopt}))); +} +CATCH + +TEST_F(Vector, CastAsString) +try +{ + ASSERT_COLUMN_EQ( + createColumn({"[]", "[1,2]"}), + executeFunction( + "cast_vector_float32_as_string", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{}, Array{1.0, 2.0}}))); +} +CATCH + +TEST_F(Vector, CastAsVector) +try +{ + ASSERT_COLUMN_EQ( + createColumn( + std::make_tuple(std::make_shared()), // + {Array{}, Array{1.0, 2.0}}), + executeFunction( + "cast_vector_float32_as_vector_float32", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{}, Array{1.0, 2.0}}))); +} +CATCH + +TEST_F(Vector, Compare) +try +{ + ASSERT_COLUMN_EQ( + createColumn({0, 1, 0, 1, 0, 1}), + executeFunction( + "less", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, Array{1.0, 2.0}, Array{1.0, 2.0}, Array{1.0, 1.0}, Array{1.0, 2.0}, Array{}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, + Array{2.0, 4.0}, + Array{0.0, 1.0}, + Array{1.0, 1.0, 1.0}, + Array{0.0, 2.0, 3.0}, + Array{0.0}}))); + + ASSERT_COLUMN_EQ( + createColumn({0, 0, 1, 0, 1, 0}), + executeFunction( + "greater", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, Array{1.0, 2.0}, Array{1.0, 2.0}, Array{1.0, 1.0}, Array{1.0, 2.0}, Array{}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, + Array{2.0, 4.0}, + Array{0.0, 1.0}, + Array{1.0, 1.0, 1.0}, + Array{0.0, 2.0, 3.0}, + Array{0.0}}))); + + // equals + ASSERT_COLUMN_EQ( + createColumn({1, 0, 1, 0}), + executeFunction( + "equals", + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, Array{1.0, 2.0}, Array{}, Array{}}), + createColumn( + std::make_tuple(std::make_shared()), // + {Array{1.0, 2.0}, Array{2.0, 4.0}, Array{}, Array{1.0, 1.0, 1.0}}))); +} +CATCH + +} // namespace tests +} // namespace DB diff --git a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp index 28a86e9e6c5..646954ae78b 100644 --- a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp +++ b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp @@ -65,6 +65,7 @@ inline bool isRoughSetFilterSupportType(const Int32 field_type) case TiDB::TypeString: return false; // Unknown. + case TiDB::TypeTiDBVectorFloat32: case TiDB::TypeDecimal: case TiDB::TypeNewDecimal: case TiDB::TypeFloat: diff --git a/dbms/src/TiDB/Decode/DatumCodec.cpp b/dbms/src/TiDB/Decode/DatumCodec.cpp index ccfd422aabb..7151a98b1f7 100644 --- a/dbms/src/TiDB/Decode/DatumCodec.cpp +++ b/dbms/src/TiDB/Decode/DatumCodec.cpp @@ -13,10 +13,13 @@ // limitations under the License. #include +#include #include +#include #include #include #include +#include #include namespace DB @@ -342,6 +345,44 @@ Field DecodeDatumForCHRow(size_t & cursor, const String & raw_value, const TiDB: } } +void EncodeVectorFloat32(const Array & val, WriteBuffer & ss) +{ + RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); + + writeIntBinary(static_cast(val.size()), ss); + for (const auto & s : val) + writeFloatBinary(static_cast(s.safeGet::Type>()), ss); +} + +void SkipVectorFloat32(size_t & cursor, const String & raw_value) +{ + RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); + + auto elements_n = readLittleEndian(&raw_value[cursor]); + auto size = sizeof(elements_n) + elements_n * sizeof(Float32); + cursor += size; +} + +Field DecodeVectorFloat32(size_t & cursor, const String & raw_value) +{ + RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); + + auto n = readLittleEndian(&raw_value[cursor]); + cursor += sizeof(UInt32); + + Array res; + res.reserve(n); + + for (size_t i = 0; i < n; i++) + { + auto v = readLittleEndian(&raw_value[cursor]); + res.emplace_back(static_cast(v)); + cursor += sizeof(Float32); + } + + return res; +} + Field DecodeDatum(size_t & cursor, const String & raw_value) { switch (raw_value[cursor++]) @@ -368,6 +409,8 @@ Field DecodeDatum(size_t & cursor, const String & raw_value) return DecodeDecimal(cursor, raw_value); case TiDB::CodecFlagJson: return JsonBinary::DecodeJsonAsBinary(cursor, raw_value); + case TiDB::CodecFlagVectorFloat32: + return DecodeVectorFloat32(cursor, raw_value); default: throw Exception("Unknown Type:" + std::to_string(raw_value[cursor - 1]), ErrorCodes::LOGICAL_ERROR); } @@ -409,6 +452,9 @@ void SkipDatum(size_t & cursor, const String & raw_value) case TiDB::CodecFlagJson: JsonBinary::SkipJson(cursor, raw_value); return; + case TiDB::CodecFlagVectorFloat32: + SkipVectorFloat32(cursor, raw_value); + return; default: throw Exception("Unknown Type:" + std::to_string(raw_value[cursor - 1]), ErrorCodes::LOGICAL_ERROR); } @@ -662,6 +708,8 @@ void EncodeDatum(const Field & field, TiDB::CodecFlag flag, WriteBuffer & ss) return EncodeInt64(field.safeGet(), ss); case TiDB::CodecFlagJson: return EncodeJSON(field.safeGet(), ss); + case TiDB::CodecFlagVectorFloat32: + return EncodeVectorFloat32(field.safeGet(), ss); case TiDB::CodecFlagNil: return; default: diff --git a/dbms/src/TiDB/Decode/DatumCodec.h b/dbms/src/TiDB/Decode/DatumCodec.h index 2fafb6af490..9dcea6b5c89 100644 --- a/dbms/src/TiDB/Decode/DatumCodec.h +++ b/dbms/src/TiDB/Decode/DatumCodec.h @@ -59,6 +59,8 @@ UInt64 DecodeVarUInt(size_t & cursor, const StringRef & raw_value); Int64 DecodeVarInt(size_t & cursor, const String & raw_value); +Field DecodeVectorFloat32(size_t & cursor, const String & raw_value); + Field DecodeDecimal(size_t & cursor, const String & raw_value); Field DecodeDecimalForCHRow(size_t & cursor, const String & raw_value, const TiDB::ColumnInfo & column_info); @@ -89,6 +91,8 @@ void EncodeCompactBytes(const String & str, WriteBuffer & ss); void EncodeJSON(const String & str, WriteBuffer & ss); +void EncodeVectorFloat32(const Array & val, WriteBuffer & ss); + void EncodeVarUInt(UInt64 num, WriteBuffer & ss); void EncodeVarInt(Int64 num, WriteBuffer & ss); diff --git a/dbms/src/TiDB/Decode/RowCodec.cpp b/dbms/src/TiDB/Decode/RowCodec.cpp index 3b15785b72d..7bea5042843 100644 --- a/dbms/src/TiDB/Decode/RowCodec.cpp +++ b/dbms/src/TiDB/Decode/RowCodec.cpp @@ -123,6 +123,9 @@ TiKVValue::Base encodeNotNullColumn(const Field & field, const ColumnInfo & colu case TiDB::TypeLongBlob: case TiDB::TypeJSON: return field.safeGet(); + case TiDB::TypeTiDBVectorFloat32: + // unsupported, only used in tests. + throw Exception("unsupported encode TiDBVectorFloat32"); case TiDB::TypeNewDecimal: EncodeDecimalForRow(field, ss, column_info); break; diff --git a/dbms/src/TiDB/Decode/TypeMapping.cpp b/dbms/src/TiDB/Decode/TypeMapping.cpp index 442ae15e7ac..2b62bf47275 100644 --- a/dbms/src/TiDB/Decode/TypeMapping.cpp +++ b/dbms/src/TiDB/Decode/TypeMapping.cpp @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include +#include #include #include #include @@ -34,6 +36,8 @@ #include #include +#include +#include #include namespace DB @@ -104,9 +108,21 @@ struct EnumType : public std::true_type template inline constexpr bool IsEnumType = EnumType::value; +template +struct ArrayType : public std::false_type +{ +}; +template <> +struct ArrayType : public std::true_type +{ +}; +template +inline constexpr bool IsArrayType = ArrayType::value; + template std::enable_if_t< - !IsSignedType && !IsDecimalType && !IsEnumType && !std::is_same_v, + !IsSignedType && !IsDecimalType && !IsEnumType && !std::is_same_v + && !IsArrayType, DataTypePtr> // getDataTypeByColumnInfoBase(const ColumnInfo &, const T *) { @@ -132,6 +148,13 @@ std::enable_if_t, DataTypePtr> getDataTypeByColumnInfoBase(cons return createDecimal(column_info.flen, column_info.decimal); } +template +std::enable_if_t, DataTypePtr> getDataTypeByColumnInfoBase(const ColumnInfo & column_info, const T *) +{ + RUNTIME_CHECK(column_info.tp == TiDB::TypeTiDBVectorFloat32, magic_enum::enum_name(column_info.tp)); + const auto nested_type = std::make_shared(); + return std::make_shared(nested_type); +} template std::enable_if_t, DataTypePtr> // @@ -425,6 +448,9 @@ ColumnInfo reverseGetColumnInfo(const NameAndTypePair & column, ColumnID id, con case TypeIndex::Enum16: column_info.tp = TiDB::TypeEnum; break; + case TypeIndex::Array: + column_info.tp = TiDB::TypeTiDBVectorFloat32; + break; default: throw DB::Exception( "Unable reverse map TiFlash type " + nested_type->getName() + " to TiDB type", diff --git a/dbms/src/TiDB/Decode/Vector.cpp b/dbms/src/TiDB/Decode/Vector.cpp new file mode 100644 index 00000000000..19151c78143 --- /dev/null +++ b/dbms/src/TiDB/Decode/Vector.cpp @@ -0,0 +1,198 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} // namespace ErrorCodes + +VectorFloat32Ref::VectorFloat32Ref(const Float32 * elements, size_t n) + : elements(elements) + , elements_n(n) +{ + for (size_t i = 0; i < n; ++i) + { + if (std::isnan(elements[i])) + throw Exception("NaN not allowed in vector", ErrorCodes::BAD_ARGUMENTS); + if (std::isinf(elements[i])) + throw Exception("infinite value not allowed in vector", ErrorCodes::BAD_ARGUMENTS); + } +} + +void VectorFloat32Ref::checkDims(VectorFloat32Ref b) const +{ + if (size() != b.size()) + throw Exception( + fmt::format("vectors have different dimensions: {} and {}", size(), b.size()), + ErrorCodes::BAD_ARGUMENTS); +} + +Float64 VectorFloat32Ref::l2SquaredDistance(VectorFloat32Ref b) const +{ + checkDims(b); + + Float32 distance = 0.0; + Float32 diff; + + for (size_t i = 0, i_max = size(); i < i_max; ++i) + { + // Hope this can be vectorized. + diff = elements[i] - b[i]; + distance += diff * diff; + } + + return distance; +} + +Float64 VectorFloat32Ref::innerProduct(VectorFloat32Ref b) const +{ + checkDims(b); + + Float32 distance = 0.0; + + for (size_t i = 0, i_max = size(); i < i_max; ++i) + { + // Hope this can be vectorized. + distance += elements[i] * b[i]; + } + + return distance; +} + +Float64 VectorFloat32Ref::cosineDistance(VectorFloat32Ref b) const +{ + checkDims(b); + + Float32 distance = 0.0; + Float32 norma = 0.0; + Float32 normb = 0.0; + + for (size_t i = 0, i_max = size(); i < i_max; ++i) + { + // Hope this can be vectorized. + distance += elements[i] * b[i]; + norma += elements[i] * elements[i]; + normb += b[i] * b[i]; + } + + Float64 similarity + = static_cast(distance) / std::sqrt(static_cast(norma) * static_cast(normb)); + + if (std::isnan(similarity)) + { + // Divide by zero + return std::nan(""); + } + + if (similarity > 1.0) + { + similarity = 1.0; + } + else if (similarity < -1.0) + { + similarity = -1.0; + } + + return 1.0 - similarity; +} + +Float64 VectorFloat32Ref::l1Distance(VectorFloat32Ref b) const +{ + checkDims(b); + + Float32 distance = 0.0; + + for (size_t i = 0, i_max = size(); i < i_max; ++i) + { + // Hope this can be vectorized. + Float32 diff = std::abs(elements[i] - b[i]); + distance += diff; + } + + return distance; +} + +Float64 VectorFloat32Ref::l2Norm() const +{ + // Note: We align the impl with pgvector: Only l2_norm use double + // precision during calculation. + + Float64 norm = 0.0; + + for (size_t i = 0, i_max = size(); i < i_max; ++i) + { + // Hope this can be vectorized. + norm += static_cast(elements[i]) * static_cast(elements[i]); + } + + return std::sqrt(norm); +} + +std::strong_ordering VectorFloat32Ref::operator<=>(const VectorFloat32Ref & b) const +{ + auto la = size(); + auto lb = b.size(); + auto common_len = std::min(la, lb); + + const auto * va = elements; + const auto * vb = b.elements; + + for (size_t i = 0; i < common_len; i++) + { + if (va[i] < vb[i]) + return std::strong_ordering::less; + else if (va[i] > vb[i]) + return std::strong_ordering::greater; + } + if (la < lb) + return std::strong_ordering::less; + else if (la > lb) + return std::strong_ordering::greater; + else + return std::strong_ordering::equal; +} + +String VectorFloat32Ref::toString() const +{ + WriteBufferFromOwnString write_buffer; + toStringInBuffer(write_buffer); + write_buffer.finalize(); + return write_buffer.releaseStr(); +} + +void VectorFloat32Ref::toStringInBuffer(WriteBuffer & write_buffer) const +{ + write_buffer.write('['); + for (size_t i = 0; i < elements_n; i++) + { + if (i > 0) + { + write_buffer.write(','); + } + writeFloatText(elements[i], write_buffer); + } + write_buffer.write(']'); +} + +} // namespace DB diff --git a/dbms/src/TiDB/Decode/Vector.h b/dbms/src/TiDB/Decode/Vector.h new file mode 100644 index 00000000000..6ed4578fa60 --- /dev/null +++ b/dbms/src/TiDB/Decode/Vector.h @@ -0,0 +1,68 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include + +namespace DB +{ + +class VectorFloat32Ref +{ +public: + explicit VectorFloat32Ref(const Float32 * elements, size_t n); + + explicit VectorFloat32Ref(const StringRef & data) + : VectorFloat32Ref(reinterpret_cast(data.data), data.size / sizeof(Float32)) + {} + + size_t size() const { return elements_n; } + + bool empty() const { return size() == 0; } + + const Float32 & operator[](size_t n) const { return elements[n]; } + + void checkDims(VectorFloat32Ref b) const; + + Float64 l2SquaredDistance(VectorFloat32Ref b) const; + + Float64 l2Distance(VectorFloat32Ref b) const { return std::sqrt(l2SquaredDistance(b)); } + + Float64 innerProduct(VectorFloat32Ref b) const; + + Float64 negativeInnerProduct(VectorFloat32Ref b) const { return innerProduct(b) * -1; } + + Float64 cosineDistance(VectorFloat32Ref b) const; + + Float64 l1Distance(VectorFloat32Ref b) const; + + Float64 l2Norm() const; + + std::strong_ordering operator<=>(const VectorFloat32Ref & b) const; + + String toString() const; + + void toStringInBuffer(WriteBuffer & write_buffer) const; + +private: + const Float32 * elements; + const size_t elements_n; +}; + +} // namespace DB diff --git a/dbms/src/TiDB/Schema/TiDB.cpp b/dbms/src/TiDB/Schema/TiDB.cpp index df2baf53995..c85ec2f45a5 100644 --- a/dbms/src/TiDB/Schema/TiDB.cpp +++ b/dbms/src/TiDB/Schema/TiDB.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -76,6 +77,8 @@ Field GenDefaultField(const TiDB::ColumnInfo & col_info) return Field(static_cast(0)); case TiDB::CodecFlagJson: return TiDB::genJsonNull(); + case TiDB::CodecFlagVectorFloat32: + return Field(Array(0)); case TiDB::CodecFlagDuration: return Field(static_cast(0)); default: diff --git a/dbms/src/TiDB/Schema/TiDB.h b/dbms/src/TiDB/Schema/TiDB.h index 156625494fe..9dc16e769b0 100644 --- a/dbms/src/TiDB/Schema/TiDB.h +++ b/dbms/src/TiDB/Schema/TiDB.h @@ -87,7 +87,8 @@ using DB::Timestamp; M(Blob, 0xfc, CompactBytes, String) \ M(VarString, 0xfd, CompactBytes, String) \ M(String, 0xfe, CompactBytes, String) \ - M(Geometry, 0xff, CompactBytes, String) + M(Geometry, 0xff, CompactBytes, String) \ + M(TiDBVectorFloat32, 0xe1, VectorFloat32, Array) enum TP { @@ -139,18 +140,19 @@ enum ColumnFlag #ifdef M #error "Please undefine macro M first." #endif -#define CODEC_FLAGS(M) \ - M(Nil, 0) \ - M(Bytes, 1) \ - M(CompactBytes, 2) \ - M(Int, 3) \ - M(UInt, 4) \ - M(Float, 5) \ - M(Decimal, 6) \ - M(Duration, 7) \ - M(VarInt, 8) \ - M(VarUInt, 9) \ - M(Json, 10) \ +#define CODEC_FLAGS(M) \ + M(Nil, 0) \ + M(Bytes, 1) \ + M(CompactBytes, 2) \ + M(Int, 3) \ + M(UInt, 4) \ + M(Float, 5) \ + M(Decimal, 6) \ + M(Duration, 7) \ + M(VarInt, 8) \ + M(VarUInt, 9) \ + M(Json, 10) \ + M(VectorFloat32, 20) \ M(Max, 250) enum CodecFlag From b4633dc186a36fb6be50ee499e21e29b51be6d8b Mon Sep 17 00:00:00 2001 From: Lloyd-Pottiger Date: Mon, 29 Jul 2024 13:34:46 +0800 Subject: [PATCH 2/7] format Signed-off-by: Lloyd-Pottiger --- dbms/src/DataTypes/DataTypeArray.h | 2 +- dbms/src/Flash/Coprocessor/TiDBColumn.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dbms/src/DataTypes/DataTypeArray.h b/dbms/src/DataTypes/DataTypeArray.h index e0083bf1e0a..bd71e9233f3 100644 --- a/dbms/src/DataTypes/DataTypeArray.h +++ b/dbms/src/DataTypes/DataTypeArray.h @@ -98,7 +98,7 @@ class DataTypeArray final : public IDataType bool haveSubtypes() const override { return true; } bool cannotBeStoredInTables() const override { return nested->cannotBeStoredInTables(); } bool textCanContainOnlyValidUTF8() const override { return nested->textCanContainOnlyValidUTF8(); } - bool isComparable() const override { return nested->isComparable(); }; + bool isComparable() const override { return nested->isComparable(); } bool canBeComparedWithCollation() const override { return nested->canBeComparedWithCollation(); } bool isValueUnambiguouslyRepresentedInContiguousMemoryRegion() const override diff --git a/dbms/src/Flash/Coprocessor/TiDBColumn.h b/dbms/src/Flash/Coprocessor/TiDBColumn.h index 9e66db40b82..dc4eea99212 100644 --- a/dbms/src/Flash/Coprocessor/TiDBColumn.h +++ b/dbms/src/Flash/Coprocessor/TiDBColumn.h @@ -49,7 +49,7 @@ class TiDBColumn void finishAppendVar(UInt32 size); private: - bool isFixed() const { return fixed_size != VAR_SIZE; }; + bool isFixed() const { return fixed_size != VAR_SIZE; } void appendNullBitMap(bool value); From 4492abbc60d1121439feb08aa3e9c5dda49df62e Mon Sep 17 00:00:00 2001 From: JaySon-Huang Date: Wed, 31 Jul 2024 15:42:22 +0800 Subject: [PATCH 3/7] Add some comments --- dbms/src/Columns/ColumnArray.cpp | 4 +--- dbms/src/Flash/Coprocessor/ArrowColCodec.cpp | 12 ++++++------ dbms/src/Flash/Coprocessor/TiDBColumn.h | 8 ++++---- dbms/src/TiDB/Schema/TiDB.h | 1 + 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/dbms/src/Columns/ColumnArray.cpp b/dbms/src/Columns/ColumnArray.cpp index 1ae4c1bc726..c81d1787cc5 100644 --- a/dbms/src/Columns/ColumnArray.cpp +++ b/dbms/src/Columns/ColumnArray.cpp @@ -1114,15 +1114,13 @@ size_t ColumnArray::encodeIntoDatumData(size_t element_idx, WriteBuffer & writer { RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little); - size_t encoded_size = 0; - RUNTIME_CHECK(checkAndGetColumn>(&getData())); RUNTIME_CHECK(getData().isFixedAndContiguous()); auto n = static_cast(sizeAt(element_idx)); writeIntBinary(n, writer); - encoded_size += sizeof(UInt32); + size_t encoded_size = sizeof(UInt32); auto data = getDataAt(element_idx); RUNTIME_CHECK(data.size == n * sizeof(Float32)); diff --git a/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp b/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp index 83a411a6181..750ec173814 100644 --- a/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp +++ b/dbms/src/Flash/Coprocessor/ArrowColCodec.cpp @@ -496,9 +496,9 @@ void flashColToArrowCol( const auto * data_type = checkAndGetDataType(type); if (!data_type || data_type->getNestedType()->getTypeId() != TypeIndex::Float32) throw TiFlashException( - "Type un-matched during arrow encode, target col type is array and source column type is " - + type->getName(), - Errors::Coprocessor::Internal); + Errors::Coprocessor::Internal, + "Type un-matched during arrow encode, target col type is array and source column type is {}", + type->getName()); if (tidb_column_info.hasNotNullFlag()) flashArrayFloat32ColToArrowCol(dag_column, col, start_index, end_index); else @@ -582,9 +582,9 @@ const char * arrowArrayFloat32ColToFlashCol( const auto * data_type = checkAndGetDataType(&*col.type); if (!data_type || data_type->getNestedType()->getTypeId() != TypeIndex::Float32) throw TiFlashException( - "Type un-matched during arrow decode, target col type is array and source column type is " - + col.type->getName(), - Errors::Coprocessor::Internal); + Errors::Coprocessor::Internal, + "Type un-matched during arrow decode, target col type is array and source column type is {}", + col.type->getName()); for (UInt32 i = 0; i < length; i++) { diff --git a/dbms/src/Flash/Coprocessor/TiDBColumn.h b/dbms/src/Flash/Coprocessor/TiDBColumn.h index dc4eea99212..c5c36b75489 100644 --- a/dbms/src/Flash/Coprocessor/TiDBColumn.h +++ b/dbms/src/Flash/Coprocessor/TiDBColumn.h @@ -44,11 +44,12 @@ class TiDBColumn void encodeColumn(WriteBuffer & ss); void clear(); - std::unique_ptr data; - void finishAppendFixed(); + // FIXME: expose for ColumnArray::encodeIntoDatumData, find a better way to implement it. void finishAppendVar(UInt32 size); - + // WriteBufferFromOwnString is not moveable. + std::unique_ptr data; private: + void finishAppendFixed(); bool isFixed() const { return fixed_size != VAR_SIZE; } void appendNullBitMap(bool value); @@ -57,7 +58,6 @@ class TiDBColumn UInt32 null_cnt; std::vector null_bitmap; std::vector var_offsets; - // WriteBufferFromOwnString is not moveable. std::string default_value; UInt64 current_data_size; diff --git a/dbms/src/TiDB/Schema/TiDB.h b/dbms/src/TiDB/Schema/TiDB.h index 9dc16e769b0..4d1f033564f 100644 --- a/dbms/src/TiDB/Schema/TiDB.h +++ b/dbms/src/TiDB/Schema/TiDB.h @@ -155,6 +155,7 @@ enum ColumnFlag M(VectorFloat32, 20) \ M(Max, 250) +// Defined in pingcap/tidb pkg/util/codec/codec.go enum CodecFlag { #ifdef M From ab8149a6226c17ab04e44539fdd4a1c5f3705d17 Mon Sep 17 00:00:00 2001 From: JaySon-Huang Date: Wed, 31 Jul 2024 16:46:05 +0800 Subject: [PATCH 4/7] Format files --- dbms/src/Flash/Coprocessor/TiDBColumn.h | 1 + dbms/src/TiDB/Schema/TiDB.h | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/dbms/src/Flash/Coprocessor/TiDBColumn.h b/dbms/src/Flash/Coprocessor/TiDBColumn.h index c5c36b75489..2095daa9f6a 100644 --- a/dbms/src/Flash/Coprocessor/TiDBColumn.h +++ b/dbms/src/Flash/Coprocessor/TiDBColumn.h @@ -48,6 +48,7 @@ class TiDBColumn void finishAppendVar(UInt32 size); // WriteBufferFromOwnString is not moveable. std::unique_ptr data; + private: void finishAppendFixed(); bool isFixed() const { return fixed_size != VAR_SIZE; } diff --git a/dbms/src/TiDB/Schema/TiDB.h b/dbms/src/TiDB/Schema/TiDB.h index 4d1f033564f..84efebea820 100644 --- a/dbms/src/TiDB/Schema/TiDB.h +++ b/dbms/src/TiDB/Schema/TiDB.h @@ -137,6 +137,7 @@ enum ColumnFlag // Codec flags. // In format: TiDB codec flag, int value. +// Defined in pingcap/tidb pkg/util/codec/codec.go #ifdef M #error "Please undefine macro M first." #endif @@ -155,7 +156,6 @@ enum ColumnFlag M(VectorFloat32, 20) \ M(Max, 250) -// Defined in pingcap/tidb pkg/util/codec/codec.go enum CodecFlag { #ifdef M From eac72333ffdf1af687f8fe7deabcca93b68373f1 Mon Sep 17 00:00:00 2001 From: Lloyd-Pottiger Date: Wed, 31 Jul 2024 17:09:42 +0800 Subject: [PATCH 5/7] address comments Signed-off-by: Lloyd-Pottiger --- dbms/src/Columns/ColumnNullable.cpp | 2 +- dbms/src/Flash/Coprocessor/TiDBColumn.h | 1 + dbms/src/TiDB/Decode/Vector.cpp | 30 +++++++------------------ 3 files changed, 10 insertions(+), 23 deletions(-) diff --git a/dbms/src/Columns/ColumnNullable.cpp b/dbms/src/Columns/ColumnNullable.cpp index e814bfeb04d..a18ff7bcaf1 100644 --- a/dbms/src/Columns/ColumnNullable.cpp +++ b/dbms/src/Columns/ColumnNullable.cpp @@ -206,7 +206,7 @@ void ColumnNullable::get(size_t n, Field & res) const StringRef ColumnNullable::getDataAt(size_t n) const { - if (!isNullAt(n)) + if (likely(!isNullAt(n))) return getNestedColumn().getDataAt(n); throw Exception( diff --git a/dbms/src/Flash/Coprocessor/TiDBColumn.h b/dbms/src/Flash/Coprocessor/TiDBColumn.h index c5c36b75489..2095daa9f6a 100644 --- a/dbms/src/Flash/Coprocessor/TiDBColumn.h +++ b/dbms/src/Flash/Coprocessor/TiDBColumn.h @@ -48,6 +48,7 @@ class TiDBColumn void finishAppendVar(UInt32 size); // WriteBufferFromOwnString is not moveable. std::unique_ptr data; + private: void finishAppendFixed(); bool isFixed() const { return fixed_size != VAR_SIZE; } diff --git a/dbms/src/TiDB/Decode/Vector.cpp b/dbms/src/TiDB/Decode/Vector.cpp index 19151c78143..6a11c5a0737 100644 --- a/dbms/src/TiDB/Decode/Vector.cpp +++ b/dbms/src/TiDB/Decode/Vector.cpp @@ -33,19 +33,17 @@ VectorFloat32Ref::VectorFloat32Ref(const Float32 * elements, size_t n) { for (size_t i = 0; i < n; ++i) { - if (std::isnan(elements[i])) + if (unlikely(std::isnan(elements[i]))) throw Exception("NaN not allowed in vector", ErrorCodes::BAD_ARGUMENTS); - if (std::isinf(elements[i])) + if (unlikely(std::isinf(elements[i]))) throw Exception("infinite value not allowed in vector", ErrorCodes::BAD_ARGUMENTS); } } void VectorFloat32Ref::checkDims(VectorFloat32Ref b) const { - if (size() != b.size()) - throw Exception( - fmt::format("vectors have different dimensions: {} and {}", size(), b.size()), - ErrorCodes::BAD_ARGUMENTS); + if (unlikely(size() != b.size())) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "vectors have different dimensions: {} and {}", size(), b.size()); } Float64 VectorFloat32Ref::l2SquaredDistance(VectorFloat32Ref b) const @@ -101,19 +99,12 @@ Float64 VectorFloat32Ref::cosineDistance(VectorFloat32Ref b) const if (std::isnan(similarity)) { - // Divide by zero + // When norma or normb is zero, distance is zero, and similarity is NaN. + // similarity can not be Inf in this case. return std::nan(""); } - if (similarity > 1.0) - { - similarity = 1.0; - } - else if (similarity < -1.0) - { - similarity = -1.0; - } - + similarity = std::clamp(similarity, -1.0, 1.0); return 1.0 - similarity; } @@ -165,12 +156,7 @@ std::strong_ordering VectorFloat32Ref::operator<=>(const VectorFloat32Ref & b) c else if (va[i] > vb[i]) return std::strong_ordering::greater; } - if (la < lb) - return std::strong_ordering::less; - else if (la > lb) - return std::strong_ordering::greater; - else - return std::strong_ordering::equal; + return la <=> lb; } String VectorFloat32Ref::toString() const From feb28b7d99bef70787a14b4f6ca8647f04c30440 Mon Sep 17 00:00:00 2001 From: Lloyd-Pottiger Date: Wed, 31 Jul 2024 18:45:58 +0800 Subject: [PATCH 6/7] fix ft Signed-off-by: Lloyd-Pottiger --- tests/docker/util.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/docker/util.sh b/tests/docker/util.sh index ed33cdedd4c..42698580e25 100644 --- a/tests/docker/util.sh +++ b/tests/docker/util.sh @@ -94,9 +94,9 @@ function set_branch() { # XYZ_BRANCH: pd/tikv/tidb hash, default to `master` # BRANCH: hash short cut, default to `master` if [ -n "$BRANCH" ]; then - [ -z "$PD_BRANCH" ] && export PD_BRANCH="$BRANCH" - [ -z "$TIKV_BRANCH" ] && export TIKV_BRANCH="$BRANCH" - [ -z "$TIDB_BRANCH" ] && export TIDB_BRANCH="$BRANCH" + [ -z "$PD_BRANCH" ] && export PD_BRANCH="master" + [ -z "$TIKV_BRANCH" ] && export TIKV_BRANCH="master" + [ -z "$TIDB_BRANCH" ] && export TIDB_BRANCH="master" fi echo "use branch \`${BRANCH-master}\` for ci test" } From 5fb3da8a61978ed8b548486b9343b99426ad076f Mon Sep 17 00:00:00 2001 From: JaySon-Huang Date: Wed, 31 Jul 2024 22:01:49 +0800 Subject: [PATCH 7/7] Revert "fix ft" This reverts commit feb28b7d99bef70787a14b4f6ca8647f04c30440. --- tests/docker/util.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/docker/util.sh b/tests/docker/util.sh index 42698580e25..ed33cdedd4c 100644 --- a/tests/docker/util.sh +++ b/tests/docker/util.sh @@ -94,9 +94,9 @@ function set_branch() { # XYZ_BRANCH: pd/tikv/tidb hash, default to `master` # BRANCH: hash short cut, default to `master` if [ -n "$BRANCH" ]; then - [ -z "$PD_BRANCH" ] && export PD_BRANCH="master" - [ -z "$TIKV_BRANCH" ] && export TIKV_BRANCH="master" - [ -z "$TIDB_BRANCH" ] && export TIDB_BRANCH="master" + [ -z "$PD_BRANCH" ] && export PD_BRANCH="$BRANCH" + [ -z "$TIKV_BRANCH" ] && export TIKV_BRANCH="$BRANCH" + [ -z "$TIDB_BRANCH" ] && export TIDB_BRANCH="$BRANCH" fi echo "use branch \`${BRANCH-master}\` for ci test" }