Skip to content

Commit

Permalink
remove some duplicate code used by both DWRF and ORC (facebookincubat…
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxu14 authored and rui-mo committed Jul 21, 2023
1 parent f55ef8d commit eea25ce
Show file tree
Hide file tree
Showing 16 changed files with 192 additions and 276 deletions.
123 changes: 50 additions & 73 deletions velox/dwio/dwrf/reader/ColumnReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,41 +1024,58 @@ class StringDictionaryColumnReader : public ColumnReader {

void ensureInitialized();

void initOrc(StripeStreams& stripe) {
void init(StripeStreams& stripe) {
auto format = stripe.format();
EncodingKey encodingKey{nodeType_->id, flatMapContext_.sequence};
RleVersion rleVersion =
convertRleVersion(stripe.getEncodingOrc(encodingKey).kind());
dictionaryCount = stripe.getEncodingOrc(encodingKey).dictionarysize();

const auto dataId = encodingKey.forKind(proto::orc::Stream_Kind_DATA);
bool dictVInts = stripe.getUseVInts(dataId);
dictIndex = createRleDecoder</*isSigned*/ false>(
stripe.getStream(dataId, true),
rleVersion,
memoryPool_,
dictVInts,
dwio::common::INT_BYTE_SIZE);

const auto lenId = encodingKey.forKind(proto::orc::Stream_Kind_LENGTH);
bool lenVInts = stripe.getUseVInts(lenId);
lengthDecoder = createRleDecoder</*isSigned*/ false>(
stripe.getStream(lenId, false),
rleVersion,
memoryPool_,
lenVInts,
dwio::common::INT_BYTE_SIZE);

blobStream = stripe.getStream(
encodingKey.forKind(proto::orc::Stream_Kind_DICTIONARY_DATA), false);
}

void initDwrf(StripeStreams& stripe) {
EncodingKey encodingKey{nodeType_->id, flatMapContext_.sequence};
RleVersion rleVersion =
convertRleVersion(stripe.getEncoding(encodingKey).kind());
dictionaryCount = stripe.getEncoding(encodingKey).dictionarysize();
RleVersion rleVersion;
DwrfStreamIdentifier dataId;
DwrfStreamIdentifier lenId;
DwrfStreamIdentifier dictionaryId;
if (format == DwrfFormat::kDwrf) {
rleVersion = convertRleVersion(stripe.getEncoding(encodingKey).kind());
dictionaryCount = stripe.getEncoding(encodingKey).dictionarysize();
dataId = encodingKey.forKind(proto::Stream_Kind_DATA);
lenId = encodingKey.forKind(proto::Stream_Kind_LENGTH);
dictionaryId = encodingKey.forKind(proto::Stream_Kind_DICTIONARY_DATA);

// handle in dictionary stream
std::unique_ptr<dwio::common::SeekableInputStream> inDictStream =
stripe.getStream(
encodingKey.forKind(proto::Stream_Kind_IN_DICTIONARY), false);
if (inDictStream) {
inDictionaryReader =
createBooleanRleDecoder(std::move(inDictStream), encodingKey);

// stride dictionary only exists if in dictionary exists
strideDictStream = stripe.getStream(
encodingKey.forKind(proto::Stream_Kind_STRIDE_DICTIONARY), true);
DWIO_ENSURE_NOT_NULL(strideDictStream, "Stride dictionary is missing");

indexStream_ = stripe.getStream(
encodingKey.forKind(proto::Stream_Kind_ROW_INDEX), true);
DWIO_ENSURE_NOT_NULL(indexStream_, "String index is missing");

const auto strideDictLenId =
encodingKey.forKind(proto::Stream_Kind_STRIDE_DICTIONARY_LENGTH);
bool strideLenVInt = stripe.getUseVInts(strideDictLenId);
strideDictLengthDecoder = createRleDecoder</*isSigned*/ false>(
stripe.getStream(strideDictLenId, true),
rleVersion,
memoryPool_,
strideLenVInt,
dwio::common::INT_BYTE_SIZE);
}
} else {
VELOX_CHECK(format == DwrfFormat::kOrc);
rleVersion = convertRleVersion(stripe.getEncodingOrc(encodingKey).kind());
dictionaryCount = stripe.getEncodingOrc(encodingKey).dictionarysize();
dataId = encodingKey.forKind(proto::orc::Stream_Kind_DATA);
lenId = encodingKey.forKind(proto::orc::Stream_Kind_LENGTH);
dictionaryId =
encodingKey.forKind(proto::orc::Stream_Kind_DICTIONARY_DATA);
}

const auto dataId = encodingKey.forKind(proto::Stream_Kind_DATA);
bool dictVInts = stripe.getUseVInts(dataId);
dictIndex = createRleDecoder</*isSigned*/ false>(
stripe.getStream(dataId, true),
Expand All @@ -1067,7 +1084,6 @@ class StringDictionaryColumnReader : public ColumnReader {
dictVInts,
dwio::common::INT_BYTE_SIZE);

const auto lenId = encodingKey.forKind(proto::Stream_Kind_LENGTH);
bool lenVInts = stripe.getUseVInts(lenId);
lengthDecoder = createRleDecoder</*isSigned*/ false>(
stripe.getStream(lenId, false),
Expand All @@ -1076,46 +1092,7 @@ class StringDictionaryColumnReader : public ColumnReader {
lenVInts,
dwio::common::INT_BYTE_SIZE);

blobStream = stripe.getStream(
encodingKey.forKind(proto::Stream_Kind_DICTIONARY_DATA), false);

// handle in dictionary stream
std::unique_ptr<dwio::common::SeekableInputStream> inDictStream =
stripe.getStream(
encodingKey.forKind(proto::Stream_Kind_IN_DICTIONARY), false);
if (inDictStream) {
inDictionaryReader =
createBooleanRleDecoder(std::move(inDictStream), encodingKey);

// stride dictionary only exists if in dictionary exists
strideDictStream = stripe.getStream(
encodingKey.forKind(proto::Stream_Kind_STRIDE_DICTIONARY), true);
DWIO_ENSURE_NOT_NULL(strideDictStream, "Stride dictionary is missing");

indexStream_ = stripe.getStream(
encodingKey.forKind(proto::Stream_Kind_ROW_INDEX), true);
DWIO_ENSURE_NOT_NULL(indexStream_, "String index is missing");

const auto strideDictLenId =
encodingKey.forKind(proto::Stream_Kind_STRIDE_DICTIONARY_LENGTH);
bool strideLenVInt = stripe.getUseVInts(strideDictLenId);
strideDictLengthDecoder = createRleDecoder</*isSigned*/ false>(
stripe.getStream(strideDictLenId, true),
rleVersion,
memoryPool_,
strideLenVInt,
dwio::common::INT_BYTE_SIZE);
}
}

void init(StripeStreams& stripe) {
auto format = stripe.format();
if (format == DwrfFormat::kDwrf) {
initDwrf(stripe);
} else {
VELOX_CHECK(format == DwrfFormat::kOrc);
initOrc(stripe);
}
blobStream = stripe.getStream(dictionaryId, false);
}

public:
Expand Down
37 changes: 11 additions & 26 deletions velox/dwio/dwrf/reader/DwrfData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,27 @@ namespace facebook::velox::dwrf {

void DwrfData::init(StripeStreams& stripe) {
auto format = stripe.format();
EncodingKey encodingKey{nodeType_->id, flatMapContext_.sequence};

DwrfStreamIdentifier presentStream;
DwrfStreamIdentifier rowIndexStream;
if (format == DwrfFormat::kDwrf) {
initDwrf(stripe);
presentStream = encodingKey.forKind(proto::Stream_Kind_PRESENT);
rowIndexStream = encodingKey.forKind(proto::Stream_Kind_ROW_INDEX);
} else {
VELOX_CHECK(format == DwrfFormat::kOrc);
initOrc(stripe);
}
}

void DwrfData::initDwrf(StripeStreams& stripe) {
EncodingKey encodingKey{nodeType_->id, flatMapContext_.sequence};

std::unique_ptr<dwio::common::SeekableInputStream> stream = stripe.getStream(
encodingKey.forKind(proto::Stream_Kind_PRESENT),
streamLabels.label(),
false);
if (stream) {
notNullDecoder_ = createBooleanRleDecoder(std::move(stream), encodingKey);
presentStream = encodingKey.forKind(proto::orc::Stream_Kind_PRESENT);
rowIndexStream = encodingKey.forKind(proto::orc::Stream_Kind_ROW_INDEX);
}

// We always initialize indexStream_ because indices are needed as
// soon as there is a single filter that can trigger row group skips
// anywhere in the reader tree. This is not known at construct time
// because the first filter can come from a hash join or other run
// time pushdown.
indexStream_ = stripe.getStream(
encodingKey.forKind(proto::Stream_Kind_ROW_INDEX),
streamLabels.label(),
false);
}

void DwrfData::initOrc(StripeStreams& stripe) {
EncodingKey encodingKey{nodeType_->id, flatMapContext_.sequence};

std::unique_ptr<dwio::common::SeekableInputStream> stream = stripe.getStream(
encodingKey.forKind(proto::orc::Stream_Kind_PRESENT), false);
std::unique_ptr<dwio::common::SeekableInputStream> stream =
stripe.getStream(presentStream, streamLabels.label(), false);
if (stream) {
notNullDecoder_ = createBooleanRleDecoder(std::move(stream), encodingKey);
}
Expand All @@ -66,8 +52,7 @@ void DwrfData::initOrc(StripeStreams& stripe) {
// anywhere in the reader tree. This is not known at construct time
// because the first filter can come from a hash join or other run
// time pushdown.
indexStream_ = stripe.getStream(
encodingKey.forKind(proto::orc::Stream_Kind_ROW_INDEX), false);
indexStream_ = stripe.getStream(rowIndexStream, false);
}

DwrfData::DwrfData(
Expand Down
2 changes: 0 additions & 2 deletions velox/dwio/dwrf/reader/DwrfData.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ class DwrfData : public dwio::common::FormatData {
}

void init(StripeStreams& stripe);
void initDwrf(StripeStreams& stripe);
void initOrc(StripeStreams& stripe);

memory::MemoryPool& memoryPool_;
const std::shared_ptr<const dwio::common::TypeWithId> nodeType_;
Expand Down
29 changes: 2 additions & 27 deletions velox/dwio/dwrf/reader/ReaderBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,44 +81,19 @@ class ReaderBase {
memory::MemoryPool& pool,
std::unique_ptr<dwio::common::BufferedInput> input,
std::unique_ptr<PostScript> ps,
const proto::Footer* footer,
std::unique_ptr<FooterWrapper> footer,
std::unique_ptr<StripeMetadataCache> cache,
std::unique_ptr<encryption::DecryptionHandler> handler = nullptr)
: pool_{pool},
postScript_{std::move(ps)},
footer_{std::make_unique<FooterWrapper>(footer)},
footer_{std::move(footer)},
cache_{std::move(cache)},
handler_{std::move(handler)},
input_{std::move(input)},
schema_{
std::dynamic_pointer_cast<const RowType>(convertType(*footer_))},
fileLength_{0},
psLength_{0} {
DWIO_ENSURE(footer_->getDwrfPtr()->GetArena());
DWIO_ENSURE_NOT_NULL(schema_, "invalid schema");
if (!handler_) {
handler_ = encryption::DecryptionHandler::create(*footer);
}
}

ReaderBase(
memory::MemoryPool& pool,
std::unique_ptr<dwio::common::BufferedInput> input,
std::unique_ptr<PostScript> ps,
const proto::orc::Footer* footer,
std::unique_ptr<StripeMetadataCache> cache,
std::unique_ptr<encryption::DecryptionHandler> handler = nullptr)
: pool_{pool},
postScript_{std::move(ps)},
footer_{std::make_unique<FooterWrapper>(footer)},
cache_{std::move(cache)},
handler_{std::move(handler)},
input_{std::move(input)},
schema_{
std::dynamic_pointer_cast<const RowType>(convertType(*footer_))},
fileLength_{0},
psLength_{0} {
DWIO_ENSURE(footer_->getOrcPtr()->GetArena());
DWIO_ENSURE_NOT_NULL(schema_, "invalid schema");
if (!handler_) {
handler_ = encryption::DecryptionHandler::create(*footer_);
Expand Down
42 changes: 8 additions & 34 deletions velox/dwio/dwrf/reader/SelectiveByteRleColumnReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,50 +24,24 @@ class SelectiveByteRleColumnReader
: public dwio::common::SelectiveByteRleColumnReader {
void init(DwrfParams& params, bool isBool) {
auto format = params.stripeStreams().format();
if (format == DwrfFormat::kDwrf) {
initDwrf(params, isBool);
} else {
VELOX_CHECK(format == DwrfFormat::kOrc);
initOrc(params, isBool);
}
}

void initDwrf(DwrfParams& params, bool isBool) {
EncodingKey encodingKey{nodeType_->id, params.flatMapContext().sequence};
auto& stripe = params.stripeStreams();
if (isBool) {
boolRle_ = createBooleanRleDecoder(
stripe.getStream(
encodingKey.forKind(proto::Stream_Kind_DATA),
params.streamLabels().label(),
true),
encodingKey);

DwrfStreamIdentifier dataId;
if (format == DwrfFormat::kDwrf) {
dataId = encodingKey.forKind(proto::Stream_Kind_DATA);
} else {
byteRle_ = createByteRleDecoder(
stripe.getStream(
encodingKey.forKind(proto::Stream_Kind_DATA),
params.streamLabels().label(),
true),
encodingKey);
VELOX_CHECK(format == DwrfFormat::kOrc);
dataId = encodingKey.forKind(proto::orc::Stream_Kind_DATA);
}
}

void initOrc(DwrfParams& params, bool isBool) {
EncodingKey encodingKey{nodeType_->id, params.flatMapContext().sequence};
auto& stripe = params.stripeStreams();
if (isBool) {
boolRle_ = createBooleanRleDecoder(
stripe.getStream(
encodingKey.forKind(proto::orc::Stream_Kind_DATA),
params.streamLabels().label(),
true),
stripe.getStream(dataId, params.streamLabels().label(), true),
encodingKey);
} else {
byteRle_ = createByteRleDecoder(
stripe.getStream(
encodingKey.forKind(proto::orc::Stream_Kind_DATA),
params.streamLabels().label(),
true),
stripe.getStream(dataId, params.streamLabels().label(), true),
encodingKey);
}
}
Expand Down
34 changes: 15 additions & 19 deletions velox/dwio/dwrf/reader/SelectiveRepeatedColumnReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,25 @@ std::unique_ptr<dwio::common::IntDecoder</*isSigned*/ false>> makeLengthDecoder(
EncodingKey encodingKey{nodeType.id, params.flatMapContext().sequence};
auto& stripe = params.stripeStreams();
auto format = stripe.format();

RleVersion rleVersion;
DwrfStreamIdentifier lenId;
if (format == DwrfFormat::kDwrf) {
auto rleVersion = convertRleVersion(stripe.getEncoding(encodingKey).kind());
auto lenId = encodingKey.forKind(proto::Stream_Kind_LENGTH);
bool lenVints = stripe.getUseVInts(lenId);
return createRleDecoder</*isSigned*/ false>(
stripe.getStream(lenId, params.streamLabels().label(), true),
rleVersion,
pool,
lenVints,
dwio::common::INT_BYTE_SIZE);
rleVersion = convertRleVersion(stripe.getEncoding(encodingKey).kind());
lenId = encodingKey.forKind(proto::Stream_Kind_LENGTH);
} else {
VELOX_CHECK(format == DwrfFormat::kOrc);
auto rleVersion =
convertRleVersion(stripe.getEncodingOrc(encodingKey).kind());
auto lenId = encodingKey.forKind(proto::orc::Stream_Kind_LENGTH);
bool lenVints = stripe.getUseVInts(lenId);
return createRleDecoder</*isSigned*/ false>(
stripe.getStream(lenId, params.streamLabels().label(), true),
rleVersion,
pool,
lenVints,
dwio::common::INT_BYTE_SIZE);
rleVersion = convertRleVersion(stripe.getEncodingOrc(encodingKey).kind());
lenId = encodingKey.forKind(proto::orc::Stream_Kind_LENGTH);
}

bool lenVints = stripe.getUseVInts(lenId);
return createRleDecoder</*isSigned*/ false>(
stripe.getStream(lenId, params.streamLabels().label(), true),
rleVersion,
pool,
lenVints,
dwio::common::INT_BYTE_SIZE);
}
} // namespace

Expand Down
Loading

0 comments on commit eea25ce

Please sign in to comment.