Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute: Add Vector data type and functions #9262

Merged
merged 9 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions dbms/src/Columns/ColumnArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,13 @@
#include <Common/SipHash.h>
#include <Common/typeid_cast.h>
#include <DataStreams/ColumnGathererStream.h>
#include <Functions/FunctionHelpers.h>
#include <IO/Endian.h>
#include <IO/WriteHelpers.h>
#include <string.h> // memcpy

#include <memory>

namespace DB
{
namespace ErrorCodes
Expand Down Expand Up @@ -798,10 +803,44 @@ void ColumnArray::getPermutation(bool reverse, size_t limit, int nan_direction_h
}
}

ColumnPtr ColumnArray::replicateRange(size_t /*start_row*/, size_t /*end_row*/, const IColumn::Offsets & /*offsets*/)
ColumnPtr ColumnArray::replicateRange(size_t start_row, size_t end_row, const IColumn::Offsets & replicate_offsets)
const
{
throw Exception("not implement.", ErrorCodes::NOT_IMPLEMENTED);
size_t col_size = size();
if (col_size != replicate_offsets.size())
throw Exception("Size of offsets doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);

// We only support replicate to full column.
RUNTIME_CHECK(start_row == 0, start_row);
RUNTIME_CHECK(end_row == replicate_offsets.size(), end_row, replicate_offsets.size());

if (typeid_cast<const ColumnUInt8 *>(data.get()))
return replicateNumber<UInt8>(replicate_offsets);
if (typeid_cast<const ColumnUInt16 *>(data.get()))
return replicateNumber<UInt16>(replicate_offsets);
if (typeid_cast<const ColumnUInt32 *>(data.get()))
return replicateNumber<UInt32>(replicate_offsets);
if (typeid_cast<const ColumnUInt64 *>(data.get()))
return replicateNumber<UInt64>(replicate_offsets);
if (typeid_cast<const ColumnUInt128 *>(data.get()))
return replicateNumber<UInt128>(replicate_offsets);
if (typeid_cast<const ColumnInt8 *>(data.get()))
return replicateNumber<Int8>(replicate_offsets);
if (typeid_cast<const ColumnInt16 *>(data.get()))
return replicateNumber<Int16>(replicate_offsets);
if (typeid_cast<const ColumnInt32 *>(data.get()))
return replicateNumber<Int32>(replicate_offsets);
if (typeid_cast<const ColumnInt64 *>(data.get()))
return replicateNumber<Int64>(replicate_offsets);
if (typeid_cast<const ColumnFloat32 *>(data.get()))
return replicateNumber<Float32>(replicate_offsets);
if (typeid_cast<const ColumnFloat64 *>(data.get()))
return replicateNumber<Float64>(replicate_offsets);
if (typeid_cast<const ColumnConst *>(data.get()))
return replicateConst(replicate_offsets);
if (typeid_cast<const ColumnNullable *>(data.get()))
return replicateNullable(replicate_offsets);
return replicateGeneric(replicate_offsets);
}


Expand Down Expand Up @@ -1048,4 +1087,48 @@ void ColumnArray::gather(ColumnGathererStream & gatherer)
gatherer.gather(*this);
}

bool ColumnArray::decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool /* force_decode */)
{
RUNTIME_CHECK(raw_value.size() >= cursor + length);
insertFromDatumData(raw_value.c_str() + cursor, length);
return true;
}

void ColumnArray::insertFromDatumData(const char * data, size_t length)
{
RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little);

RUNTIME_CHECK(checkAndGetColumn<ColumnVector<Float32>>(&getData()));
RUNTIME_CHECK(getData().isFixedAndContiguous());

RUNTIME_CHECK(length >= sizeof(UInt32), length);
auto n = readLittleEndian<UInt32>(data);
data += sizeof(UInt32);

auto precise_data_size = n * sizeof(Float32);
RUNTIME_CHECK(length >= sizeof(UInt32) + precise_data_size, n, length);
insertData(data, precise_data_size);
}

size_t ColumnArray::encodeIntoDatumData(size_t element_idx, WriteBuffer & writer) const
{
RUNTIME_CHECK(boost::endian::order::native == boost::endian::order::little);

RUNTIME_CHECK(checkAndGetColumn<ColumnVector<Float32>>(&getData()));
RUNTIME_CHECK(getData().isFixedAndContiguous());

auto n = static_cast<UInt32>(sizeAt(element_idx));

writeIntBinary(n, writer);
size_t encoded_size = sizeof(UInt32);

auto data = getDataAt(element_idx);
RUNTIME_CHECK(data.size == n * sizeof(Float32));
writer.write(data.data, data.size);
encoded_size += data.size;

return encoded_size;
}


} // namespace DB
9 changes: 9 additions & 0 deletions dbms/src/Columns/ColumnArray.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,15 @@ class ColumnArray final : public COWPtrHelper<IColumn, ColumnArray>
callback(data);
}

bool canBeInsideNullable() const override { return true; }

bool decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t /* length */, bool /* force_decode */)
override;

void insertFromDatumData(const char * data, size_t length) override;

size_t encodeIntoDatumData(size_t element_idx, WriteBuffer & writer) const;

private:
ColumnPtr data;
ColumnPtr offsets;
Expand Down
29 changes: 25 additions & 4 deletions dbms/src/Columns/ColumnNullable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,29 @@ void ColumnNullable::get(size_t n, Field & res) const
getNestedColumn().get(n, res);
}

StringRef ColumnNullable::getDataAt(size_t /*n*/) const
StringRef ColumnNullable::getDataAt(size_t n) const
{
throw Exception(fmt::format("Method getDataAt is not supported for {}", getName()), ErrorCodes::NOT_IMPLEMENTED);
if (likely(!isNullAt(n)))
return getNestedColumn().getDataAt(n);

throw Exception(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error code and error message should be updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is by design

ErrorCodes::NOT_IMPLEMENTED,
"Method getDataAt is not supported for {} in case if value is NULL",
getName());
}

void ColumnNullable::insertData(const char * /*pos*/, size_t /*length*/)
void ColumnNullable::insertData(const char * pos, size_t length)
{
throw Exception(fmt::format("Method insertData is not supported for {}", getName()), ErrorCodes::NOT_IMPLEMENTED);
if (pos == nullptr)
{
getNestedColumn().insertDefault();
getNullMapData().push_back(1);
}
else
{
getNestedColumn().insertData(pos, length);
getNullMapData().push_back(0);
}
}

bool ColumnNullable::decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool force_decode)
Expand All @@ -222,6 +237,12 @@ bool ColumnNullable::decodeTiDBRowV2Datum(size_t cursor, const String & raw_valu
return true;
}

void ColumnNullable::insertFromDatumData(const char * cursor, size_t len)
{
getNestedColumn().insertFromDatumData(cursor, len);
getNullMapData().push_back(0);
}

StringRef ColumnNullable::serializeValueIntoArena(
size_t n,
Arena & arena,
Expand Down
4 changes: 3 additions & 1 deletion dbms/src/Columns/ColumnNullable.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ class ColumnNullable final : public COWPtrHelper<IColumn, ColumnNullable>
Field operator[](size_t n) const override;
void get(size_t n, Field & res) const override;
UInt64 get64(size_t n) const override { return nested_column->get64(n); }
StringRef getDataAt(size_t n) const override;
StringRef getDataAt(size_t) const override;
/// Will insert null value if pos=nullptr
void insertData(const char * pos, size_t length) override;
bool decodeTiDBRowV2Datum(size_t cursor, const String & raw_value, size_t length, bool force_decode) override;
void insertFromDatumData(const char *, size_t) override;
StringRef serializeValueIntoArena(
size_t n,
Arena & arena,
Expand Down
5 changes: 5 additions & 0 deletions dbms/src/Columns/IColumn.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ class IColumn : public COWPtr<IColumn>
throw Exception("Method decodeTiDBRowV2Datum is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}

virtual void insertFromDatumData(const char *, size_t)
{
throw Exception("Method insertFromDatumData is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}

/// Like getData, but has special behavior for columns that contain variable-length strings.
/// In this special case inserting data should be zero-ending (i.e. length is 1 byte greater than real string size).
virtual void insertDataWithTerminatingZero(const char * pos, size_t length) { insertData(pos, length); }
Expand Down
4 changes: 2 additions & 2 deletions dbms/src/DataTypes/DataTypeArray.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DataTypeArray final : public IDataType

const char * getFamilyName() const override { return "Array"; }

bool canBeInsideNullable() const override { return false; }
bool canBeInsideNullable() const override { return true; }

TypeIndex getTypeId() const override { return TypeIndex::Array; }

Expand Down Expand Up @@ -98,7 +98,7 @@ class DataTypeArray final : public IDataType
bool haveSubtypes() const override { return true; }
bool cannotBeStoredInTables() const override { return nested->cannotBeStoredInTables(); }
bool textCanContainOnlyValidUTF8() const override { return nested->textCanContainOnlyValidUTF8(); }
bool isComparable() const override { return nested->isComparable(); };
bool isComparable() const override { return nested->isComparable(); }
bool canBeComparedWithCollation() const override { return nested->canBeComparedWithCollation(); }

bool isValueUnambiguouslyRepresentedInContiguousMemoryRegion() const override
Expand Down
7 changes: 7 additions & 0 deletions dbms/src/Debug/MockExecutor/AstToPB.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ void literalFieldToTiPBExpr(const ColumnInfo & ci, const Field & val_field, tipb
encodeDAGInt64(val, ss);
break;
}
case TiDB::TypeTiDBVectorFloat32:
{
expr->set_tp(tipb::ExprType::TiDBVectorFloat32);
const auto & val = val_field.safeGet<Array>();
encodeDAGVectorFloat32(val, ss);
break;
}
default:
throw Exception(fmt::format(
"Type {} does not support literal in function unit test",
Expand Down
4 changes: 4 additions & 0 deletions dbms/src/Debug/dbgTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ struct BatchCtrl
throw Exception(
"Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagJson",
ErrorCodes::LOGICAL_ERROR);
case TiDB::CodecFlagVectorFloat32:
throw Exception(
"Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagVectorFloat32",
ErrorCodes::LOGICAL_ERROR);
case TiDB::CodecFlagMax:
throw Exception("Not implented yet: BatchCtrl::encodeDatum, TiDB::CodecFlagMax", ErrorCodes::LOGICAL_ERROR);
case TiDB::CodecFlagDuration:
Expand Down
83 changes: 83 additions & 0 deletions dbms/src/Flash/Coprocessor/ArrowColCodec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Columns/ColumnArray.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <Common/TiFlashException.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeDecimal.h>
#include <DataTypes/DataTypeEnum.h>
#include <DataTypes/DataTypeMyDate.h>
Expand Down Expand Up @@ -292,6 +294,34 @@ void flashStringColToArrowCol(
}
}

template <bool is_nullable>
void flashArrayFloat32ColToArrowCol(
TiDBColumn & dag_column,
const IColumn * flash_col_untyped,
size_t start_index,
size_t end_index)
{
// We only unwrap the NULLABLE() part.
const IColumn * nested_col = getNestedCol(flash_col_untyped);
const auto * flash_col = checkAndGetColumn<ColumnArray>(nested_col);
for (size_t i = start_index; i < end_index; i++)
{
// todo check if we can convert flash_col to DAG col directly since the internal representation is almost the same
if constexpr (is_nullable)
{
if (flash_col_untyped->isNullAt(i))
{
dag_column.appendNull();
continue;
}
}

auto encoded_size = flash_col->encodeIntoDatumData(i, *dag_column.data);
RUNTIME_CHECK(encoded_size > 0);
dag_column.finishAppendVar(encoded_size);
}
}

template <bool is_nullable>
void flashBitColToArrowCol(
TiDBColumn & dag_column,
Expand Down Expand Up @@ -461,6 +491,20 @@ void flashColToArrowCol(
else
flashStringColToArrowCol<true>(dag_column, col, start_index, end_index);
break;
case TiDB::TypeTiDBVectorFloat32:
{
const auto * data_type = checkAndGetDataType<DataTypeArray>(type);
if (!data_type || data_type->getNestedType()->getTypeId() != TypeIndex::Float32)
throw TiFlashException(
Errors::Coprocessor::Internal,
"Type un-matched during arrow encode, target col type is array<float32> and source column type is {}",
type->getName());
if (tidb_column_info.hasNotNullFlag())
flashArrayFloat32ColToArrowCol<false>(dag_column, col, start_index, end_index);
else
flashArrayFloat32ColToArrowCol<true>(dag_column, col, start_index, end_index);
break;
}
case TiDB::TypeBit:
if (!checkDataType<DataTypeUInt64>(type))
throw TiFlashException(
Expand Down Expand Up @@ -525,6 +569,35 @@ const char * arrowStringColToFlashCol(
return pos + offsets[length];
}

const char * arrowArrayFloat32ColToFlashCol(
const char * pos,
UInt8,
UInt32 null_count,
const std::vector<UInt8> & null_bitmap,
const std::vector<UInt64> & offsets,
const ColumnWithTypeAndName & col,
const ColumnInfo &,
UInt32 length)
{
const auto * data_type = checkAndGetDataType<DataTypeArray>(&*col.type);
if (!data_type || data_type->getNestedType()->getTypeId() != TypeIndex::Float32)
throw TiFlashException(
Errors::Coprocessor::Internal,
"Type un-matched during arrow decode, target col type is array<float32> and source column type is {}",
col.type->getName());

for (UInt32 i = 0; i < length; i++)
{
if (checkNull(i, null_count, null_bitmap, col))
continue;

auto arrow_data_size = offsets[i + 1] - offsets[i];
const auto * base_offset = pos + offsets[i];
col.column->assumeMutable()->insertFromDatumData(base_offset, arrow_data_size);
}
return pos + offsets[length];
}

const char * arrowEnumColToFlashCol(
const char * pos,
UInt8,
Expand Down Expand Up @@ -819,6 +892,16 @@ const char * arrowColToFlashCol(
length);
case TiDB::TypeBit:
return arrowBitColToFlashCol(pos, field_length, null_count, null_bitmap, offsets, flash_col, col_info, length);
case TiDB::TypeTiDBVectorFloat32:
return arrowArrayFloat32ColToFlashCol(
pos,
field_length,
null_count,
null_bitmap,
offsets,
flash_col,
col_info,
length);
case TiDB::TypeEnum:
return arrowEnumColToFlashCol(pos, field_length, null_count, null_bitmap, offsets, flash_col, col_info, length);
default:
Expand Down
11 changes: 11 additions & 0 deletions dbms/src/Flash/Coprocessor/DAGCodec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ void encodeDAGDecimal(const Field & field, WriteBuffer & ss)
EncodeDecimal(field, ss);
}

void encodeDAGVectorFloat32(const Array & v, WriteBuffer & ss)
{
EncodeVectorFloat32(v, ss);
}

Int64 decodeDAGInt64(const String & s)
{
auto u = *(reinterpret_cast<const UInt64 *>(s.data()));
Expand Down Expand Up @@ -93,4 +98,10 @@ Field decodeDAGDecimal(const String & s)
return DecodeDecimal(cursor, s);
}

Field decodeDAGVectorFloat32(const String & s)
{
size_t cursor = 0;
return DecodeVectorFloat32(cursor, s);
}

} // namespace DB
Loading