From 5e9cf26a6a2ea9053a57befa11f9ea866ebb9e8f Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 26 Jun 2023 15:54:54 -0700 Subject: [PATCH 01/77] add DELTA_BINARY_PACKED decoder --- cpp/CMakeLists.txt | 1 + cpp/src/io/parquet/delta_binary.cuh | 272 +++++++++++++++++++ cpp/src/io/parquet/page_data.cu | 43 +-- cpp/src/io/parquet/page_decode.cuh | 4 + cpp/src/io/parquet/page_delta_decode.cu | 165 +++++++++++ cpp/src/io/parquet/page_string_decode.cu | 88 +----- cpp/src/io/parquet/page_string_utils.cuh | 110 ++++++++ cpp/src/io/parquet/parquet_gpu.hpp | 19 ++ cpp/src/io/parquet/reader_impl.cpp | 32 ++- cpp/src/io/parquet/reader_impl_preprocess.cu | 7 +- 10 files changed, 626 insertions(+), 115 deletions(-) create mode 100644 cpp/src/io/parquet/delta_binary.cuh create mode 100644 cpp/src/io/parquet/page_delta_decode.cu create mode 100644 cpp/src/io/parquet/page_string_utils.cuh diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 0742d039092..f8e500ac906 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -392,6 +392,7 @@ add_library( src/io/parquet/chunk_dict.cu src/io/parquet/page_enc.cu src/io/parquet/page_hdr.cu + src/io/parquet/page_delta_decode.cu src/io/parquet/page_string_decode.cu src/io/parquet/reader.cpp src/io/parquet/reader_impl.cpp diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh new file mode 100644 index 00000000000..b02164a377f --- /dev/null +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -0,0 +1,272 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "page_decode.cuh" + +namespace cudf::io::parquet::gpu { + +// DELTA_XXX encoding support +// +// DELTA_BINARY_PACKED is used for INT32 and INT64 data types. Encoding begins with a header +// containing a block size, number of mini-blocks in each block, total value count, and first +// value. The first three are ULEB128 variable length ints, and the last is a zigzag ULEB128 +// varint. +// -- the block size is a multiple of 128 +// -- the mini-block count is chosen so that each mini-block will contain a multiple of 32 values +// -- the value count includes the first value stored in the header +// +// It seems most Parquet encoders will stick with a block size of 128, and 4 mini-blocks of 32 +// elements each. arrow-rs will use a block size of 256 for 64-bit ints. +// +// Following the header are the data blocks. Each block is further divided into mini-blocks, with +// each mini-block having its own encoding bitwidth. Each block begins with a header containing a +// zigzag ULEB128 encoded minimum delta value, followed by an array of uint8 bitwidths, one entry +// per mini-block. While encoding, the lowest delta value is subtracted from all the deltas in the +// block to ensure that all encoded values are positive. The deltas for each mini-block are bit +// packed using the same encoding as the RLE/Bit-Packing Hybrid encoder. +// +// DELTA_BYTE_ARRAY encoding (incremental encoding or front compression), is used for BYTE_ARRAY +// columns. For each element in a sequence of strings, a prefix length from the preceding string +// and a suffix is stored. The prefix lengths are DELTA_BINARY_PACKED encoded. The suffixes are +// encoded with DELTA_LENGTH_BYTE_ARRAY encoding, which is a DELTA_BINARY_PACKED list of suffix +// lengths, followed by the concatenated suffix data. + +// TODO: The delta encodings use ULEB128 integers, but for now we're only +// using max 64 bits. Need to see what the performance impact is of useing +// __int128_t rather than int64_t. +using uleb128_t = uint64_t; +using zigzag128_t = int64_t; + +/** + * @brief Read a ULEB128 varint integer + * + * @param[in,out] cur The current data position, updated after the read + * @param[in] end The end data position + * + * @return The value read + */ +inline __device__ uleb128_t get_uleb128(uint8_t const*& cur, uint8_t const* end) +{ + uleb128_t v = 0, l = 0, c; + while (cur < end) { + c = *cur++; + v |= (c & 0x7f) << l; + l += 7; + if ((c & 0x80) == 0) { return v; } + } + return v; +} + +/** + * @brief Read a ULEB128 zig-zag encoded varint integer + * + * @param[in,out] cur The current data position, updated after the read + * @param[in] end The end data position + * + * @return The value read + */ +inline __device__ zigzag128_t get_zz128(uint8_t const*& cur, uint8_t const* end) +{ + uleb128_t u = get_uleb128(cur, end); + return static_cast((u >> 1u) ^ -static_cast(u & 1)); +} + +struct delta_binary_decoder { + uint8_t const* block_start; // start of data, but updated as data is read + uint8_t const* block_end; // end of data + uleb128_t block_size; // usually 128, must be multiple of 128 + uleb128_t mini_block_count; // usually 4, chosen such that block_size/mini_block_count is a + // multiple of 32 + uleb128_t value_count; // total values encoded in the block + zigzag128_t last_value; // last value decoded, initialized to first_value from header + + uint32_t values_per_mb; // block_size / mini_block_count, must be multiple of 32 + uint32_t current_value_idx; // current value index, initialized to 0 at start of block + + zigzag128_t cur_min_delta; // min delta for the block + uint32_t cur_mb; // index of the current mini-block within the block + uint8_t const* cur_mb_start; // pointer to the start of the current mini-block data + uint8_t const* cur_bitwidths; // pointer to the bitwidth array in the block + + uleb128_t value[non_zero_buffer_size]; // circular buffer of delta values + + // returns the number of values encoded in the block data. when is_decode is true, + // account for the first value in the header. otherwise just count the values encoded + // in the mini-block data. + constexpr uint32_t num_encoded_values(bool is_decode) + { + return value_count == 0 ? 0 : is_decode ? value_count : value_count - 1; + } + + // read mini-block header into state object. should only be called from init_binary_block or + // setup_next_mini_block. header format is: + // + // | min delta (int) | bit-width array (1 byte * mini_block_count) | + // + // on exit db->cur_mb is 0 and db->cur_mb_start points to the first mini-block of data, or + // nullptr if out of data. + inline __device__ void init_mini_block(bool is_decode) + { + cur_mb = 0; + cur_mb_start = nullptr; + + if (current_value_idx < num_encoded_values(is_decode)) { + auto d_start = block_start; + cur_min_delta = get_zz128(d_start, block_end); + cur_bitwidths = d_start; + + d_start += mini_block_count; + cur_mb_start = d_start; + } + } + + // read delta binary header into state object. should be called on thread 0. header format is: + // + // | block size (uint) | mini-block count (uint) | value count (uint) | first value (int) | + // + // also initializes the first mini-block before exit + inline __device__ void init_binary_block(uint8_t const* d_start, uint8_t const* d_end) + { + block_end = d_end; + block_size = get_uleb128(d_start, d_end); + mini_block_count = get_uleb128(d_start, d_end); + value_count = get_uleb128(d_start, d_end); + last_value = get_zz128(d_start, d_end); + + current_value_idx = 0; + values_per_mb = block_size / mini_block_count; + + // init the first mini-block + block_start = d_start; + init_mini_block(false); + } + + // skip to the start of the next mini-block. should only be called on thread 0. + // calls init_binary_block if currently on the last mini-block in a block. + inline __device__ void setup_next_mini_block(bool is_decode) + { + if (current_value_idx >= num_encoded_values(is_decode)) { return; } + + current_value_idx += values_per_mb; + + // just set pointer to start of next mini_block + if (cur_mb < mini_block_count - 1) { + cur_mb_start += cur_bitwidths[cur_mb] * values_per_mb / 8; + cur_mb++; + } + // out of mini-blocks, start a new block + else { + block_start = cur_mb_start + cur_bitwidths[cur_mb] * values_per_mb / 8; + init_mini_block(is_decode); + } + } + + // decode the current mini-batch of deltas, and convert to values. + // called by all threads in a warp, currently only one warp supported. + inline __device__ void calc_mini_block_values(int lane_id) + { + using cudf::detail::warp_size; + if (current_value_idx >= value_count) { return; } + + // need to save first value from header on first pass + if (current_value_idx == 0) { + if (lane_id == 0) { + current_value_idx++; + value[0] = last_value; + } + __syncwarp(); + } + + uint32_t const mb_bits = cur_bitwidths[cur_mb]; + + // need to do in multiple passes if values_per_mb != 32 + uint32_t const num_pass = values_per_mb / warp_size; + + auto d_start = cur_mb_start; + + for (int i = 0; i < num_pass; i++) { + // position at end of the current mini-block since the following calculates + // negative indexes + d_start += (warp_size * mb_bits) / 8; + + // unpack deltas. modified from version in gpuDecodeDictionaryIndices(), but + // that one only unpacks up to bitwidths of 24. simplified some since this + // will always do batches of 32. also replaced branching with a loop. + int64_t delta = 0; + if (lane_id + current_value_idx < value_count) { + int32_t ofs = (lane_id - warp_size) * mb_bits; + uint8_t const* p = d_start + (ofs >> 3); + ofs &= 7; + if (p < block_end) { + uint32_t c = 8 - ofs; // 0 - 7 bits + delta = (*p++) >> ofs; + + while (c < mb_bits && p < block_end) { + delta |= (*p++) << c; + c += 8; + } + delta &= (1 << mb_bits) - 1; + } + } + + // add min delta to get true delta + delta += cur_min_delta; + + // do inclusive scan to get value - first_value at each position + __shared__ cub::WarpScan::TempStorage temp_storage; + cub::WarpScan(temp_storage).InclusiveSum(delta, delta); + + // now add first value from header or last value from previous block to get true value + delta += last_value; + value[rolling_index(current_value_idx + warp_size * i + lane_id)] = delta; + + // save value from last lane in warp. this will become the 'first value' added to the + // deltas calculated in the next iteration (or invocation). + if (lane_id == 31) { last_value = delta; } + __syncwarp(); + } + } + + inline __device__ void skip_values(int skip) + { + int const t = threadIdx.x; + int const lane_id = t & 0x1f; + + while (current_value_idx < skip && current_value_idx < num_encoded_values(true)) { + if (t < 32) { + calc_mini_block_values(lane_id); + if (lane_id == 0) { setup_next_mini_block(true); } + } + __syncthreads(); + } + } + + inline __device__ void decode_batch() + { + int const t = threadIdx.x; + int const lane_id = t & 0x1f; + + // unpack deltas and save in db->value + calc_mini_block_values(lane_id); + + // set up for next mini-block + if (lane_id == 0) { setup_next_mini_block(true); } + } +}; + +} // namespace cudf::io::parquet::gpu diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index e49378485fc..b93cbb6c2c5 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -35,8 +35,8 @@ namespace { * @param[in] src_pos Source position * @param[in] dstv Pointer to row output data (string descriptor or 32-bit hash) */ -inline __device__ void gpuOutputString(volatile page_state_s* s, - volatile page_state_buffers_s* sb, +inline __device__ void gpuOutputString(page_state_s volatile* s, + page_state_buffers_s volatile* sb, int src_pos, void* dstv) { @@ -62,7 +62,7 @@ inline __device__ void gpuOutputString(volatile page_state_s* s, * @param[in] src_pos Source position * @param[in] dst Pointer to row output data */ -inline __device__ void gpuOutputBoolean(volatile page_state_buffers_s* sb, +inline __device__ void gpuOutputBoolean(page_state_buffers_s volatile* sb, int src_pos, uint8_t* dst) { @@ -137,8 +137,8 @@ inline __device__ void gpuStoreOutput(uint2* dst, * @param[in] src_pos Source position * @param[out] dst Pointer to row output data */ -inline __device__ void gpuOutputInt96Timestamp(volatile page_state_s* s, - volatile page_state_buffers_s* sb, +inline __device__ void gpuOutputInt96Timestamp(page_state_s volatile* s, + page_state_buffers_s volatile* sb, int src_pos, int64_t* dst) { @@ -210,8 +210,8 @@ inline __device__ void gpuOutputInt96Timestamp(volatile page_state_s* s, * @param[in] src_pos Source position * @param[in] dst Pointer to row output data */ -inline __device__ void gpuOutputInt64Timestamp(volatile page_state_s* s, - volatile page_state_buffers_s* sb, +inline __device__ void gpuOutputInt64Timestamp(page_state_s volatile* s, + page_state_buffers_s volatile* sb, int src_pos, int64_t* dst) { @@ -292,8 +292,8 @@ __device__ void gpuOutputByteArrayAsInt(char const* ptr, int32_t len, T* dst) * @param[in] dst Pointer to row output data */ template -__device__ void gpuOutputFixedLenByteArrayAsInt(volatile page_state_s* s, - volatile page_state_buffers_s* sb, +__device__ void gpuOutputFixedLenByteArrayAsInt(page_state_s volatile* s, + page_state_buffers_s volatile* sb, int src_pos, T* dst) { @@ -327,8 +327,8 @@ __device__ void gpuOutputFixedLenByteArrayAsInt(volatile page_state_s* s, * @param[in] dst Pointer to row output data */ template -inline __device__ void gpuOutputFast(volatile page_state_s* s, - volatile page_state_buffers_s* sb, +inline __device__ void gpuOutputFast(page_state_s volatile* s, + page_state_buffers_s volatile* sb, int src_pos, T* dst) { @@ -358,7 +358,7 @@ inline __device__ void gpuOutputFast(volatile page_state_s* s, * @param[in] len Length of element */ static __device__ void gpuOutputGeneric( - volatile page_state_s* s, volatile page_state_buffers_s* sb, int src_pos, uint8_t* dst8, int len) + page_state_s volatile* s, page_state_buffers_s volatile* sb, int src_pos, uint8_t* dst8, int len) { uint8_t const* dict; uint32_t dict_pos, dict_size = s->dict_size; @@ -422,7 +422,7 @@ __device__ size_type gpuDecodeTotalPageStringSize(page_state_s* s, int t) } else if ((s->col.data_type & 7) == BYTE_ARRAY) { str_len = gpuInitStringDescriptors(s, nullptr, target_pos, t); } - if (!t) { *(volatile int32_t*)&s->dict_pos = target_pos; } + if (!t) { *(int32_t volatile*)&s->dict_pos = target_pos; } return str_len; } @@ -736,6 +736,17 @@ __global__ void __launch_bounds__(preprocess_block_size) } } +// skips strings and delta encodings +struct catch_all_filter { + device_span chunks; + + __device__ inline bool operator()(PageInfo const& page) + { + return !(is_string_col(page, chunks) || page.encoding == Encoding::DELTA_BINARY_PACKED || + page.encoding == Encoding::DELTA_BYTE_ARRAY); + } +}; + /** * @brief Kernel for computing the column data stored in the pages * @@ -764,7 +775,7 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( [[maybe_unused]] null_count_back_copier _{s, t}; if (!setupLocalPageInfo( - s, &pages[page_idx], chunks, min_row, num_rows, non_string_filter{chunks}, true)) { + s, &pages[page_idx], chunks, min_row, num_rows, catch_all_filter{chunks}, true)) { return; } @@ -814,7 +825,7 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( } else if ((s->col.data_type & 7) == BYTE_ARRAY) { gpuInitStringDescriptors(s, sb, src_target_pos, t & 0x1f); } - if (t == 32) { *(volatile int32_t*)&s->dict_pos = src_target_pos; } + if (t == 32) { *(int32_t volatile*)&s->dict_pos = src_target_pos; } } else { // WARP1..WARP3: Decode values int const dtype = s->col.data_type & 7; @@ -901,7 +912,7 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( } } - if (t == out_thread0) { *(volatile int32_t*)&s->src_pos = target_pos; } + if (t == out_thread0) { *(int32_t volatile*)&s->src_pos = target_pos; } } __syncthreads(); } diff --git a/cpp/src/io/parquet/page_decode.cuh b/cpp/src/io/parquet/page_decode.cuh index 4469ec59b7a..eba3698730c 100644 --- a/cpp/src/io/parquet/page_decode.cuh +++ b/cpp/src/io/parquet/page_decode.cuh @@ -1248,6 +1248,10 @@ inline __device__ bool setupLocalPageInfo(page_state_s* const s, if ((s->col.data_type & 7) == BOOLEAN) { s->dict_run = s->dict_size * 2 + 1; } break; case Encoding::RLE: s->dict_run = 0; break; + case Encoding::DELTA_BINARY_PACKED: + // nothing to do, just don't error + break; + default: s->error = 1; // Unsupported encoding break; diff --git a/cpp/src/io/parquet/page_delta_decode.cu b/cpp/src/io/parquet/page_delta_decode.cu new file mode 100644 index 00000000000..3017783cd44 --- /dev/null +++ b/cpp/src/io/parquet/page_delta_decode.cu @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "delta_binary.cuh" +#include "page_string_utils.cuh" +#include "parquet_gpu.hpp" + +#include + +#include +#include + +namespace cudf::io::parquet::gpu { + +namespace { + +// functor for setupLocalPageInfo +struct delta_binary_filter { + __device__ inline bool operator()(PageInfo const& page) + { + return page.encoding == Encoding::DELTA_BINARY_PACKED; + } +}; + +// Decode page data that is DELTA_BINARY_PACKED encoded. This encoding is +// only used for int32 and int64 physical types (and appears to only be used +// with V2 page headers; see https://www.mail-archive.com/dev@parquet.apache.org/msg11826.html). +// this kernel only needs 96 threads (3 warps)(for now). +template +__global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( + PageInfo* pages, device_span chunks, size_t min_row, size_t num_rows) +{ + __shared__ __align__(16) delta_binary_decoder db_state; + __shared__ __align__(16) page_state_buffers_s state_buffers; + __shared__ __align__(16) page_state_s state_g; + + page_state_s* const s = &state_g; + page_state_buffers_s* const sb = &state_buffers; + int const page_idx = blockIdx.x; + int const t = threadIdx.x; + int const lane_id = t & 0x1f; + auto* const db = &db_state; + [[maybe_unused]] null_count_back_copier _{s, t}; + + if (!setupLocalPageInfo( + s, &pages[page_idx], chunks, min_row, num_rows, delta_binary_filter{}, true)) { + return; + } + + bool const has_repetition = s->col.max_level[level_type::REPETITION] > 0; + + // copying logic from gpuDecodePageData. + PageNestingDecodeInfo const* nesting_info_base = s->nesting_info; + + __shared__ level_t rep[non_zero_buffer_size]; // circular buffer of repetition level values + __shared__ level_t def[non_zero_buffer_size]; // circular buffer of definition level values + + // skipped_leaf_values will always be 0 for flat hierarchies. + uint32_t const skipped_leaf_values = s->page.skipped_leaf_values; + + // initialize delta state + if (t == 0) { db->init_binary_block(s->data_start, s->data_end); } + __syncthreads(); + + auto const batch_size = db->values_per_mb; + + // if skipped_leaf_values is non-zero, then we need to decode up to the first mini-block + // that has a value we need. + if (skipped_leaf_values > 0) { db->skip_values(skipped_leaf_values); } + + while (!s->error && (s->input_value_count < s->num_input_values || s->src_pos < s->nz_count)) { + uint32_t target_pos; + uint32_t const src_pos = s->src_pos; + + if (t < 64) { // warp0..1 + target_pos = min(src_pos + 2 * batch_size, s->nz_count + batch_size); + } else { // warp2... + target_pos = min(s->nz_count, src_pos + batch_size); + } + __syncthreads(); + + // warp0 will decode the rep/def levels, warp1 will unpack a mini-batch of deltas. + // warp2 waits one cycle for warps 0/1 to produce a batch, and then stuffs values + // into the proper location in the output. + if (t < 32) { + // warp 0 + // decode repetition and definition levels. + // - update validity vectors + // - updates offsets (for nested columns) + // - produces non-NULL value indices in s->nz_idx for subsequent decoding + gpuDecodeLevels(s, sb, target_pos, rep, def, t); + } else if (t < 64) { + // warp 1 + db->decode_batch(); + + } else if (t < 96 && src_pos < target_pos) { + // warp 2 + // nesting level that is storing actual leaf values + int const leaf_level_index = s->col.max_nesting_depth - 1; + + // process the mini-block in batches of 32 + for (uint32_t sp = src_pos + lane_id; sp < src_pos + batch_size; sp += 32) { + // the position in the output column/buffer + int32_t dst_pos = sb->nz_idx[rolling_index(sp)]; + + // handle skip_rows here. flat hierarchies can just skip up to first_row. + if (!has_repetition) { dst_pos -= s->first_row; } + + // place value for this thread + if (dst_pos >= 0 && sp < target_pos) { + void* const dst = nesting_info_base[leaf_level_index].data_out + dst_pos * s->dtype_len; + if (s->dtype_len == 8) { + *static_cast(dst) = db->value[rolling_index(sp + skipped_leaf_values)]; + } else if (s->dtype_len == 4) { + *static_cast(dst) = db->value[rolling_index(sp + skipped_leaf_values)]; + } + } + } + + if (lane_id == 0) { s->src_pos = src_pos + batch_size; } + } + __syncthreads(); + } +} + +} // anonymous namespace + +/** + * @copydoc cudf::io::parquet::gpu::DecodeDeltaBinary + */ +void __host__ DecodeDeltaBinary(cudf::detail::hostdevice_vector& pages, + cudf::detail::hostdevice_vector const& chunks, + size_t num_rows, + size_t min_row, + int level_type_size, + rmm::cuda_stream_view stream) +{ + CUDF_EXPECTS(pages.size() > 0, "There is no page to decode"); + + dim3 dim_block(96, 1); + dim3 dim_grid(pages.size(), 1); // 1 threadblock per page + + if (level_type_size == 1) { + gpuDecodeDeltaBinary + <<>>(pages.device_ptr(), chunks, min_row, num_rows); + } else { + gpuDecodeDeltaBinary + <<>>(pages.device_ptr(), chunks, min_row, num_rows); + } +} + +} // namespace cudf::io::parquet::gpu diff --git a/cpp/src/io/parquet/page_string_decode.cu b/cpp/src/io/parquet/page_string_decode.cu index 9173d408192..20686ba7c84 100644 --- a/cpp/src/io/parquet/page_string_decode.cu +++ b/cpp/src/io/parquet/page_string_decode.cu @@ -15,6 +15,7 @@ */ #include "page_decode.cuh" +#include "page_string_utils.cuh" #include #include @@ -26,93 +27,6 @@ namespace gpu { namespace { -// stole this from cudf/strings/detail/gather.cuh. modified to run on a single string on one warp. -// copies from src to dst in 16B chunks per thread. -__device__ void wideStrcpy(uint8_t* dst, uint8_t const* src, size_t len, uint32_t lane_id) -{ - using cudf::detail::warp_size; - using cudf::strings::detail::load_uint4; - - constexpr size_t out_datatype_size = sizeof(uint4); - constexpr size_t in_datatype_size = sizeof(uint); - - auto const alignment_offset = reinterpret_cast(dst) % out_datatype_size; - uint4* out_chars_aligned = reinterpret_cast(dst - alignment_offset); - auto const in_start = src; - - // Both `out_start_aligned` and `out_end_aligned` are indices into `dst`. - // `out_start_aligned` is the first 16B aligned memory location after `dst + 4`. - // `out_end_aligned` is the last 16B aligned memory location before `len - 4`. Characters - // between `[out_start_aligned, out_end_aligned)` will be copied using uint4. - // `dst + 4` and `len - 4` are used instead of `dst` and `len` to avoid - // `load_uint4` reading beyond string boundaries. - // use signed int since out_end_aligned can be negative. - int64_t out_start_aligned = (in_datatype_size + alignment_offset + out_datatype_size - 1) / - out_datatype_size * out_datatype_size - - alignment_offset; - int64_t out_end_aligned = - (len - in_datatype_size + alignment_offset) / out_datatype_size * out_datatype_size - - alignment_offset; - - for (int64_t ichar = out_start_aligned + lane_id * out_datatype_size; ichar < out_end_aligned; - ichar += warp_size * out_datatype_size) { - *(out_chars_aligned + (ichar + alignment_offset) / out_datatype_size) = - load_uint4((const char*)in_start + ichar); - } - - // Tail logic: copy characters of the current string outside - // `[out_start_aligned, out_end_aligned)`. - if (out_end_aligned <= out_start_aligned) { - // In this case, `[out_start_aligned, out_end_aligned)` is an empty set, and we copy the - // entire string. - for (int64_t ichar = lane_id; ichar < len; ichar += warp_size) { - dst[ichar] = in_start[ichar]; - } - } else { - // Copy characters in range `[0, out_start_aligned)`. - if (lane_id < out_start_aligned) { dst[lane_id] = in_start[lane_id]; } - // Copy characters in range `[out_end_aligned, len)`. - int64_t ichar = out_end_aligned + lane_id; - if (ichar < len) { dst[ichar] = in_start[ichar]; } - } -} - -/** - * @brief char-parallel string copy. - */ -__device__ void ll_strcpy(uint8_t* dst, uint8_t const* src, size_t len, uint32_t lane_id) -{ - using cudf::detail::warp_size; - if (len > 64) { - wideStrcpy(dst, src, len, lane_id); - } else { - for (int i = lane_id; i < len; i += warp_size) { - dst[i] = src[i]; - } - } -} - -/** - * @brief Perform exclusive scan on an array of any length using a single block of threads. - */ -template -__device__ void block_excl_sum(size_type* arr, size_type length, size_type initial_value) -{ - using block_scan = cub::BlockScan; - __shared__ typename block_scan::TempStorage scan_storage; - int const t = threadIdx.x; - - // do a series of block sums, storing results in arr as we go - for (int pos = 0; pos < length; pos += block_size) { - int const tidx = pos + t; - size_type tval = tidx < length ? arr[tidx] : 0; - size_type block_sum; - block_scan(scan_storage).ExclusiveScan(tval, tval, initial_value, cub::Sum(), block_sum); - if (tidx < length) { arr[tidx] = tval; } - initial_value += block_sum; - } -} - /** * @brief Compute the start and end page value bounds for this page * diff --git a/cpp/src/io/parquet/page_string_utils.cuh b/cpp/src/io/parquet/page_string_utils.cuh new file mode 100644 index 00000000000..fb36c09052c --- /dev/null +++ b/cpp/src/io/parquet/page_string_utils.cuh @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cudf::io::parquet::gpu { + +// stole this from cudf/strings/detail/gather.cuh. modified to run on a single string on one warp. +// copies from src to dst in 16B chunks per thread. +inline __device__ void wideStrcpy(uint8_t* dst, uint8_t const* src, size_t len, uint32_t lane_id) +{ + using cudf::detail::warp_size; + using cudf::strings::detail::load_uint4; + + constexpr size_t out_datatype_size = sizeof(uint4); + constexpr size_t in_datatype_size = sizeof(uint); + + auto const alignment_offset = reinterpret_cast(dst) % out_datatype_size; + uint4* out_chars_aligned = reinterpret_cast(dst - alignment_offset); + auto const in_start = src; + + // Both `out_start_aligned` and `out_end_aligned` are indices into `dst`. + // `out_start_aligned` is the first 16B aligned memory location after `dst + 4`. + // `out_end_aligned` is the last 16B aligned memory location before `len - 4`. Characters + // between `[out_start_aligned, out_end_aligned)` will be copied using uint4. + // `dst + 4` and `len - 4` are used instead of `dst` and `len` to avoid + // `load_uint4` reading beyond string boundaries. + // use signed int since out_end_aligned can be negative. + int64_t out_start_aligned = (in_datatype_size + alignment_offset + out_datatype_size - 1) / + out_datatype_size * out_datatype_size - + alignment_offset; + int64_t out_end_aligned = + (len - in_datatype_size + alignment_offset) / out_datatype_size * out_datatype_size - + alignment_offset; + + for (int64_t ichar = out_start_aligned + lane_id * out_datatype_size; ichar < out_end_aligned; + ichar += warp_size * out_datatype_size) { + *(out_chars_aligned + (ichar + alignment_offset) / out_datatype_size) = + load_uint4((const char*)in_start + ichar); + } + + // Tail logic: copy characters of the current string outside + // `[out_start_aligned, out_end_aligned)`. + if (out_end_aligned <= out_start_aligned) { + // In this case, `[out_start_aligned, out_end_aligned)` is an empty set, and we copy the + // entire string. + for (int64_t ichar = lane_id; ichar < len; ichar += warp_size) { + dst[ichar] = in_start[ichar]; + } + } else { + // Copy characters in range `[0, out_start_aligned)`. + if (lane_id < out_start_aligned) { dst[lane_id] = in_start[lane_id]; } + // Copy characters in range `[out_end_aligned, len)`. + int64_t ichar = out_end_aligned + lane_id; + if (ichar < len) { dst[ichar] = in_start[ichar]; } + } +} + +/** + * @brief char-parallel string copy. + */ +inline __device__ void ll_strcpy(uint8_t* dst, uint8_t const* src, size_t len, uint32_t lane_id) +{ + using cudf::detail::warp_size; + if (len > 64) { + wideStrcpy(dst, src, len, lane_id); + } else { + for (int i = lane_id; i < len; i += warp_size) { + dst[i] = src[i]; + } + } +} + +/** + * @brief Perform exclusive scan for offsets array. Called for each page. + */ +template +__device__ void block_excl_sum(size_type* arr, size_type length, size_type initial_value) +{ + using block_scan = cub::BlockScan; + __shared__ typename block_scan::TempStorage scan_storage; + int const t = threadIdx.x; + + // do a series of block sums, storing results in arr as we go + for (int pos = 0; pos < length; pos += block_size) { + int const tidx = pos + t; + size_type tval = tidx < length ? arr[tidx] : 0; + size_type block_sum; + block_scan(scan_storage).ExclusiveScan(tval, tval, initial_value, cub::Sum(), block_sum); + if (tidx < length) { arr[tidx] = tval; } + initial_value += block_sum; + } +} + +} // namespace cudf::io::parquet::gpu diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 25d2885b7da..7690dcae611 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -549,6 +549,25 @@ void DecodeStringPageData(cudf::detail::hostdevice_vector& pages, int level_type_size, rmm::cuda_stream_view stream); +/** + * @brief Launches kernel for reading the DELTA_BINARY_PACKED column data stored in the pages + * + * The page data will be written to the output pointed to in the page's + * associated column chunk. + * + * @param[in,out] pages All pages to be decoded + * @param[in] chunks All chunks to be decoded + * @param[in] num_rows Total number of rows to read + * @param[in] min_row Minimum number of rows to read + * @param[in] stream CUDA stream to use, default 0 + */ +void DecodeDeltaBinary(cudf::detail::hostdevice_vector& pages, + cudf::detail::hostdevice_vector const& chunks, + size_t num_rows, + size_t min_row, + int level_type_size, + rmm::cuda_stream_view stream); + /** * @brief Launches kernel for initializing encoder row group fragments * diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 0237bf820b0..44d5e0d4fba 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -25,7 +25,7 @@ namespace cudf::io::detail::parquet { namespace { -int constexpr NUM_DECODERS = 2; // how many decode kernels are there to run +int constexpr NUM_DECODERS = 3; // how many decode kernels are there to run int constexpr APPROX_NUM_THREADS = 4; // guestimate from DaveB int constexpr STREAM_POOL_SIZE = NUM_DECODERS * APPROX_NUM_THREADS; @@ -176,16 +176,30 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) chunk_nested_data.host_to_device_async(_stream); _stream.synchronize(); - auto stream1 = get_stream_pool().get_stream(); - gpu::DecodePageData(pages, chunks, num_rows, skip_rows, _file_itm_data.level_type_size, stream1); + bool const has_delta_binary = std::any_of(pages.begin(), pages.end(), [](auto& page) { + return page.encoding == Encoding::DELTA_BINARY_PACKED; + }); + + auto const level_type_size = _file_itm_data.level_type_size; + + // launch the catch-all page decoder + std::vector streams; + streams.push_back(get_stream_pool().get_stream()); + gpu::DecodePageData(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); + + // and then the specializations if (has_strings) { - auto stream2 = get_stream_pool().get_stream(); - chunk_nested_str_data.host_to_device_async(stream2); - gpu::DecodeStringPageData( - pages, chunks, num_rows, skip_rows, _file_itm_data.level_type_size, stream2); - stream2.synchronize(); + streams.push_back(get_stream_pool().get_stream()); + chunk_nested_str_data.host_to_device_async(streams.back()); + gpu::DecodeStringPageData(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); } - stream1.synchronize(); + if (has_delta_binary) { + streams.push_back(get_stream_pool().get_stream()); + gpu::DecodeDeltaBinary(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); + } + + // synchronize the streams + std::for_each(streams.begin(), streams.end(), [](auto& stream) { stream.synchronize(); }); pages.device_to_host_async(_stream); page_nesting.device_to_host_async(_stream); diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 8c3bdabe6b4..d3ff9651f8a 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -323,7 +323,8 @@ constexpr bool is_supported_encoding(Encoding enc) case Encoding::PLAIN: case Encoding::PLAIN_DICTIONARY: case Encoding::RLE: - case Encoding::RLE_DICTIONARY: return true; + case Encoding::RLE_DICTIONARY: + case Encoding::DELTA_BINARY_PACKED: return true; default: return false; } } @@ -729,8 +730,8 @@ std::pair>> reader::impl::create_and_read_co auto& chunks = _file_itm_data.chunks; // Descriptors for all the chunks that make up the selected columns - const auto num_input_columns = _input_columns.size(); - const auto num_chunks = row_groups_info.size() * num_input_columns; + auto const num_input_columns = _input_columns.size(); + auto const num_chunks = row_groups_info.size() * num_input_columns; chunks = cudf::detail::hostdevice_vector(0, num_chunks, _stream); // Association between each column chunk and its source From 9326321fe3c42bf1bba5d0298a0f0f9ec2009684 Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 26 Jun 2023 17:51:10 -0700 Subject: [PATCH 02/77] start merging in changes from #13622 --- cpp/src/io/parquet/page_data.cu | 22 ++++++++++++++++++++++ cpp/src/io/parquet/page_hdr.cu | 16 ++++++++++++++++ cpp/src/io/parquet/parquet_gpu.hpp | 18 ++++++++++++++++++ cpp/src/io/parquet/reader_impl.cpp | 12 +++++------- 4 files changed, 61 insertions(+), 7 deletions(-) diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index b93cbb6c2c5..1edc32de158 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -20,6 +20,9 @@ #include +#include +#include + namespace cudf { namespace io { namespace parquet { @@ -920,6 +923,25 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( } // anonymous namespace +uint32_t GetKernelMasks(cudf::detail::hostdevice_vector& pages, + rmm::cuda_stream_view stream) +{ + // determine which kernels to invoke + // FIXME: when running on device I get and 'invalid device function' error +#if 0 + auto mask_iter = thrust::make_transform_iterator( + pages.d_begin(), [] __device__(auto const& p) { return p.kernel_mask; }); + auto const kernel_mask = thrust::reduce( + rmm::exec_policy(stream), mask_iter, mask_iter + pages.size(), 0U, thrust::bit_or{}); +#else + auto mask_iter = + thrust::make_transform_iterator(pages.begin(), [](auto const& p) { return p.kernel_mask; }); + auto const kernel_mask = + thrust::reduce(mask_iter, mask_iter + pages.size(), 0U, thrust::bit_or{}); +#endif + return kernel_mask; +} + /** * @copydoc cudf::io::parquet::gpu::ComputePageSizes */ diff --git a/cpp/src/io/parquet/page_hdr.cu b/cpp/src/io/parquet/page_hdr.cu index 16886d91fc9..bc00e48a69c 100644 --- a/cpp/src/io/parquet/page_hdr.cu +++ b/cpp/src/io/parquet/page_hdr.cu @@ -154,6 +154,20 @@ __device__ void skip_struct_field(byte_stream_s* bs, int field_type) } while (rep_cnt || struct_depth); } +__device__ uint32_t get_kernel_mask(gpu::PageInfo const& page, gpu::ColumnChunkDesc const& chunk) +{ + if (page.flags & PAGEINFO_FLAGS_DICTIONARY) { return 0; } + + // non-string, non-nested, non-dict, non-boolean types + if (page.encoding == Encoding::DELTA_BINARY_PACKED) { + return KERNEL_MASK_DELTA_BINARY; + } else if (is_string_col(chunk)) { + return KERNEL_MASK_STRING; + } + + return KERNEL_MASK_GENERAL; +} + /** * @brief Functor to set value to 32 bit integer read from byte stream * @@ -370,6 +384,7 @@ __global__ void __launch_bounds__(128) bs->page.skipped_values = -1; bs->page.skipped_leaf_values = 0; bs->page.str_bytes = 0; + bs->page.kernel_mask = 0; } num_values = bs->ck.num_values; page_info = bs->ck.page_info; @@ -420,6 +435,7 @@ __global__ void __launch_bounds__(128) } bs->page.page_data = const_cast(bs->cur); bs->cur += bs->page.compressed_page_size; + bs->page.kernel_mask = get_kernel_mask(bs->page, bs->ck); } else { bs->cur = bs->end; } diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 7690dcae611..92bcd947b4b 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -86,6 +86,16 @@ enum level_type { NUM_LEVEL_TYPES }; +enum kernel_mask_bits { + KERNEL_MASK_GENERAL = (1 << 0), + KERNEL_MASK_STRING = (1 << 1), + KERNEL_MASK_DELTA_BINARY = (1 << 2) + // KERNEL_MASK_FIXED_WIDTH_DICT, + // KERNEL_MASK_STRINGS, + // KERNEL_NESTED_ + // etc +}; + /** * @brief Nesting information specifically needed by the decode and preprocessing * kernels. @@ -203,6 +213,8 @@ struct PageInfo { // level decode buffers uint8_t* lvl_decode_buf[level_type::NUM_LEVEL_TYPES]; + + uint32_t kernel_mask; }; /** @@ -454,6 +466,12 @@ void BuildStringDictionaryIndex(ColumnChunkDesc* chunks, int32_t num_chunks, rmm::cuda_stream_view stream); +/** + * @brief Get OR'd sum of page kernel masks. + */ +uint32_t GetKernelMasks(cudf::detail::hostdevice_vector& pages, + rmm::cuda_stream_view stream); + /** * @brief Compute page output size information. * diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 44d5e0d4fba..1b6302fa92a 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -56,14 +56,16 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) return cursum + _metadata->get_output_nesting_depth(chunk.src_col_schema); }); + // figure out which kernels to run + auto const kernel_mask = GetKernelMasks(pages, _stream); + // Check to see if there are any string columns present. If so, then we need to get size info // for each string page. This size info will be used to pre-allocate memory for the column, // allowing the page decoder to write string data directly to the column buffer, rather than // doing a gather operation later on. // TODO: This step is somewhat redundant if size info has already been calculated (nested schema, // chunked reader). - auto const has_strings = std::any_of(chunks.begin(), chunks.end(), gpu::is_string_col); - + auto const has_strings = (kernel_mask & gpu::KERNEL_MASK_STRING) != 0; std::vector col_sizes(_input_columns.size(), 0L); if (has_strings) { gpu::ComputePageStringSizes( @@ -176,10 +178,6 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) chunk_nested_data.host_to_device_async(_stream); _stream.synchronize(); - bool const has_delta_binary = std::any_of(pages.begin(), pages.end(), [](auto& page) { - return page.encoding == Encoding::DELTA_BINARY_PACKED; - }); - auto const level_type_size = _file_itm_data.level_type_size; // launch the catch-all page decoder @@ -193,7 +191,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) chunk_nested_str_data.host_to_device_async(streams.back()); gpu::DecodeStringPageData(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); } - if (has_delta_binary) { + if ((kernel_mask & gpu::KERNEL_MASK_DELTA_BINARY) != 0) { streams.push_back(get_stream_pool().get_stream()); gpu::DecodeDeltaBinary(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); } From ee7511d2826c3a400d0a9040fec3f7de62ee2a65 Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 27 Jun 2023 11:58:17 -0700 Subject: [PATCH 03/77] get reduce working on device --- cpp/src/io/parquet/page_data.cu | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index 1edc32de158..6a0ba62e134 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -927,18 +927,12 @@ uint32_t GetKernelMasks(cudf::detail::hostdevice_vector& pages, rmm::cuda_stream_view stream) { // determine which kernels to invoke - // FIXME: when running on device I get and 'invalid device function' error -#if 0 + // TODO: if lambda doesn't also have the __host__ decorator an 'invalid device function' exception + // is sometimes thrown auto mask_iter = thrust::make_transform_iterator( - pages.d_begin(), [] __device__(auto const& p) { return p.kernel_mask; }); + pages.d_begin(), [] __host__ __device__(PageInfo const& p) { return p.kernel_mask; }); auto const kernel_mask = thrust::reduce( rmm::exec_policy(stream), mask_iter, mask_iter + pages.size(), 0U, thrust::bit_or{}); -#else - auto mask_iter = - thrust::make_transform_iterator(pages.begin(), [](auto const& p) { return p.kernel_mask; }); - auto const kernel_mask = - thrust::reduce(mask_iter, mask_iter + pages.size(), 0U, thrust::bit_or{}); -#endif return kernel_mask; } From 2cafe62701828e76ca2a36394b4333cbad92d9ac Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 27 Jun 2023 12:06:01 -0700 Subject: [PATCH 04/77] use functor for transform iterator --- cpp/src/io/parquet/page_data.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index 6a0ba62e134..e8c6d2dd32a 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -921,19 +921,19 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( } } +struct mask_tform { + __device__ uint32_t operator()(PageInfo const& p) { return p.kernel_mask; } +}; + } // anonymous namespace uint32_t GetKernelMasks(cudf::detail::hostdevice_vector& pages, rmm::cuda_stream_view stream) { // determine which kernels to invoke - // TODO: if lambda doesn't also have the __host__ decorator an 'invalid device function' exception - // is sometimes thrown - auto mask_iter = thrust::make_transform_iterator( - pages.d_begin(), [] __host__ __device__(PageInfo const& p) { return p.kernel_mask; }); - auto const kernel_mask = thrust::reduce( + auto mask_iter = thrust::make_transform_iterator(pages.d_begin(), mask_tform{}); + return thrust::reduce( rmm::exec_policy(stream), mask_iter, mask_iter + pages.size(), 0U, thrust::bit_or{}); - return kernel_mask; } /** From 219ff0bf198c0f249fc9537c0136de28505a416e Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 27 Jun 2023 12:32:41 -0700 Subject: [PATCH 05/77] change filter functors to use kernel_mask --- cpp/src/io/parquet/page_data.cu | 13 +------------ cpp/src/io/parquet/page_decode.cuh | 18 ++++-------------- cpp/src/io/parquet/page_delta_decode.cu | 17 +++++++---------- cpp/src/io/parquet/page_string_decode.cu | 8 +++----- 4 files changed, 15 insertions(+), 41 deletions(-) diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index e8c6d2dd32a..648a286e392 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -739,17 +739,6 @@ __global__ void __launch_bounds__(preprocess_block_size) } } -// skips strings and delta encodings -struct catch_all_filter { - device_span chunks; - - __device__ inline bool operator()(PageInfo const& page) - { - return !(is_string_col(page, chunks) || page.encoding == Encoding::DELTA_BINARY_PACKED || - page.encoding == Encoding::DELTA_BYTE_ARRAY); - } -}; - /** * @brief Kernel for computing the column data stored in the pages * @@ -778,7 +767,7 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( [[maybe_unused]] null_count_back_copier _{s, t}; if (!setupLocalPageInfo( - s, &pages[page_idx], chunks, min_row, num_rows, catch_all_filter{chunks}, true)) { + s, &pages[page_idx], chunks, min_row, num_rows, mask_filter{KERNEL_MASK_GENERAL}, true)) { return; } diff --git a/cpp/src/io/parquet/page_decode.cuh b/cpp/src/io/parquet/page_decode.cuh index eba3698730c..8e8c15fefe5 100644 --- a/cpp/src/io/parquet/page_decode.cuh +++ b/cpp/src/io/parquet/page_decode.cuh @@ -934,21 +934,11 @@ struct all_types_filter { }; /** - * @brief Functor for setupLocalPageInfo that returns true if this is not a string column. + * @brief Functor for setupLocalPageInfo that takes a mask of allowed types. */ -struct non_string_filter { - device_span chunks; - - __device__ inline bool operator()(PageInfo const& page) { return !is_string_col(page, chunks); } -}; - -/** - * @brief Functor for setupLocalPageInfo that returns true if this is a string column. - */ -struct string_filter { - device_span chunks; - - __device__ inline bool operator()(PageInfo const& page) { return is_string_col(page, chunks); } +struct mask_filter { + int mask; + __device__ inline bool operator()(PageInfo const& page) { return (page.kernel_mask & mask) != 0; } }; /** diff --git a/cpp/src/io/parquet/page_delta_decode.cu b/cpp/src/io/parquet/page_delta_decode.cu index 3017783cd44..9d345571495 100644 --- a/cpp/src/io/parquet/page_delta_decode.cu +++ b/cpp/src/io/parquet/page_delta_decode.cu @@ -27,14 +27,6 @@ namespace cudf::io::parquet::gpu { namespace { -// functor for setupLocalPageInfo -struct delta_binary_filter { - __device__ inline bool operator()(PageInfo const& page) - { - return page.encoding == Encoding::DELTA_BINARY_PACKED; - } -}; - // Decode page data that is DELTA_BINARY_PACKED encoded. This encoding is // only used for int32 and int64 physical types (and appears to only be used // with V2 page headers; see https://www.mail-archive.com/dev@parquet.apache.org/msg11826.html). @@ -55,8 +47,13 @@ __global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( auto* const db = &db_state; [[maybe_unused]] null_count_back_copier _{s, t}; - if (!setupLocalPageInfo( - s, &pages[page_idx], chunks, min_row, num_rows, delta_binary_filter{}, true)) { + if (!setupLocalPageInfo(s, + &pages[page_idx], + chunks, + min_row, + num_rows, + mask_filter{KERNEL_MASK_DELTA_BINARY}, + true)) { return; } diff --git a/cpp/src/io/parquet/page_string_decode.cu b/cpp/src/io/parquet/page_string_decode.cu index 20686ba7c84..96b4f8f5842 100644 --- a/cpp/src/io/parquet/page_string_decode.cu +++ b/cpp/src/io/parquet/page_string_decode.cu @@ -463,9 +463,6 @@ __global__ void __launch_bounds__(preprocess_block_size) gpuComputePageStringSiz { __shared__ __align__(16) page_state_s state_g; - // only count if it's a string column - if (not is_string_col(pages[blockIdx.x], chunks)) { return; } - page_state_s* const s = &state_g; int const page_idx = blockIdx.x; int const t = threadIdx.x; @@ -483,7 +480,8 @@ __global__ void __launch_bounds__(preprocess_block_size) gpuComputePageStringSiz rle_stream decoders[level_type::NUM_LEVEL_TYPES] = {{def_runs}, {rep_runs}}; // setup page info - if (!setupLocalPageInfo(s, pp, chunks, min_row, num_rows, string_filter{chunks}, false)) { + if (!setupLocalPageInfo( + s, pp, chunks, min_row, num_rows, mask_filter{KERNEL_MASK_STRING}, false)) { return; } @@ -580,7 +578,7 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodeStringPageData( [[maybe_unused]] null_count_back_copier _{s, t}; if (!setupLocalPageInfo( - s, &pages[page_idx], chunks, min_row, num_rows, string_filter{chunks}, true)) { + s, &pages[page_idx], chunks, min_row, num_rows, mask_filter{KERNEL_MASK_STRING}, true)) { return; } From 0e181a801c7437f3a6113ba200fc238b4ac37ec4 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 28 Jun 2023 15:47:01 -0700 Subject: [PATCH 06/77] pull in changes from #13622 --- cpp/CMakeLists.txt | 1 + cpp/src/io/parquet/decode_preprocess.cu | 417 +++++++++++++++++++++ cpp/src/io/parquet/delta_binary.cuh | 14 +- cpp/src/io/parquet/page_data.cu | 441 +++-------------------- cpp/src/io/parquet/page_decode.cuh | 120 +++--- cpp/src/io/parquet/page_delta_decode.cu | 36 +- cpp/src/io/parquet/page_string_decode.cu | 73 ++-- cpp/src/io/parquet/parquet_gpu.hpp | 6 + cpp/src/io/parquet/reader_impl.cpp | 14 +- cpp/src/io/parquet/rle_stream.cuh | 36 +- 10 files changed, 630 insertions(+), 528 deletions(-) create mode 100644 cpp/src/io/parquet/decode_preprocess.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index f8e500ac906..810ebb8198e 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -388,6 +388,7 @@ add_library( src/io/orc/writer_impl.cu src/io/parquet/compact_protocol_reader.cpp src/io/parquet/compact_protocol_writer.cpp + src/io/parquet/decode_preprocess.cu src/io/parquet/page_data.cu src/io/parquet/chunk_dict.cu src/io/parquet/page_enc.cu diff --git a/cpp/src/io/parquet/decode_preprocess.cu b/cpp/src/io/parquet/decode_preprocess.cu new file mode 100644 index 00000000000..439422d3554 --- /dev/null +++ b/cpp/src/io/parquet/decode_preprocess.cu @@ -0,0 +1,417 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "page_decode.cuh" + +#include + +#include + +#include +#include + +namespace cudf { +namespace io { +namespace parquet { +namespace gpu { + +namespace { + +// # of threads we're decoding with +constexpr int preprocess_block_size = 512; + +// the required number of runs in shared memory we will need to provide the +// rle_stream object +constexpr int rle_run_buffer_size = rle_stream_required_run_buffer_size(preprocess_block_size); + +// the size of the rolling batch buffer +constexpr int rolling_buf_size = LEVEL_DECODE_BUF_SIZE; + +using unused_state_buf = page_state_buffers_s<0, 0, 0>; + +/** + * + * This function expects the dictionary position to be at 0 and will traverse + * the entire thing. + * + * Operates on a single warp only. Expects t < 32 + * + * @param s The local page info + * @param t Thread index + */ +__device__ size_type gpuDecodeTotalPageStringSize(page_state_s* s, int t) +{ + size_type target_pos = s->num_input_values; + size_type str_len = 0; + if (s->dict_base) { + auto const [new_target_pos, len] = + gpuDecodeDictionaryIndices(s, nullptr, target_pos, t); + target_pos = new_target_pos; + str_len = len; + } else if ((s->col.data_type & 7) == BYTE_ARRAY) { + str_len = gpuInitStringDescriptors(s, nullptr, target_pos, t); + } + if (!t) { *(int32_t volatile*)&s->dict_pos = target_pos; } + return str_len; +} + +/** + * @brief Update output column sizes for every nesting level based on a batch + * of incoming decoded definition and repetition level values. + * + * If bounds_set is true, computes skipped_values and skipped_leaf_values for the + * page to indicate where we need to skip to based on min/max row. + * + * Operates at the block level. + * + * @param s The local page info + * @param target_value_count The target value count to process up to + * @param rep Repetition level buffer + * @param def Definition level buffer + * @param t Thread index + * @param bounds_set A boolean indicating whether or not min/max row bounds have been set + */ +template +static __device__ void gpuUpdatePageSizes(page_state_s* s, + int target_value_count, + level_t const* const rep, + level_t const* const def, + int t, + bool bounds_set) +{ + // max nesting depth of the column + int const max_depth = s->col.max_nesting_depth; + + constexpr int num_warps = preprocess_block_size / 32; + constexpr int max_batch_size = num_warps * 32; + + using block_reduce = cub::BlockReduce; + using block_scan = cub::BlockScan; + __shared__ union { + typename block_reduce::TempStorage reduce_storage; + typename block_scan::TempStorage scan_storage; + } temp_storage; + + // how many input level values we've processed in the page so far + int value_count = s->input_value_count; + // how many rows we've processed in the page so far + int row_count = s->input_row_count; + // how many leaf values we've processed in the page so far + int leaf_count = s->input_leaf_count; + // whether or not we need to continue checking for the first row + bool skipped_values_set = s->page.skipped_values >= 0; + + while (value_count < target_value_count) { + int const batch_size = min(max_batch_size, target_value_count - value_count); + + // start/end depth + int start_depth, end_depth, d; + get_nesting_bounds( + start_depth, end_depth, d, s, rep, def, value_count, value_count + batch_size, t); + + // is this thread within row bounds? in the non skip_rows/num_rows case this will always + // be true. + int in_row_bounds = 1; + + // if we are in the skip_rows/num_rows case, we need to check against these limits + if (bounds_set) { + // get absolute thread row index + int const is_new_row = start_depth == 0; + int thread_row_count, block_row_count; + block_scan(temp_storage.scan_storage) + .InclusiveSum(is_new_row, thread_row_count, block_row_count); + __syncthreads(); + + // get absolute thread leaf index + int const is_new_leaf = (d >= s->nesting_info[max_depth - 1].max_def_level); + int thread_leaf_count, block_leaf_count; + block_scan(temp_storage.scan_storage) + .InclusiveSum(is_new_leaf, thread_leaf_count, block_leaf_count); + __syncthreads(); + + // if this thread is in row bounds + int const row_index = (thread_row_count + row_count) - 1; + in_row_bounds = + (row_index >= s->row_index_lower_bound) && (row_index < (s->first_row + s->num_rows)); + + // if we have not set skipped values yet, see if we found the first in-bounds row + if (!skipped_values_set) { + int local_count, global_count; + block_scan(temp_storage.scan_storage) + .InclusiveSum(in_row_bounds, local_count, global_count); + __syncthreads(); + + // we found it + if (global_count > 0) { + // this is the thread that represents the first row. + if (local_count == 1 && in_row_bounds) { + s->page.skipped_values = value_count + t; + s->page.skipped_leaf_values = + leaf_count + (is_new_leaf ? thread_leaf_count - 1 : thread_leaf_count); + } + skipped_values_set = true; + } + } + + row_count += block_row_count; + leaf_count += block_leaf_count; + } + + // increment value counts across all nesting depths + for (int s_idx = 0; s_idx < max_depth; s_idx++) { + int const in_nesting_bounds = (s_idx >= start_depth && s_idx <= end_depth && in_row_bounds); + int const count = block_reduce(temp_storage.reduce_storage).Sum(in_nesting_bounds); + __syncthreads(); + if (!t) { + PageNestingInfo* pni = &s->page.nesting[s_idx]; + pni->batch_size += count; + } + } + + value_count += batch_size; + } + + // update final outputs + if (!t) { + s->input_value_count = value_count; + + // only used in the skip_rows/num_rows case + s->input_leaf_count = leaf_count; + s->input_row_count = row_count; + } +} + +/** + * @brief Kernel for computing per-page column size information for all nesting levels. + * + * This function will write out the size field for each level of nesting. + * + * @param pages List of pages + * @param chunks List of column chunks + * @param min_row Row index to start reading at + * @param num_rows Maximum number of rows to read. Pass as INT_MAX to guarantee reading all rows + * @param is_base_pass Whether or not this is the base pass. We first have to compute + * the full size information of every page before we come through in a second (trim) pass + * to determine what subset of rows in this page we should be reading + * @param compute_string_sizes Whether or not we should be computing string sizes + * (PageInfo::str_bytes) as part of the pass + */ +template +__global__ void __launch_bounds__(preprocess_block_size) + gpuComputePageSizes(PageInfo* pages, + device_span chunks, + size_t min_row, + size_t num_rows, + bool is_base_pass, + bool compute_string_sizes) +{ + __shared__ __align__(16) page_state_s state_g; + + page_state_s* const s = &state_g; + int page_idx = blockIdx.x; + int t = threadIdx.x; + PageInfo* pp = &pages[page_idx]; + + // whether or not we have repetition levels (lists) + bool has_repetition = chunks[pp->chunk_idx].max_level[level_type::REPETITION] > 0; + + // the level stream decoders + __shared__ rle_run def_runs[rle_run_buffer_size]; + __shared__ rle_run rep_runs[rle_run_buffer_size]; + rle_stream decoders[level_type::NUM_LEVEL_TYPES] = {{def_runs}, + {rep_runs}}; + + // setup page info + if (!setupLocalPageInfo(s, pp, chunks, min_row, num_rows, all_types_filter{}, false)) { return; } + + // initialize the stream decoders (requires values computed in setupLocalPageInfo) + // the size of the rolling batch buffer + int const max_batch_size = rolling_buf_size; + level_t* rep = reinterpret_cast(pp->lvl_decode_buf[level_type::REPETITION]); + level_t* def = reinterpret_cast(pp->lvl_decode_buf[level_type::DEFINITION]); + decoders[level_type::DEFINITION].init(s->col.level_bits[level_type::DEFINITION], + s->abs_lvl_start[level_type::DEFINITION], + s->abs_lvl_end[level_type::DEFINITION], + max_batch_size, + def, + s->page.num_input_values); + if (has_repetition) { + decoders[level_type::REPETITION].init(s->col.level_bits[level_type::REPETITION], + s->abs_lvl_start[level_type::REPETITION], + s->abs_lvl_end[level_type::REPETITION], + max_batch_size, + rep, + s->page.num_input_values); + } + __syncthreads(); + + if (!t) { + s->page.skipped_values = -1; + s->page.skipped_leaf_values = 0; + s->page.str_bytes = 0; + s->input_row_count = 0; + s->input_value_count = 0; + + // in the base pass, we're computing the number of rows, make sure we visit absolutely + // everything + if (is_base_pass) { + s->first_row = 0; + s->num_rows = INT_MAX; + s->row_index_lower_bound = -1; + } + } + + // we only need to preprocess hierarchies with repetition in them (ie, hierarchies + // containing lists anywhere within). + compute_string_sizes = + compute_string_sizes && ((s->col.data_type & 7) == BYTE_ARRAY && s->dtype_len != 4); + + // early out optimizations: + + // - if this is a flat hierarchy (no lists) and is not a string column. in this case we don't need + // to do the expensive work of traversing the level data to determine sizes. we can just compute + // it directly. + if (!has_repetition && !compute_string_sizes) { + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + if (is_base_pass) { pp->nesting[thread_depth].size = pp->num_input_values; } + pp->nesting[thread_depth].batch_size = pp->num_input_values; + } + depth += blockDim.x; + } + return; + } + + // in the trim pass, for anything with lists, we only need to fully process bounding pages (those + // at the beginning or the end of the row bounds) + if (!is_base_pass && !is_bounds_page(s, min_row, num_rows, has_repetition)) { + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + // if we are not a bounding page (as checked above) then we are either + // returning all rows/values from this page, or 0 of them + pp->nesting[thread_depth].batch_size = + (s->num_rows == 0 && !is_page_contained(s, min_row, num_rows)) + ? 0 + : pp->nesting[thread_depth].size; + } + depth += blockDim.x; + } + return; + } + + // zero sizes + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + s->page.nesting[thread_depth].batch_size = 0; + } + depth += blockDim.x; + } + __syncthreads(); + + // the core loop. decode batches of level stream data using rle_stream objects + // and pass the results to gpuUpdatePageSizes + int processed = 0; + while (processed < s->page.num_input_values) { + // TODO: it would not take much more work to make it so that we could run both of these + // decodes concurrently. there are a couple of shared variables internally that would have to + // get dealt with but that's about it. + if (has_repetition) { + decoders[level_type::REPETITION].decode_next(t); + __syncthreads(); + } + // the # of rep/def levels will always be the same size + processed += decoders[level_type::DEFINITION].decode_next(t); + __syncthreads(); + + // update page sizes + gpuUpdatePageSizes(s, processed, rep, def, t, !is_base_pass); + __syncthreads(); + } + + // retrieve total string size. + // TODO: make this block-based instead of just 1 warp + if (compute_string_sizes) { + if (t < 32) { s->page.str_bytes = gpuDecodeTotalPageStringSize(s, t); } + } + + // update output results: + // - real number of rows for the whole page + // - nesting sizes for the whole page + // - skipped value information for trimmed pages + // - string bytes + if (is_base_pass) { + // nesting level 0 is the root column, so the size is also the # of rows + if (!t) { pp->num_rows = s->page.nesting[0].batch_size; } + + // store off this batch size as the "full" size + int depth = 0; + while (depth < s->page.num_output_nesting_levels) { + auto const thread_depth = depth + t; + if (thread_depth < s->page.num_output_nesting_levels) { + pp->nesting[thread_depth].size = pp->nesting[thread_depth].batch_size; + } + depth += blockDim.x; + } + } + + if (!t) { + pp->skipped_values = s->page.skipped_values; + pp->skipped_leaf_values = s->page.skipped_leaf_values; + pp->str_bytes = s->page.str_bytes; + } +} + +} // anonymous namespace + +/** + * @copydoc cudf::io::parquet::gpu::ComputePageSizes + */ +void ComputePageSizes(cudf::detail::hostdevice_vector& pages, + cudf::detail::hostdevice_vector const& chunks, + size_t min_row, + size_t num_rows, + bool compute_num_rows, + bool compute_string_sizes, + int level_type_size, + rmm::cuda_stream_view stream) +{ + dim3 dim_block(preprocess_block_size, 1); + dim3 dim_grid(pages.size(), 1); // 1 threadblock per page + + // computes: + // PageNestingInfo::size for each level of nesting, for each page. + // This computes the size for the entire page, not taking row bounds into account. + // If uses_custom_row_bounds is set to true, we have to do a second pass later that "trims" + // the starting and ending read values to account for these bounds. + if (level_type_size == 1) { + gpuComputePageSizes<<>>( + pages.device_ptr(), chunks, min_row, num_rows, compute_num_rows, compute_string_sizes); + } else { + gpuComputePageSizes<<>>( + pages.device_ptr(), chunks, min_row, num_rows, compute_num_rows, compute_string_sizes); + } +} + +} // namespace gpu +} // namespace parquet +} // namespace io +} // namespace cudf diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh index b02164a377f..8e33cc33a0c 100644 --- a/cpp/src/io/parquet/delta_binary.cuh +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -52,6 +52,8 @@ namespace cudf::io::parquet::gpu { using uleb128_t = uint64_t; using zigzag128_t = int64_t; +constexpr int delta_rolling_buf_size = 256; + /** * @brief Read a ULEB128 varint integer * @@ -103,7 +105,7 @@ struct delta_binary_decoder { uint8_t const* cur_mb_start; // pointer to the start of the current mini-block data uint8_t const* cur_bitwidths; // pointer to the bitwidth array in the block - uleb128_t value[non_zero_buffer_size]; // circular buffer of delta values + uleb128_t value[delta_rolling_buf_size]; // circular buffer of delta values // returns the number of values encoded in the block data. when is_decode is true, // account for the first value in the header. otherwise just count the values encoded @@ -207,7 +209,7 @@ struct delta_binary_decoder { // unpack deltas. modified from version in gpuDecodeDictionaryIndices(), but // that one only unpacks up to bitwidths of 24. simplified some since this // will always do batches of 32. also replaced branching with a loop. - int64_t delta = 0; + zigzag128_t delta = 0; if (lane_id + current_value_idx < value_count) { int32_t ofs = (lane_id - warp_size) * mb_bits; uint8_t const* p = d_start + (ofs >> 3); @@ -217,10 +219,10 @@ struct delta_binary_decoder { delta = (*p++) >> ofs; while (c < mb_bits && p < block_end) { - delta |= (*p++) << c; + delta |= static_cast(*p++) << c; c += 8; } - delta &= (1 << mb_bits) - 1; + delta &= (static_cast(1) << mb_bits) - 1; } } @@ -233,7 +235,9 @@ struct delta_binary_decoder { // now add first value from header or last value from previous block to get true value delta += last_value; - value[rolling_index(current_value_idx + warp_size * i + lane_id)] = delta; + int const value_idx = + rolling_index(current_value_idx + warp_size * i + lane_id); + value[value_idx] = delta; // save value from last lane in warp. this will become the 'first value' added to the // deltas calculated in the next iteration (or invocation). diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index 648a286e392..512364d7c1a 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -30,6 +30,9 @@ namespace gpu { namespace { +constexpr int decode_block_size = 128; +constexpr int rolling_buf_size = decode_block_size * 2; + /** * @brief Output a string descriptor * @@ -38,8 +41,9 @@ namespace { * @param[in] src_pos Source position * @param[in] dstv Pointer to row output data (string descriptor or 32-bit hash) */ +template inline __device__ void gpuOutputString(page_state_s volatile* s, - page_state_buffers_s volatile* sb, + state_buf volatile* sb, int src_pos, void* dstv) { @@ -65,11 +69,10 @@ inline __device__ void gpuOutputString(page_state_s volatile* s, * @param[in] src_pos Source position * @param[in] dst Pointer to row output data */ -inline __device__ void gpuOutputBoolean(page_state_buffers_s volatile* sb, - int src_pos, - uint8_t* dst) +template +inline __device__ void gpuOutputBoolean(state_buf volatile* sb, int src_pos, uint8_t* dst) { - *dst = sb->dict_idx[rolling_index(src_pos)]; + *dst = sb->dict_idx[rolling_index(src_pos)]; } /** @@ -140,8 +143,9 @@ inline __device__ void gpuStoreOutput(uint2* dst, * @param[in] src_pos Source position * @param[out] dst Pointer to row output data */ +template inline __device__ void gpuOutputInt96Timestamp(page_state_s volatile* s, - page_state_buffers_s volatile* sb, + state_buf volatile* sb, int src_pos, int64_t* dst) { @@ -152,8 +156,9 @@ inline __device__ void gpuOutputInt96Timestamp(page_state_s volatile* s, if (s->dict_base) { // Dictionary - dict_pos = (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0; - src8 = s->dict_base; + dict_pos = + (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0; + src8 = s->dict_base; } else { // Plain dict_pos = src_pos; @@ -213,8 +218,9 @@ inline __device__ void gpuOutputInt96Timestamp(page_state_s volatile* s, * @param[in] src_pos Source position * @param[in] dst Pointer to row output data */ +template inline __device__ void gpuOutputInt64Timestamp(page_state_s volatile* s, - page_state_buffers_s volatile* sb, + state_buf volatile* sb, int src_pos, int64_t* dst) { @@ -224,8 +230,9 @@ inline __device__ void gpuOutputInt64Timestamp(page_state_s volatile* s, if (s->dict_base) { // Dictionary - dict_pos = (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0; - src8 = s->dict_base; + dict_pos = + (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0; + src8 = s->dict_base; } else { // Plain dict_pos = src_pos; @@ -294,16 +301,18 @@ __device__ void gpuOutputByteArrayAsInt(char const* ptr, int32_t len, T* dst) * @param[in] src_pos Source position * @param[in] dst Pointer to row output data */ -template +template __device__ void gpuOutputFixedLenByteArrayAsInt(page_state_s volatile* s, - page_state_buffers_s volatile* sb, + state_buf volatile* sb, int src_pos, T* dst) { uint32_t const dtype_len_in = s->dtype_len_in; uint8_t const* data = s->dict_base ? s->dict_base : s->data_start; uint32_t const pos = - (s->dict_base ? ((s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0) : src_pos) * + (s->dict_base + ? ((s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0) + : src_pos) * dtype_len_in; uint32_t const dict_size = s->dict_size; @@ -329,9 +338,9 @@ __device__ void gpuOutputFixedLenByteArrayAsInt(page_state_s volatile* s, * @param[in] src_pos Source position * @param[in] dst Pointer to row output data */ -template +template inline __device__ void gpuOutputFast(page_state_s volatile* s, - page_state_buffers_s volatile* sb, + state_buf volatile* sb, int src_pos, T* dst) { @@ -340,8 +349,9 @@ inline __device__ void gpuOutputFast(page_state_s volatile* s, if (s->dict_base) { // Dictionary - dict_pos = (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0; - dict = s->dict_base; + dict_pos = + (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0; + dict = s->dict_base; } else { // Plain dict_pos = src_pos; @@ -360,16 +370,18 @@ inline __device__ void gpuOutputFast(page_state_s volatile* s, * @param[in] dst8 Pointer to row output data * @param[in] len Length of element */ +template static __device__ void gpuOutputGeneric( - page_state_s volatile* s, page_state_buffers_s volatile* sb, int src_pos, uint8_t* dst8, int len) + page_state_s volatile* s, state_buf volatile* sb, int src_pos, uint8_t* dst8, int len) { uint8_t const* dict; uint32_t dict_pos, dict_size = s->dict_size; if (s->dict_base) { // Dictionary - dict_pos = (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0; - dict = s->dict_base; + dict_pos = + (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] : 0; + dict = s->dict_base; } else { // Plain dict_pos = src_pos; @@ -404,341 +416,6 @@ static __device__ void gpuOutputGeneric( } } -/** - * - * This function expects the dictionary position to be at 0 and will traverse - * the entire thing. - * - * Operates on a single warp only. Expects t < 32 - * - * @param s The local page info - * @param t Thread index - */ -__device__ size_type gpuDecodeTotalPageStringSize(page_state_s* s, int t) -{ - size_type target_pos = s->num_input_values; - size_type str_len = 0; - if (s->dict_base) { - auto const [new_target_pos, len] = gpuDecodeDictionaryIndices(s, nullptr, target_pos, t); - target_pos = new_target_pos; - str_len = len; - } else if ((s->col.data_type & 7) == BYTE_ARRAY) { - str_len = gpuInitStringDescriptors(s, nullptr, target_pos, t); - } - if (!t) { *(int32_t volatile*)&s->dict_pos = target_pos; } - return str_len; -} - -/** - * @brief Update output column sizes for every nesting level based on a batch - * of incoming decoded definition and repetition level values. - * - * If bounds_set is true, computes skipped_values and skipped_leaf_values for the - * page to indicate where we need to skip to based on min/max row. - * - * Operates at the block level. - * - * @param s The local page info - * @param target_value_count The target value count to process up to - * @param rep Repetition level buffer - * @param def Definition level buffer - * @param t Thread index - * @param bounds_set A boolean indicating whether or not min/max row bounds have been set - */ -template -static __device__ void gpuUpdatePageSizes(page_state_s* s, - int target_value_count, - level_t const* const rep, - level_t const* const def, - int t, - bool bounds_set) -{ - // max nesting depth of the column - int const max_depth = s->col.max_nesting_depth; - - constexpr int num_warps = preprocess_block_size / 32; - constexpr int max_batch_size = num_warps * 32; - - using block_reduce = cub::BlockReduce; - using block_scan = cub::BlockScan; - __shared__ union { - typename block_reduce::TempStorage reduce_storage; - typename block_scan::TempStorage scan_storage; - } temp_storage; - - // how many input level values we've processed in the page so far - int value_count = s->input_value_count; - // how many rows we've processed in the page so far - int row_count = s->input_row_count; - // how many leaf values we've processed in the page so far - int leaf_count = s->input_leaf_count; - // whether or not we need to continue checking for the first row - bool skipped_values_set = s->page.skipped_values >= 0; - - while (value_count < target_value_count) { - int const batch_size = min(max_batch_size, target_value_count - value_count); - - // start/end depth - int start_depth, end_depth, d; - get_nesting_bounds( - start_depth, end_depth, d, s, rep, def, value_count, value_count + batch_size, t); - - // is this thread within row bounds? in the non skip_rows/num_rows case this will always - // be true. - int in_row_bounds = 1; - - // if we are in the skip_rows/num_rows case, we need to check against these limits - if (bounds_set) { - // get absolute thread row index - int const is_new_row = start_depth == 0; - int thread_row_count, block_row_count; - block_scan(temp_storage.scan_storage) - .InclusiveSum(is_new_row, thread_row_count, block_row_count); - __syncthreads(); - - // get absolute thread leaf index - int const is_new_leaf = (d >= s->nesting_info[max_depth - 1].max_def_level); - int thread_leaf_count, block_leaf_count; - block_scan(temp_storage.scan_storage) - .InclusiveSum(is_new_leaf, thread_leaf_count, block_leaf_count); - __syncthreads(); - - // if this thread is in row bounds - int const row_index = (thread_row_count + row_count) - 1; - in_row_bounds = - (row_index >= s->row_index_lower_bound) && (row_index < (s->first_row + s->num_rows)); - - // if we have not set skipped values yet, see if we found the first in-bounds row - if (!skipped_values_set) { - int local_count, global_count; - block_scan(temp_storage.scan_storage) - .InclusiveSum(in_row_bounds, local_count, global_count); - __syncthreads(); - - // we found it - if (global_count > 0) { - // this is the thread that represents the first row. - if (local_count == 1 && in_row_bounds) { - s->page.skipped_values = value_count + t; - s->page.skipped_leaf_values = - leaf_count + (is_new_leaf ? thread_leaf_count - 1 : thread_leaf_count); - } - skipped_values_set = true; - } - } - - row_count += block_row_count; - leaf_count += block_leaf_count; - } - - // increment value counts across all nesting depths - for (int s_idx = 0; s_idx < max_depth; s_idx++) { - int const in_nesting_bounds = (s_idx >= start_depth && s_idx <= end_depth && in_row_bounds); - int const count = block_reduce(temp_storage.reduce_storage).Sum(in_nesting_bounds); - __syncthreads(); - if (!t) { - PageNestingInfo* pni = &s->page.nesting[s_idx]; - pni->batch_size += count; - } - } - - value_count += batch_size; - } - - // update final outputs - if (!t) { - s->input_value_count = value_count; - - // only used in the skip_rows/num_rows case - s->input_leaf_count = leaf_count; - s->input_row_count = row_count; - } -} - -/** - * @brief Kernel for computing per-page column size information for all nesting levels. - * - * This function will write out the size field for each level of nesting. - * - * @param pages List of pages - * @param chunks List of column chunks - * @param min_row Row index to start reading at - * @param num_rows Maximum number of rows to read. Pass as INT_MAX to guarantee reading all rows - * @param is_base_pass Whether or not this is the base pass. We first have to compute - * the full size information of every page before we come through in a second (trim) pass - * to determine what subset of rows in this page we should be reading - * @param compute_string_sizes Whether or not we should be computing string sizes - * (PageInfo::str_bytes) as part of the pass - */ -template -__global__ void __launch_bounds__(preprocess_block_size) - gpuComputePageSizes(PageInfo* pages, - device_span chunks, - size_t min_row, - size_t num_rows, - bool is_base_pass, - bool compute_string_sizes) -{ - __shared__ __align__(16) page_state_s state_g; - - page_state_s* const s = &state_g; - int page_idx = blockIdx.x; - int t = threadIdx.x; - PageInfo* pp = &pages[page_idx]; - - // whether or not we have repetition levels (lists) - bool has_repetition = chunks[pp->chunk_idx].max_level[level_type::REPETITION] > 0; - - // the level stream decoders - __shared__ rle_run def_runs[run_buffer_size]; - __shared__ rle_run rep_runs[run_buffer_size]; - rle_stream decoders[level_type::NUM_LEVEL_TYPES] = {{def_runs}, {rep_runs}}; - - // setup page info - if (!setupLocalPageInfo(s, pp, chunks, min_row, num_rows, all_types_filter{}, false)) { return; } - - // initialize the stream decoders (requires values computed in setupLocalPageInfo) - int const max_batch_size = lvl_buf_size; - level_t* rep = reinterpret_cast(pp->lvl_decode_buf[level_type::REPETITION]); - level_t* def = reinterpret_cast(pp->lvl_decode_buf[level_type::DEFINITION]); - decoders[level_type::DEFINITION].init(s->col.level_bits[level_type::DEFINITION], - s->abs_lvl_start[level_type::DEFINITION], - s->abs_lvl_end[level_type::DEFINITION], - max_batch_size, - def, - s->page.num_input_values); - if (has_repetition) { - decoders[level_type::REPETITION].init(s->col.level_bits[level_type::REPETITION], - s->abs_lvl_start[level_type::REPETITION], - s->abs_lvl_end[level_type::REPETITION], - max_batch_size, - rep, - s->page.num_input_values); - } - __syncthreads(); - - if (!t) { - s->page.skipped_values = -1; - s->page.skipped_leaf_values = 0; - s->page.str_bytes = 0; - s->input_row_count = 0; - s->input_value_count = 0; - - // in the base pass, we're computing the number of rows, make sure we visit absolutely - // everything - if (is_base_pass) { - s->first_row = 0; - s->num_rows = INT_MAX; - s->row_index_lower_bound = -1; - } - } - - // we only need to preprocess hierarchies with repetition in them (ie, hierarchies - // containing lists anywhere within). - compute_string_sizes = - compute_string_sizes && ((s->col.data_type & 7) == BYTE_ARRAY && s->dtype_len != 4); - - // early out optimizations: - - // - if this is a flat hierarchy (no lists) and is not a string column. in this case we don't need - // to do the expensive work of traversing the level data to determine sizes. we can just compute - // it directly. - if (!has_repetition && !compute_string_sizes) { - int depth = 0; - while (depth < s->page.num_output_nesting_levels) { - auto const thread_depth = depth + t; - if (thread_depth < s->page.num_output_nesting_levels) { - if (is_base_pass) { pp->nesting[thread_depth].size = pp->num_input_values; } - pp->nesting[thread_depth].batch_size = pp->num_input_values; - } - depth += blockDim.x; - } - return; - } - - // in the trim pass, for anything with lists, we only need to fully process bounding pages (those - // at the beginning or the end of the row bounds) - if (!is_base_pass && !is_bounds_page(s, min_row, num_rows, has_repetition)) { - int depth = 0; - while (depth < s->page.num_output_nesting_levels) { - auto const thread_depth = depth + t; - if (thread_depth < s->page.num_output_nesting_levels) { - // if we are not a bounding page (as checked above) then we are either - // returning all rows/values from this page, or 0 of them - pp->nesting[thread_depth].batch_size = - (s->num_rows == 0 && !is_page_contained(s, min_row, num_rows)) - ? 0 - : pp->nesting[thread_depth].size; - } - depth += blockDim.x; - } - return; - } - - // zero sizes - int depth = 0; - while (depth < s->page.num_output_nesting_levels) { - auto const thread_depth = depth + t; - if (thread_depth < s->page.num_output_nesting_levels) { - s->page.nesting[thread_depth].batch_size = 0; - } - depth += blockDim.x; - } - __syncthreads(); - - // the core loop. decode batches of level stream data using rle_stream objects - // and pass the results to gpuUpdatePageSizes - int processed = 0; - while (processed < s->page.num_input_values) { - // TODO: it would not take much more work to make it so that we could run both of these - // decodes concurrently. there are a couple of shared variables internally that would have to - // get dealt with but that's about it. - if (has_repetition) { - decoders[level_type::REPETITION].decode_next(t); - __syncthreads(); - } - // the # of rep/def levels will always be the same size - processed += decoders[level_type::DEFINITION].decode_next(t); - __syncthreads(); - - // update page sizes - gpuUpdatePageSizes(s, processed, rep, def, t, !is_base_pass); - __syncthreads(); - } - - // retrieve total string size. - // TODO: make this block-based instead of just 1 warp - if (compute_string_sizes) { - if (t < 32) { s->page.str_bytes = gpuDecodeTotalPageStringSize(s, t); } - } - - // update output results: - // - real number of rows for the whole page - // - nesting sizes for the whole page - // - skipped value information for trimmed pages - // - string bytes - if (is_base_pass) { - // nesting level 0 is the root column, so the size is also the # of rows - if (!t) { pp->num_rows = s->page.nesting[0].batch_size; } - - // store off this batch size as the "full" size - int depth = 0; - while (depth < s->page.num_output_nesting_levels) { - auto const thread_depth = depth + t; - if (thread_depth < s->page.num_output_nesting_levels) { - pp->nesting[thread_depth].size = pp->nesting[thread_depth].batch_size; - } - depth += blockDim.x; - } - } - - if (!t) { - pp->skipped_values = s->page.skipped_values; - pp->skipped_leaf_values = s->page.skipped_leaf_values; - pp->str_bytes = s->page.str_bytes; - } -} - /** * @brief Kernel for computing the column data stored in the pages * @@ -757,12 +434,14 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( PageInfo* pages, device_span chunks, size_t min_row, size_t num_rows) { __shared__ __align__(16) page_state_s state_g; - __shared__ __align__(16) page_state_buffers_s state_buffers; + __shared__ __align__(16) + page_state_buffers_s + state_buffers; - page_state_s* const s = &state_g; - page_state_buffers_s* const sb = &state_buffers; - int page_idx = blockIdx.x; - int t = threadIdx.x; + page_state_s* const s = &state_g; + auto* const sb = &state_buffers; + int page_idx = blockIdx.x; + int t = threadIdx.x; int out_thread0; [[maybe_unused]] null_count_back_copier _{s, t}; @@ -782,8 +461,8 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( PageNestingDecodeInfo* nesting_info_base = s->nesting_info; - __shared__ level_t rep[non_zero_buffer_size]; // circular buffer of repetition level values - __shared__ level_t def[non_zero_buffer_size]; // circular buffer of definition level values + __shared__ level_t rep[rolling_buf_size]; // circular buffer of repetition level values + __shared__ level_t def[rolling_buf_size]; // circular buffer of definition level values // skipped_leaf_values will always be 0 for flat hierarchies. uint32_t skipped_leaf_values = s->page.skipped_leaf_values; @@ -824,7 +503,7 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( src_pos += t - out_thread0; // the position in the output column/buffer - int dst_pos = sb->nz_idx[rolling_index(src_pos)]; + int dst_pos = sb->nz_idx[rolling_index(src_pos)]; // for the flat hierarchy case we will be reading from the beginning of the value stream, // regardless of the value of first_row. so adjust our destination offset accordingly. @@ -925,36 +604,6 @@ uint32_t GetKernelMasks(cudf::detail::hostdevice_vector& pages, rmm::exec_policy(stream), mask_iter, mask_iter + pages.size(), 0U, thrust::bit_or{}); } -/** - * @copydoc cudf::io::parquet::gpu::ComputePageSizes - */ -void ComputePageSizes(cudf::detail::hostdevice_vector& pages, - cudf::detail::hostdevice_vector const& chunks, - size_t min_row, - size_t num_rows, - bool compute_num_rows, - bool compute_string_sizes, - int level_type_size, - rmm::cuda_stream_view stream) -{ - dim3 dim_block(preprocess_block_size, 1); - dim3 dim_grid(pages.size(), 1); // 1 threadblock per page - - // computes: - // PageNestingInfo::size for each level of nesting, for each page. - // This computes the size for the entire page, not taking row bounds into account. - // If uses_custom_row_bounds is set to true, we have to do a second pass later that "trims" - // the starting and ending read values to account for these bounds. - if (level_type_size == 1) { - gpuComputePageSizes<<>>( - pages.device_ptr(), chunks, min_row, num_rows, compute_num_rows, compute_string_sizes); - } else { - gpuComputePageSizes - <<>>( - pages.device_ptr(), chunks, min_row, num_rows, compute_num_rows, compute_string_sizes); - } -} - /** * @copydoc cudf::io::parquet::gpu::DecodePageData */ @@ -971,10 +620,10 @@ void __host__ DecodePageData(cudf::detail::hostdevice_vector& pages, dim3 dim_grid(pages.size(), 1); // 1 threadblock per page if (level_type_size == 1) { - gpuDecodePageData + gpuDecodePageData <<>>(pages.device_ptr(), chunks, min_row, num_rows); } else { - gpuDecodePageData + gpuDecodePageData <<>>(pages.device_ptr(), chunks, min_row, num_rows); } } diff --git a/cpp/src/io/parquet/page_decode.cuh b/cpp/src/io/parquet/page_decode.cuh index 8e8c15fefe5..91e19bd380d 100644 --- a/cpp/src/io/parquet/page_decode.cuh +++ b/cpp/src/io/parquet/page_decode.cuh @@ -25,17 +25,6 @@ namespace cudf::io::parquet::gpu { -constexpr int preprocess_block_size = num_rle_stream_decode_threads; // 512 -constexpr int decode_block_size = 128; -constexpr int non_zero_buffer_size = decode_block_size * 2; - -constexpr int rolling_index(int index) { return index & (non_zero_buffer_size - 1); } -template -constexpr int rolling_lvl_index(int index) -{ - return index % lvl_buf_size; -} - struct page_state_s { uint8_t const* data_start; uint8_t const* data_end; @@ -83,10 +72,15 @@ struct page_state_s { // buffers only used in the decode kernel. separated from page_state_s to keep // shared memory usage in other kernels (eg, gpuComputePageSizes) down. +template struct page_state_buffers_s { - uint32_t nz_idx[non_zero_buffer_size]; // circular buffer of non-null value positions - uint32_t dict_idx[non_zero_buffer_size]; // Dictionary index, boolean, or string offset values - uint32_t str_len[non_zero_buffer_size]; // String length for plain encoding of strings + static constexpr int nz_buf_size = _nz_buf_size; + static constexpr int dict_buf_size = _dict_buf_size; + static constexpr int str_buf_size = _str_buf_size; + + uint32_t nz_idx[nz_buf_size]; // circular buffer of non-null value positions + uint32_t dict_idx[dict_buf_size]; // Dictionary index, boolean, or string offset values + uint32_t str_len[str_buf_size]; // String length for plain encoding of strings }; // Copies null counts back to `nesting_decode` at the end of scope @@ -175,8 +169,10 @@ inline __device__ bool is_page_contained(page_state_s* const s, size_t start_row * * @return A pair containing a pointer to the string and its length */ -inline __device__ cuda::std::pair gpuGetStringData( - page_state_s volatile* s, page_state_buffers_s volatile* sb, int src_pos) +template +inline __device__ cuda::std::pair gpuGetStringData(page_state_s volatile* s, + state_buf volatile* sb, + int src_pos) { char const* ptr = nullptr; size_t len = 0; @@ -184,7 +180,9 @@ inline __device__ cuda::std::pair gpuGetStringData( if (s->dict_base) { // String dictionary uint32_t dict_pos = - (s->dict_bits > 0) ? sb->dict_idx[rolling_index(src_pos)] * sizeof(string_index_pair) : 0; + (s->dict_bits > 0) + ? sb->dict_idx[rolling_index(src_pos)] * sizeof(string_index_pair) + : 0; if (dict_pos < (uint32_t)s->dict_size) { auto const* src = reinterpret_cast(s->dict_base + dict_pos); ptr = src->first; @@ -192,10 +190,10 @@ inline __device__ cuda::std::pair gpuGetStringData( } } else { // Plain encoding - uint32_t dict_pos = sb->dict_idx[rolling_index(src_pos)]; + uint32_t dict_pos = sb->dict_idx[rolling_index(src_pos)]; if (dict_pos <= (uint32_t)s->dict_size) { ptr = reinterpret_cast(s->data_start + dict_pos); - len = sb->str_len[rolling_index(src_pos)]; + len = sb->str_len[rolling_index(src_pos)]; } } @@ -216,12 +214,9 @@ inline __device__ cuda::std::pair gpuGetStringData( * decodes strings beyond target_pos, the total length of strings returned will include these * additional values. */ -template +template __device__ cuda::std::pair gpuDecodeDictionaryIndices( - page_state_s volatile* s, - [[maybe_unused]] page_state_buffers_s volatile* sb, - int target_pos, - int t) + page_state_s volatile* s, [[maybe_unused]] state_buf volatile* sb, int target_pos, int t) { uint8_t const* end = s->data_end; int dict_bits = s->dict_bits; @@ -297,7 +292,9 @@ __device__ cuda::std::pair gpuDecodeDictionaryIndices( } // if we're not computing sizes, store off the dictionary index - if constexpr (!sizes_only) { sb->dict_idx[rolling_index(pos + t)] = dict_idx; } + if constexpr (!sizes_only) { + sb->dict_idx[rolling_index(pos + t)] = dict_idx; + } } // if we're computing sizes, add the length(s) @@ -333,8 +330,9 @@ __device__ cuda::std::pair gpuDecodeDictionaryIndices( * * @return The new output position */ +template inline __device__ int gpuDecodeRleBooleans(page_state_s volatile* s, - page_state_buffers_s volatile* sb, + state_buf volatile* sb, int target_pos, int t) { @@ -383,7 +381,7 @@ inline __device__ int gpuDecodeRleBooleans(page_state_s volatile* s, } else { dict_idx = s->dict_val; } - sb->dict_idx[rolling_index(pos + t)] = dict_idx; + sb->dict_idx[rolling_index(pos + t)] = dict_idx; } pos += batch_len; } @@ -401,9 +399,9 @@ inline __device__ int gpuDecodeRleBooleans(page_state_s volatile* s, * * @return Total length of strings processed */ -template +template __device__ size_type gpuInitStringDescriptors(page_state_s volatile* s, - [[maybe_unused]] page_state_buffers_s volatile* sb, + [[maybe_unused]] state_buf volatile* sb, int target_pos, int t) { @@ -426,8 +424,8 @@ __device__ size_type gpuInitStringDescriptors(page_state_s volatile* s, len = 0; } if constexpr (!sizes_only) { - sb->dict_idx[rolling_index(pos)] = k; - sb->str_len[rolling_index(pos)] = len; + sb->dict_idx[rolling_index(pos)] = k; + sb->str_len[rolling_index(pos)] = len; } k += len; total_len += len; @@ -448,7 +446,7 @@ __device__ size_type gpuInitStringDescriptors(page_state_s volatile* s, * @param[in] t Warp0 thread ID (0..31) * @param[in] lvl The level type we are decoding - DEFINITION or REPETITION */ -template +template __device__ void gpuDecodeStream( level_t* output, page_state_s* s, int32_t target_count, int t, level_type lvl) { @@ -515,8 +513,8 @@ __device__ void gpuDecodeStream( level_run -= batch_len * 2; } if (t < batch_len) { - int idx = value_count + t; - output[rolling_index(idx)] = level_val; + int idx = value_count + t; + output[rolling_index(idx)] = level_val; } batch_coded_count += batch_len; value_count += batch_len; @@ -540,21 +538,22 @@ __device__ void gpuDecodeStream( * @param[in] valid_mask The validity mask to be stored * @param[in] value_count # of bits in the validity mask */ -inline __device__ void store_validity(PageNestingDecodeInfo* nesting_info, +inline __device__ void store_validity(int valid_map_offset, + bitmask_type* valid_map, uint32_t valid_mask, int32_t value_count) { - int word_offset = nesting_info->valid_map_offset / 32; - int bit_offset = nesting_info->valid_map_offset % 32; + int word_offset = valid_map_offset / 32; + int bit_offset = valid_map_offset % 32; // if we fit entirely in the output word if (bit_offset + value_count <= 32) { auto relevant_mask = static_cast((static_cast(1) << value_count) - 1); if (relevant_mask == ~0) { - nesting_info->valid_map[word_offset] = valid_mask; + valid_map[word_offset] = valid_mask; } else { - atomicAnd(nesting_info->valid_map + word_offset, ~(relevant_mask << bit_offset)); - atomicOr(nesting_info->valid_map + word_offset, (valid_mask & relevant_mask) << bit_offset); + atomicAnd(valid_map + word_offset, ~(relevant_mask << bit_offset)); + atomicOr(valid_map + word_offset, (valid_mask & relevant_mask) << bit_offset); } } // we're going to spill over into the next word. @@ -568,17 +567,15 @@ inline __device__ void store_validity(PageNestingDecodeInfo* nesting_info, // first word. strip bits_left bits off the beginning and store that uint32_t relevant_mask = ((1 << bits_left) - 1); uint32_t mask_word0 = valid_mask & relevant_mask; - atomicAnd(nesting_info->valid_map + word_offset, ~(relevant_mask << bit_offset)); - atomicOr(nesting_info->valid_map + word_offset, mask_word0 << bit_offset); + atomicAnd(valid_map + word_offset, ~(relevant_mask << bit_offset)); + atomicOr(valid_map + word_offset, mask_word0 << bit_offset); // second word. strip the remainder of the bits off the end and store that relevant_mask = ((1 << (value_count - bits_left)) - 1); uint32_t mask_word1 = valid_mask & (relevant_mask << bits_left); - atomicAnd(nesting_info->valid_map + word_offset + 1, ~(relevant_mask)); - atomicOr(nesting_info->valid_map + word_offset + 1, mask_word1 >> bits_left); + atomicAnd(valid_map + word_offset + 1, ~(relevant_mask)); + atomicOr(valid_map + word_offset + 1, mask_word1 >> bits_left); } - - nesting_info->valid_map_offset += value_count; } /** @@ -596,7 +593,7 @@ inline __device__ void store_validity(PageNestingDecodeInfo* nesting_info, * @param[in] target_input_value_count The desired # of input level values we want to process * @param[in] t Thread index */ -template +template inline __device__ void get_nesting_bounds(int& start_depth, int& end_depth, int& d, @@ -611,7 +608,7 @@ inline __device__ void get_nesting_bounds(int& start_depth, end_depth = -1; d = -1; if (input_value_count + t < target_input_value_count) { - int const index = rolling_lvl_index(input_value_count + t); + int const index = rolling_index(input_value_count + t); d = static_cast(def[index]); // if we have repetition (there are list columns involved) we have to // bound what nesting levels we apply values to @@ -640,10 +637,10 @@ inline __device__ void get_nesting_bounds(int& start_depth, * @param[in] def Definition level buffer * @param[in] t Thread index */ -template +template __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_input_value_count, page_state_s* s, - page_state_buffers_s* sb, + state_buf* sb, level_t const* const rep, level_t const* const def, int t) @@ -663,7 +660,7 @@ __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_input_value // determine the nesting bounds for this thread (the range of nesting depths we // will generate new value indices and validity bits for) int start_depth, end_depth, d; - get_nesting_bounds( + get_nesting_bounds( start_depth, end_depth, d, s, rep, def, input_value_count, target_input_value_count, t); // 4 interesting things to track: @@ -730,7 +727,7 @@ __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_input_value int const src_pos = nesting_info->valid_count + thread_valid_count; int const dst_pos = nesting_info->value_count + thread_value_count; // nz_idx is a mapping of src buffer indices to destination buffer indices - sb->nz_idx[rolling_index(src_pos)] = dst_pos; + sb->nz_idx[rolling_index(src_pos)] = dst_pos; } // compute warp and thread value counts for the -next- nesting level. we need to @@ -775,8 +772,11 @@ __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_input_value if (!t) { if (nesting_info->valid_map != nullptr && warp_valid_mask_bit_count > 0) { uint32_t const warp_output_valid_mask = warp_valid_mask >> first_thread_in_write_range; - store_validity(nesting_info, warp_output_valid_mask, warp_valid_mask_bit_count); - + store_validity(nesting_info->valid_map_offset, + nesting_info->valid_map, + warp_output_valid_mask, + warp_valid_mask_bit_count); + nesting_info->valid_map_offset += warp_valid_mask_bit_count; nesting_info->null_count += warp_valid_mask_bit_count - __popc(warp_output_valid_mask); } nesting_info->valid_count += warp_valid_count; @@ -819,9 +819,9 @@ __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_input_value * @param[in] def Definition level buffer * @param[in] t Thread index */ -template +template __device__ void gpuDecodeLevels(page_state_s* s, - page_state_buffers_s* sb, + state_buf* sb, int32_t target_leaf_count, level_t* const rep, level_t* const def, @@ -833,8 +833,10 @@ __device__ void gpuDecodeLevels(page_state_s* s, int cur_leaf_count = target_leaf_count; while (!s->error && s->nz_count < target_leaf_count && s->input_value_count < s->num_input_values) { - if (has_repetition) { gpuDecodeStream(rep, s, cur_leaf_count, t, level_type::REPETITION); } - gpuDecodeStream(def, s, cur_leaf_count, t, level_type::DEFINITION); + if (has_repetition) { + gpuDecodeStream(rep, s, cur_leaf_count, t, level_type::REPETITION); + } + gpuDecodeStream(def, s, cur_leaf_count, t, level_type::DEFINITION); __syncwarp(); // because the rep and def streams are encoded separately, we cannot request an exact @@ -845,7 +847,7 @@ __device__ void gpuDecodeLevels(page_state_s* s, : s->lvl_count[level_type::DEFINITION]; // process what we got back - gpuUpdateValidityOffsetsAndRowIndices( + gpuUpdateValidityOffsetsAndRowIndices( actual_leaf_count, s, sb, rep, def, t); cur_leaf_count = actual_leaf_count + batch_size; __syncwarp(); diff --git a/cpp/src/io/parquet/page_delta_decode.cu b/cpp/src/io/parquet/page_delta_decode.cu index 9d345571495..e8fafe4445f 100644 --- a/cpp/src/io/parquet/page_delta_decode.cu +++ b/cpp/src/io/parquet/page_delta_decode.cu @@ -31,20 +31,20 @@ namespace { // only used for int32 and int64 physical types (and appears to only be used // with V2 page headers; see https://www.mail-archive.com/dev@parquet.apache.org/msg11826.html). // this kernel only needs 96 threads (3 warps)(for now). -template +template __global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( PageInfo* pages, device_span chunks, size_t min_row, size_t num_rows) { __shared__ __align__(16) delta_binary_decoder db_state; - __shared__ __align__(16) page_state_buffers_s state_buffers; __shared__ __align__(16) page_state_s state_g; - - page_state_s* const s = &state_g; - page_state_buffers_s* const sb = &state_buffers; - int const page_idx = blockIdx.x; - int const t = threadIdx.x; - int const lane_id = t & 0x1f; - auto* const db = &db_state; + __shared__ __align__(16) page_state_buffers_s state_buffers; + + page_state_s* const s = &state_g; + auto* const sb = &state_buffers; + int const page_idx = blockIdx.x; + int const t = threadIdx.x; + int const lane_id = t & 0x1f; + auto* const db = &db_state; [[maybe_unused]] null_count_back_copier _{s, t}; if (!setupLocalPageInfo(s, @@ -62,8 +62,8 @@ __global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( // copying logic from gpuDecodePageData. PageNestingDecodeInfo const* nesting_info_base = s->nesting_info; - __shared__ level_t rep[non_zero_buffer_size]; // circular buffer of repetition level values - __shared__ level_t def[non_zero_buffer_size]; // circular buffer of definition level values + __shared__ level_t rep[delta_rolling_buf_size]; // circular buffer of repetition level values + __shared__ level_t def[delta_rolling_buf_size]; // circular buffer of definition level values // skipped_leaf_values will always be 0 for flat hierarchies. uint32_t const skipped_leaf_values = s->page.skipped_leaf_values; @@ -98,7 +98,7 @@ __global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( // - update validity vectors // - updates offsets (for nested columns) // - produces non-NULL value indices in s->nz_idx for subsequent decoding - gpuDecodeLevels(s, sb, target_pos, rep, def, t); + gpuDecodeLevels(s, sb, target_pos, rep, def, t); } else if (t < 64) { // warp 1 db->decode_batch(); @@ -111,7 +111,7 @@ __global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( // process the mini-block in batches of 32 for (uint32_t sp = src_pos + lane_id; sp < src_pos + batch_size; sp += 32) { // the position in the output column/buffer - int32_t dst_pos = sb->nz_idx[rolling_index(sp)]; + int32_t dst_pos = sb->nz_idx[rolling_index(sp)]; // handle skip_rows here. flat hierarchies can just skip up to first_row. if (!has_repetition) { dst_pos -= s->first_row; } @@ -120,9 +120,11 @@ __global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( if (dst_pos >= 0 && sp < target_pos) { void* const dst = nesting_info_base[leaf_level_index].data_out + dst_pos * s->dtype_len; if (s->dtype_len == 8) { - *static_cast(dst) = db->value[rolling_index(sp + skipped_leaf_values)]; + *static_cast(dst) = + db->value[rolling_index(sp + skipped_leaf_values)]; } else if (s->dtype_len == 4) { - *static_cast(dst) = db->value[rolling_index(sp + skipped_leaf_values)]; + *static_cast(dst) = + db->value[rolling_index(sp + skipped_leaf_values)]; } } } @@ -151,10 +153,10 @@ void __host__ DecodeDeltaBinary(cudf::detail::hostdevice_vector& pages dim3 dim_grid(pages.size(), 1); // 1 threadblock per page if (level_type_size == 1) { - gpuDecodeDeltaBinary + gpuDecodeDeltaBinary <<>>(pages.device_ptr(), chunks, min_row, num_rows); } else { - gpuDecodeDeltaBinary + gpuDecodeDeltaBinary <<>>(pages.device_ptr(), chunks, min_row, num_rows); } } diff --git a/cpp/src/io/parquet/page_string_decode.cu b/cpp/src/io/parquet/page_string_decode.cu index 96b4f8f5842..8ab52f32226 100644 --- a/cpp/src/io/parquet/page_string_decode.cu +++ b/cpp/src/io/parquet/page_string_decode.cu @@ -27,6 +27,11 @@ namespace gpu { namespace { +constexpr int preprocess_block_size = 512; +constexpr int decode_block_size = 128; +constexpr int rolling_buf_size = decode_block_size * 2; +constexpr int preproc_buf_size = LEVEL_DECODE_BUF_SIZE; + /** * @brief Compute the start and end page value bounds for this page * @@ -40,16 +45,16 @@ namespace { * @param has_repetition True if the schema is nested * @param decoders Definition and repetition level decoders * @return pair containing start and end value indexes - * @tparam lvl_buf_size Size of the buffer used when decoding repetition and definition levels + * @tparam rle_buf_size Size of the buffer used when decoding repetition and definition levels * @tparam level_t Type used to store decoded repetition and definition levels */ -template +template __device__ thrust::pair page_bounds(page_state_s* const s, size_t min_row, size_t num_rows, bool is_bounds_pg, bool has_repetition, - rle_stream* decoders) + rle_stream* decoders) { using block_reduce = cub::BlockReduce; using block_scan = cub::BlockScan; @@ -78,13 +83,12 @@ __device__ thrust::pair page_bounds(page_state_s* const s, auto const col = &s->col; // initialize the stream decoders (requires values computed in setupLocalPageInfo) - int const max_batch_size = lvl_buf_size; - auto const def_decode = reinterpret_cast(pp->lvl_decode_buf[level_type::DEFINITION]); - auto const rep_decode = reinterpret_cast(pp->lvl_decode_buf[level_type::REPETITION]); + auto const def_decode = reinterpret_cast(pp->lvl_decode_buf[level_type::DEFINITION]); + auto const rep_decode = reinterpret_cast(pp->lvl_decode_buf[level_type::REPETITION]); decoders[level_type::DEFINITION].init(s->col.level_bits[level_type::DEFINITION], s->abs_lvl_start[level_type::DEFINITION], s->abs_lvl_end[level_type::DEFINITION], - max_batch_size, + preproc_buf_size, def_decode, s->page.num_input_values); // only need repetition if this is a bounds page. otherwise all we need is def level info @@ -93,7 +97,7 @@ __device__ thrust::pair page_bounds(page_state_s* const s, decoders[level_type::REPETITION].init(s->col.level_bits[level_type::REPETITION], s->abs_lvl_start[level_type::REPETITION], s->abs_lvl_end[level_type::REPETITION], - max_batch_size, + preproc_buf_size, rep_decode, s->page.num_input_values); } @@ -152,7 +156,7 @@ __device__ thrust::pair page_bounds(page_state_s* const s, // do something with the level data while (start_val < processed) { int idx_t = start_val + t; - int idx = rolling_lvl_index(idx_t); + int idx = rolling_index(idx_t); // get absolute thread row index int is_new_row = idx_t < processed && (!has_repetition || rep_decode[idx] == 0); @@ -250,7 +254,7 @@ __device__ thrust::pair page_bounds(page_state_s* const s, while (start_val < processed) { int idx_t = start_val + t; if (idx_t < processed) { - int idx = rolling_lvl_index(idx_t); + int idx = rolling_index(idx_t); if (def_decode[idx] < max_def) { num_nulls++; } } start_val += preprocess_block_size; @@ -454,10 +458,9 @@ __device__ size_t totalPlainEntriesSize(uint8_t const* data, * @param chunks All chunks to be decoded * @param min_rows crop all rows below min_row * @param num_rows Maximum number of rows to read - * @tparam lvl_buf_size Size of the buffer used when decoding repetition and definition levels * @tparam level_t Type used to store decoded repetition and definition levels */ -template +template __global__ void __launch_bounds__(preprocess_block_size) gpuComputePageStringSizes( PageInfo* pages, device_span chunks, size_t min_row, size_t num_rows) { @@ -474,10 +477,15 @@ __global__ void __launch_bounds__(preprocess_block_size) gpuComputePageStringSiz // whether or not we have repetition levels (lists) bool const has_repetition = chunks[pp->chunk_idx].max_level[level_type::REPETITION] > 0; + // the required number of runs in shared memory we will need to provide the + // rle_stream object + constexpr int rle_run_buffer_size = rle_stream_required_run_buffer_size(preprocess_block_size); + // the level stream decoders - __shared__ rle_run def_runs[run_buffer_size]; - __shared__ rle_run rep_runs[run_buffer_size]; - rle_stream decoders[level_type::NUM_LEVEL_TYPES] = {{def_runs}, {rep_runs}}; + __shared__ rle_run def_runs[rle_run_buffer_size]; + __shared__ rle_run rep_runs[rle_run_buffer_size]; + rle_stream decoders[level_type::NUM_LEVEL_TYPES] = {{def_runs}, + {rep_runs}}; // setup page info if (!setupLocalPageInfo( @@ -499,7 +507,7 @@ __global__ void __launch_bounds__(preprocess_block_size) gpuComputePageStringSiz // find start/end value indices auto const [start_value, end_value] = - page_bounds(s, min_row, num_rows, is_bounds_pg, has_repetition, decoders); + page_bounds(s, min_row, num_rows, is_bounds_pg, has_repetition, decoders); // need to save num_nulls and num_valids calculated in page_bounds in this page if (t == 0) { @@ -560,21 +568,22 @@ __global__ void __launch_bounds__(preprocess_block_size) gpuComputePageStringSiz * @param chunks List of column chunks * @param min_row Row index to start reading at * @param num_rows Maximum number of rows to read - * @tparam lvl_buf_size Size of the buffer used when decoding repetition and definition levels * @tparam level_t Type used to store decoded repetition and definition levels */ -template +template __global__ void __launch_bounds__(decode_block_size) gpuDecodeStringPageData( PageInfo* pages, device_span chunks, size_t min_row, size_t num_rows) { __shared__ __align__(16) page_state_s state_g; - __shared__ __align__(16) page_state_buffers_s state_buffers; __shared__ __align__(4) size_type last_offset; + __shared__ __align__(16) + page_state_buffers_s + state_buffers; - page_state_s* const s = &state_g; - page_state_buffers_s* const sb = &state_buffers; - int const page_idx = blockIdx.x; - int const t = threadIdx.x; + page_state_s* const s = &state_g; + auto* const sb = &state_buffers; + int const page_idx = blockIdx.x; + int const t = threadIdx.x; [[maybe_unused]] null_count_back_copier _{s, t}; if (!setupLocalPageInfo( @@ -592,8 +601,8 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodeStringPageData( int const leaf_level_index = s->col.max_nesting_depth - 1; PageNestingDecodeInfo* const nesting_info_base = s->nesting_info; - __shared__ level_t rep[lvl_buf_size]; // circular buffer of repetition level values - __shared__ level_t def[lvl_buf_size]; // circular buffer of definition level values + __shared__ level_t rep[rolling_buf_size]; // circular buffer of repetition level values + __shared__ level_t def[rolling_buf_size]; // circular buffer of definition level values // skipped_leaf_values will always be 0 for flat hierarchies. uint32_t skipped_leaf_values = s->page.skipped_leaf_values; @@ -614,7 +623,7 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodeStringPageData( // - update validity vectors // - updates offsets (for nested columns) // - produces non-NULL value indices in s->nz_idx for subsequent decoding - gpuDecodeLevels(s, sb, target_pos, rep, def, t); + gpuDecodeLevels(s, sb, target_pos, rep, def, t); } else if (t < out_thread0) { // skipped_leaf_values will always be 0 for flat hierarchies. uint32_t src_target_pos = target_pos + skipped_leaf_values; @@ -633,7 +642,7 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodeStringPageData( src_pos += t - out_thread0; // the position in the output column/buffer - int dst_pos = sb->nz_idx[rolling_index(src_pos)]; + int dst_pos = sb->nz_idx[rolling_index(src_pos)]; // for the flat hierarchy case we will be reading from the beginning of the value stream, // regardless of the value of first_row. so adjust our destination offset accordingly. @@ -656,7 +665,7 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodeStringPageData( if (me < warp_size) { for (int i = 0; i < decode_block_size - out_thread0; i += warp_size) { - dst_pos = sb->nz_idx[rolling_index(src_pos + i)]; + dst_pos = sb->nz_idx[rolling_index(src_pos + i)]; if (!has_repetition) { dst_pos -= s->first_row; } auto [ptr, len] = src_pos + i < target_pos && dst_pos >= 0 @@ -739,10 +748,10 @@ void ComputePageStringSizes(cudf::detail::hostdevice_vector& pages, dim3 dim_block(preprocess_block_size, 1); dim3 dim_grid(pages.size(), 1); // 1 threadblock per page if (level_type_size == 1) { - gpuComputePageStringSizes + gpuComputePageStringSizes <<>>(pages.device_ptr(), chunks, min_row, num_rows); } else { - gpuComputePageStringSizes + gpuComputePageStringSizes <<>>(pages.device_ptr(), chunks, min_row, num_rows); } } @@ -763,10 +772,10 @@ void __host__ DecodeStringPageData(cudf::detail::hostdevice_vector& pa dim3 dim_grid(pages.size(), 1); // 1 threadblock per page if (level_type_size == 1) { - gpuDecodeStringPageData + gpuDecodeStringPageData <<>>(pages.device_ptr(), chunks, min_row, num_rows); } else { - gpuDecodeStringPageData + gpuDecodeStringPageData <<>>(pages.device_ptr(), chunks, min_row, num_rows); } } diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 92bcd947b4b..798bc3cb6e2 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -48,6 +48,12 @@ constexpr size_type MAX_DICT_SIZE = (1 << MAX_DICT_BITS) - 1; // level decode buffer size. constexpr int LEVEL_DECODE_BUF_SIZE = 2048; +template +constexpr int rolling_index(int index) +{ + return index % rolling_size; +} + /** * @brief Struct representing an input column in the file. */ diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 1b6302fa92a..8759bc260c8 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -180,22 +180,28 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) auto const level_type_size = _file_itm_data.level_type_size; - // launch the catch-all page decoder + // vector of launched streams std::vector streams; - streams.push_back(get_stream_pool().get_stream()); - gpu::DecodePageData(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); - // and then the specializations + // launch string decoder if (has_strings) { streams.push_back(get_stream_pool().get_stream()); chunk_nested_str_data.host_to_device_async(streams.back()); gpu::DecodeStringPageData(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); } + + // launch delta binary decoder if ((kernel_mask & gpu::KERNEL_MASK_DELTA_BINARY) != 0) { streams.push_back(get_stream_pool().get_stream()); gpu::DecodeDeltaBinary(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); } + // launch the catch-all page decoder + if ((kernel_mask & gpu::KERNEL_MASK_GENERAL) != 0) { + streams.push_back(get_stream_pool().get_stream()); + gpu::DecodePageData(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); + } + // synchronize the streams std::for_each(streams.begin(), streams.end(), [](auto& stream) { stream.synchronize(); }); diff --git a/cpp/src/io/parquet/rle_stream.cuh b/cpp/src/io/parquet/rle_stream.cuh index 473db660238..9d2e8baa1cc 100644 --- a/cpp/src/io/parquet/rle_stream.cuh +++ b/cpp/src/io/parquet/rle_stream.cuh @@ -22,16 +22,11 @@ namespace cudf::io::parquet::gpu { -// TODO: consider if these should be template parameters to rle_stream -constexpr int num_rle_stream_decode_threads = 512; -// the -1 here is for the look-ahead warp that fills in the list of runs to be decoded -// in an overlapped manner. so if we had 16 total warps: -// - warp 0 would be filling in batches of runs to be processed -// - warps 1-15 would be decoding the previous batch of runs generated -constexpr int num_rle_stream_decode_warps = - (num_rle_stream_decode_threads / cudf::detail::warp_size) - 1; -constexpr int run_buffer_size = (num_rle_stream_decode_warps * 2); -constexpr int rolling_run_index(int index) { return index % run_buffer_size; } +constexpr int rle_stream_required_run_buffer_size(int num_threads) +{ + int num_rle_stream_decode_warps = (num_threads / cudf::detail::warp_size) - 1; + return (num_rle_stream_decode_warps * 2); +} /** * @brief Read a 32-bit varint integer @@ -144,8 +139,19 @@ struct rle_run { }; // a stream of rle_runs -template +template struct rle_stream { + // TODO: consider if these should be template parameters to rle_stream + static constexpr int num_rle_stream_decode_threads = decode_threads; + // the -1 here is for the look-ahead warp that fills in the list of runs to be decoded + // in an overlapped manner. so if we had 16 total warps: + // - warp 0 would be filling in batches of runs to be processed + // - warps 1-15 would be decoding the previous batch of runs generated + static constexpr int num_rle_stream_decode_warps = + (num_rle_stream_decode_threads / cudf::detail::warp_size) - 1; + + static constexpr int run_buffer_size = rle_stream_required_run_buffer_size(decode_threads); + int level_bits; uint8_t const* start; uint8_t const* cur; @@ -210,7 +216,7 @@ struct rle_stream { // generate runs until we either run out of warps to decode them with, or // we cross the output limit. while (run_count < num_rle_stream_decode_warps && output_pos < max_count && cur < end) { - auto& run = runs[rolling_run_index(run_index)]; + auto& run = runs[rolling_index(run_index)]; // Encoding::RLE @@ -256,13 +262,13 @@ struct rle_stream { // if we've reached the value output limit on the last run if (output_pos >= max_count) { // first, see if we've spilled over - auto const& src = runs[rolling_run_index(run_index - 1)]; + auto const& src = runs[rolling_index(run_index - 1)]; int const spill_count = output_pos - max_count; // a spill has occurred in the current run. spill the extra values over into the beginning of // the next run. if (spill_count > 0) { - auto& spill_run = runs[rolling_run_index(run_index)]; + auto& spill_run = runs[rolling_index(run_index)]; spill_run = src; spill_run.output_pos = 0; spill_run.remaining = spill_count; @@ -330,7 +336,7 @@ struct rle_stream { // repetition levels for one of the list benchmarks decodes in ~3ms total, while the // definition levels take ~11ms - the difference is entirely due to long runs in the // definition levels. - auto& run = runs[rolling_run_index(run_start + warp_decode_id)]; + auto& run = runs[rolling_index(run_start + warp_decode_id)]; auto batch = run.next_batch(output + run.output_pos, min(run.remaining, (output_count - run.output_pos))); batch.decode(end, level_bits, warp_lane, warp_decode_id); From b58e55ca0542e241fb089453e68c299f52dec39f Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 28 Jun 2023 16:50:53 -0700 Subject: [PATCH 07/77] use less shared memory for delta binary decoder --- cpp/src/io/parquet/delta_binary.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh index 8e33cc33a0c..9142405c8a4 100644 --- a/cpp/src/io/parquet/delta_binary.cuh +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -52,7 +52,8 @@ namespace cudf::io::parquet::gpu { using uleb128_t = uint64_t; using zigzag128_t = int64_t; -constexpr int delta_rolling_buf_size = 256; +// we decode one mini-block at a time. max mini-block size seen is 64. +constexpr int delta_rolling_buf_size = 128; /** * @brief Read a ULEB128 varint integer From 2c5e087612689f8ff5e5af7e98cfe9787be1ee67 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 28 Jun 2023 16:55:20 -0700 Subject: [PATCH 08/77] spelling --- cpp/src/io/parquet/delta_binary.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh index 9142405c8a4..887dd94995a 100644 --- a/cpp/src/io/parquet/delta_binary.cuh +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -47,7 +47,7 @@ namespace cudf::io::parquet::gpu { // lengths, followed by the concatenated suffix data. // TODO: The delta encodings use ULEB128 integers, but for now we're only -// using max 64 bits. Need to see what the performance impact is of useing +// using max 64 bits. Need to see what the performance impact is of using // __int128_t rather than int64_t. using uleb128_t = uint64_t; using zigzag128_t = int64_t; From 996893e8f4b38466642eb108378dd310b72dc8af Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 29 Jun 2023 11:28:46 -0700 Subject: [PATCH 09/77] change encoding to unsupported type --- .../tests/data/parquet/delta_encoding.parquet | Bin 577 -> 577 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/python/cudf/cudf/tests/data/parquet/delta_encoding.parquet b/python/cudf/cudf/tests/data/parquet/delta_encoding.parquet index e129ced34f3b570ba0ae966277f2111f8f539465..29565bef4d2e79033e2631a46eebedd3292db6b7 100644 GIT binary patch delta 28 icmX@ea*#zhz%j^BlucAlR4E2XF^DpW@@y2=V*&tHoCNOx delta 28 icmX@ea*#zhz%j^BlucAlR4E2XF^DpWa%~jWV*&tHk_7Dl From f1f74dc9daa462a68880cade7928fa50f9e663cb Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 29 Jun 2023 12:30:12 -0700 Subject: [PATCH 10/77] add python test of delta parser --- python/cudf/cudf/tests/test_parquet.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 3a35a0088ff..a6d19fe39af 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -1306,6 +1306,29 @@ def test_parquet_reader_v2(tmpdir, simple_pdf): assert_eq(cudf.read_parquet(pdf_fname), simple_pdf) +def test_delta_binary(tmpdir): + nrows = 100000 + # Create a pandas dataframe with random data of mixed types + test_pdf = pd.DataFrame( + { + "col_int32": np.random.randint(0, nrows, nrows).astype("int32"), + "col_int64": np.random.randint( + -0x10000000000, 0x10000000000, nrows + ).astype("int64"), + }, + ) + pdf_fname = tmpdir.join("pdfv2.parquet") + test_pdf.to_parquet( + pdf_fname, + version="2.6", + column_encoding="DELTA_BINARY_PACKED", + data_page_version="2.0", + engine="pyarrow", + use_dictionary=False, + ) + assert_eq(cudf.read_parquet(pdf_fname), test_pdf) + + @pytest.mark.parametrize( "data", [ From 639b8abb593d38dc1e31f7ee43f3c235e23a5980 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 29 Jun 2023 13:50:19 -0700 Subject: [PATCH 11/77] test delta with nulls --- python/cudf/cudf/tests/test_parquet.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index a6d19fe39af..6e354645148 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -1306,17 +1306,24 @@ def test_parquet_reader_v2(tmpdir, simple_pdf): assert_eq(cudf.read_parquet(pdf_fname), simple_pdf) -def test_delta_binary(tmpdir): +@pytest.mark.parametrize("add_nulls", [True, False]) +def test_delta_binary(add_nulls, tmpdir): nrows = 100000 # Create a pandas dataframe with random data of mixed types test_pdf = pd.DataFrame( { - "col_int32": np.random.randint(0, nrows, nrows).astype("int32"), - "col_int64": np.random.randint( - -0x10000000000, 0x10000000000, nrows - ).astype("int64"), + "col_int32": pd.Series( + np.random.randint(0, 0x7FFFFFFF, nrows), dtype="Int32" + ), + "col_int64": pd.Series( + np.random.randint(0, 0x7FFFFFFFFFFFFFFF, nrows), dtype="Int64" + ), }, ) + if add_nulls: + for i in range(0, int(nrows / 4)): + test_pdf.iloc[np.random.randint(0, nrows), 0] = pd.NA + test_pdf.iloc[np.random.randint(0, nrows), 1] = pd.NA pdf_fname = tmpdir.join("pdfv2.parquet") test_pdf.to_parquet( pdf_fname, @@ -1326,7 +1333,9 @@ def test_delta_binary(tmpdir): engine="pyarrow", use_dictionary=False, ) - assert_eq(cudf.read_parquet(pdf_fname), test_pdf) + cdf = cudf.read_parquet(pdf_fname) + pcdf = cudf.from_pandas(test_pdf) + assert_eq(cdf, pcdf) @pytest.mark.parametrize( From c1bbb840f342ae4c2bd0f0fd0efad55ccf0bf2a5 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 29 Jun 2023 17:01:44 -0700 Subject: [PATCH 12/77] add comments to skip_values and decode_batch --- cpp/src/io/parquet/delta_binary.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh index 887dd94995a..be5256b242d 100644 --- a/cpp/src/io/parquet/delta_binary.cuh +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -247,6 +247,8 @@ struct delta_binary_decoder { } } + // decodes and skips values until the block containing the value after `skip` is reached. + // called by all threads in a thread block. inline __device__ void skip_values(int skip) { int const t = threadIdx.x; @@ -261,6 +263,8 @@ struct delta_binary_decoder { } } + // decodes the current mini block and stores the values obtained. should only be called by + // a single warp. inline __device__ void decode_batch() { int const t = threadIdx.x; From 8e66a080d920fedc9cb39a276f2c5b08a93b4ffc Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 30 Jun 2023 13:44:06 -0700 Subject: [PATCH 13/77] revert east volatile changes --- cpp/src/io/parquet/page_data.cu | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index 512364d7c1a..cb92c0807f8 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -42,8 +42,8 @@ constexpr int rolling_buf_size = decode_block_size * 2; * @param[in] dstv Pointer to row output data (string descriptor or 32-bit hash) */ template -inline __device__ void gpuOutputString(page_state_s volatile* s, - state_buf volatile* sb, +inline __device__ void gpuOutputString(volatile page_state_s* s, + volatile state_buf* sb, int src_pos, void* dstv) { @@ -70,7 +70,7 @@ inline __device__ void gpuOutputString(page_state_s volatile* s, * @param[in] dst Pointer to row output data */ template -inline __device__ void gpuOutputBoolean(state_buf volatile* sb, int src_pos, uint8_t* dst) +inline __device__ void gpuOutputBoolean(volatile state_buf* sb, int src_pos, uint8_t* dst) { *dst = sb->dict_idx[rolling_index(src_pos)]; } @@ -144,8 +144,8 @@ inline __device__ void gpuStoreOutput(uint2* dst, * @param[out] dst Pointer to row output data */ template -inline __device__ void gpuOutputInt96Timestamp(page_state_s volatile* s, - state_buf volatile* sb, +inline __device__ void gpuOutputInt96Timestamp(volatile page_state_s* s, + volatile state_buf* sb, int src_pos, int64_t* dst) { @@ -219,8 +219,8 @@ inline __device__ void gpuOutputInt96Timestamp(page_state_s volatile* s, * @param[in] dst Pointer to row output data */ template -inline __device__ void gpuOutputInt64Timestamp(page_state_s volatile* s, - state_buf volatile* sb, +inline __device__ void gpuOutputInt64Timestamp(volatile page_state_s* s, + volatile state_buf* sb, int src_pos, int64_t* dst) { @@ -302,8 +302,8 @@ __device__ void gpuOutputByteArrayAsInt(char const* ptr, int32_t len, T* dst) * @param[in] dst Pointer to row output data */ template -__device__ void gpuOutputFixedLenByteArrayAsInt(page_state_s volatile* s, - state_buf volatile* sb, +__device__ void gpuOutputFixedLenByteArrayAsInt(volatile page_state_s* s, + volatile state_buf* sb, int src_pos, T* dst) { @@ -339,8 +339,8 @@ __device__ void gpuOutputFixedLenByteArrayAsInt(page_state_s volatile* s, * @param[in] dst Pointer to row output data */ template -inline __device__ void gpuOutputFast(page_state_s volatile* s, - state_buf volatile* sb, +inline __device__ void gpuOutputFast(volatile page_state_s* s, + volatile state_buf* sb, int src_pos, T* dst) { @@ -372,7 +372,7 @@ inline __device__ void gpuOutputFast(page_state_s volatile* s, */ template static __device__ void gpuOutputGeneric( - page_state_s volatile* s, state_buf volatile* sb, int src_pos, uint8_t* dst8, int len) + volatile page_state_s* s, volatile state_buf* sb, int src_pos, uint8_t* dst8, int len) { uint8_t const* dict; uint32_t dict_pos, dict_size = s->dict_size; @@ -496,7 +496,7 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( } else if ((s->col.data_type & 7) == BYTE_ARRAY) { gpuInitStringDescriptors(s, sb, src_target_pos, t & 0x1f); } - if (t == 32) { *(int32_t volatile*)&s->dict_pos = src_target_pos; } + if (t == 32) { *(volatile int32_t*)&s->dict_pos = src_target_pos; } } else { // WARP1..WARP3: Decode values int const dtype = s->col.data_type & 7; @@ -583,7 +583,7 @@ __global__ void __launch_bounds__(decode_block_size) gpuDecodePageData( } } - if (t == out_thread0) { *(int32_t volatile*)&s->src_pos = target_pos; } + if (t == out_thread0) { *(volatile int32_t*)&s->src_pos = target_pos; } } __syncthreads(); } From 7b09c4f55f48533be262112e8f2b798450b6256a Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Fri, 7 Jul 2023 12:37:30 -0700 Subject: [PATCH 14/77] update doc string --- cpp/src/io/parquet/page_string_utils.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/page_string_utils.cuh b/cpp/src/io/parquet/page_string_utils.cuh index fb36c09052c..3be52642f3a 100644 --- a/cpp/src/io/parquet/page_string_utils.cuh +++ b/cpp/src/io/parquet/page_string_utils.cuh @@ -87,7 +87,7 @@ inline __device__ void ll_strcpy(uint8_t* dst, uint8_t const* src, size_t len, u } /** - * @brief Perform exclusive scan for offsets array. Called for each page. + * @brief Perform exclusive scan on an array of any length using a single block of threads. */ template __device__ void block_excl_sum(size_type* arr, size_type length, size_type initial_value) From 9b636c7534991abffe600da4d36746ea3228ea4e Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 19 Jul 2023 09:43:54 -0700 Subject: [PATCH 15/77] fix for header location --- cpp/src/io/parquet/decode_preprocess.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/decode_preprocess.cu b/cpp/src/io/parquet/decode_preprocess.cu index 439422d3554..03510cc693e 100644 --- a/cpp/src/io/parquet/decode_preprocess.cu +++ b/cpp/src/io/parquet/decode_preprocess.cu @@ -18,7 +18,7 @@ #include -#include +#include #include #include From ceb22ab04a13777bdc42ce4e09955223cf9bf371 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 19 Jul 2023 10:13:42 -0700 Subject: [PATCH 16/77] fix some short-circuit logic --- cpp/src/io/parquet/page_string_decode.cu | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/page_string_decode.cu b/cpp/src/io/parquet/page_string_decode.cu index 8ab52f32226..568457bc17c 100644 --- a/cpp/src/io/parquet/page_string_decode.cu +++ b/cpp/src/io/parquet/page_string_decode.cu @@ -123,7 +123,13 @@ __device__ thrust::pair page_bounds(page_state_s* const s, int row_fudge = -1; // short circuit for no nulls - if (max_def == 0 && !has_repetition) { return {begin_row, end_row}; } + if (max_def == 0 && !has_repetition) { + if (t == 0) { + pp->num_nulls = 0; + pp->num_valids = end_row - begin_row; + } + return {begin_row, end_row}; + } int row_count = 0; int leaf_count = 0; From ea49a23bc97489de0075b19681d93e23a0c03ce1 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Tue, 1 Aug 2023 21:25:22 -0700 Subject: [PATCH 17/77] rename function --- cpp/src/io/parquet/page_data.cu | 4 ++-- cpp/src/io/parquet/parquet_gpu.hpp | 8 ++++++-- cpp/src/io/parquet/reader_impl.cpp | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cpp/src/io/parquet/page_data.cu b/cpp/src/io/parquet/page_data.cu index 51cb810d9c4..ab08368362e 100644 --- a/cpp/src/io/parquet/page_data.cu +++ b/cpp/src/io/parquet/page_data.cu @@ -596,8 +596,8 @@ struct mask_tform { } // anonymous namespace -uint32_t GetKernelMasks(cudf::detail::hostdevice_vector& pages, - rmm::cuda_stream_view stream) +uint32_t SumPageKernelMasks(cudf::detail::hostdevice_vector& pages, + rmm::cuda_stream_view stream) { // determine which kernels to invoke auto mask_iter = thrust::make_transform_iterator(pages.d_begin(), mask_tform{}); diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 90274921595..2b9887487ba 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -475,9 +475,13 @@ void BuildStringDictionaryIndex(ColumnChunkDesc* chunks, /** * @brief Get OR'd sum of page kernel masks. + * + * @param[in] pages List of pages to aggregate + * @param[in] stream CUDA stream to use + * @return Bitwise OR of all page `kernel_mask` values */ -uint32_t GetKernelMasks(cudf::detail::hostdevice_vector& pages, - rmm::cuda_stream_view stream); +uint32_t SumPageKernelMasks(cudf::detail::hostdevice_vector& pages, + rmm::cuda_stream_view stream); /** * @brief Compute page output size information. diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 1623e0662d4..87ea404e3e7 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -59,7 +59,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) }); // figure out which kernels to run - auto const kernel_mask = GetKernelMasks(pages, _stream); + auto const kernel_mask = SumPageKernelMasks(pages, _stream); // Check to see if there are any string columns present. If so, then we need to get size info // for each string page. This size info will be used to pre-allocate memory for the column, From 3a9f186fd984d6eea70934997c1d67dfadc869fe Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Tue, 1 Aug 2023 21:37:35 -0700 Subject: [PATCH 18/77] clean up kernel_mask_for_page() --- cpp/src/io/parquet/page_hdr.cu | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/cpp/src/io/parquet/page_hdr.cu b/cpp/src/io/parquet/page_hdr.cu index 2fd2c05aeee..84daed56677 100644 --- a/cpp/src/io/parquet/page_hdr.cu +++ b/cpp/src/io/parquet/page_hdr.cu @@ -154,17 +154,24 @@ __device__ void skip_struct_field(byte_stream_s* bs, int field_type) } while (rep_cnt || struct_depth); } -__device__ uint32_t get_kernel_mask(gpu::PageInfo const& page, gpu::ColumnChunkDesc const& chunk) +/** + * @brief Determine which decode kernel to run for the given page. + * + * @param page The page to decode + * @param chunk Column chunk the page belongs to + * @return `kernel_mask_bits` value for the given page +*/ +__device__ uint32_t kernel_mask_for_page(gpu::PageInfo const& page, gpu::ColumnChunkDesc const& chunk) { if (page.flags & PAGEINFO_FLAGS_DICTIONARY) { return 0; } - // non-string, non-nested, non-dict, non-boolean types if (page.encoding == Encoding::DELTA_BINARY_PACKED) { return KERNEL_MASK_DELTA_BINARY; } else if (is_string_col(chunk)) { return KERNEL_MASK_STRING; } + // non-string, non-delta return KERNEL_MASK_GENERAL; } @@ -435,7 +442,7 @@ __global__ void __launch_bounds__(128) } bs->page.page_data = const_cast(bs->cur); bs->cur += bs->page.compressed_page_size; - bs->page.kernel_mask = get_kernel_mask(bs->page, bs->ck); + bs->page.kernel_mask = kernel_mask_for_page(bs->page, bs->ck); } else { bs->cur = bs->end; } From d7671d7148df173fea756913c9db4aa567579548 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Tue, 1 Aug 2023 21:43:21 -0700 Subject: [PATCH 19/77] remove TODO --- cpp/src/io/parquet/rle_stream.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/io/parquet/rle_stream.cuh b/cpp/src/io/parquet/rle_stream.cuh index 9d2e8baa1cc..42e1833008a 100644 --- a/cpp/src/io/parquet/rle_stream.cuh +++ b/cpp/src/io/parquet/rle_stream.cuh @@ -141,7 +141,6 @@ struct rle_run { // a stream of rle_runs template struct rle_stream { - // TODO: consider if these should be template parameters to rle_stream static constexpr int num_rle_stream_decode_threads = decode_threads; // the -1 here is for the look-ahead warp that fills in the list of runs to be decoded // in an overlapped manner. so if we had 16 total warps: From 7084d89cfab89017237e23fdb9c2c5a1bbed9f7a Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Tue, 1 Aug 2023 22:08:23 -0700 Subject: [PATCH 20/77] add some documentation to kernel_mask_bits --- cpp/src/io/parquet/parquet_gpu.hpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 2b9887487ba..ad41f5b21b0 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -93,14 +93,15 @@ enum level_type { NUM_LEVEL_TYPES }; +/** + * @brief Enum of mask bits for the PageInfo kernel_mask + * + * Used to control which decode kernels to run. +*/ enum kernel_mask_bits { - KERNEL_MASK_GENERAL = (1 << 0), - KERNEL_MASK_STRING = (1 << 1), - KERNEL_MASK_DELTA_BINARY = (1 << 2) - // KERNEL_MASK_FIXED_WIDTH_DICT, - // KERNEL_MASK_STRINGS, - // KERNEL_NESTED_ - // etc + KERNEL_MASK_GENERAL = (1 << 0), // Run catch-all decode kernel + KERNEL_MASK_STRING = (1 << 1), // Run decode kernel for string data + KERNEL_MASK_DELTA_BINARY = (1 << 2) // Run decode kernel for DELTA_BINARY_PACKED data }; /** From 9ebeca9c6b7a23d454363402315800d1c5207a74 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 2 Aug 2023 17:22:00 -0700 Subject: [PATCH 21/77] use rand_dataframe() to produce test data --- python/cudf/cudf/tests/test_parquet.py | 33 ++++++++++++++++---------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 907d36ed143..ca3541c4c31 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -1289,21 +1289,28 @@ def test_parquet_reader_v2(tmpdir, simple_pdf): @pytest.mark.parametrize("add_nulls", [True, False]) def test_delta_binary(add_nulls, tmpdir): nrows = 100000 + null_frequency = 0.25 if add_nulls else 0 + # Create a pandas dataframe with random data of mixed types - test_pdf = pd.DataFrame( - { - "col_int32": pd.Series( - np.random.randint(0, 0x7FFFFFFF, nrows), dtype="Int32" - ), - "col_int64": pd.Series( - np.random.randint(0, 0x7FFFFFFFFFFFFFFF, nrows), dtype="Int64" - ), - }, + arrow_table = dg.rand_dataframe( + dtypes_meta=[ + { + "dtype": "int32", + "null_frequency": null_frequency, + "cardinality": nrows, + }, + { + "dtype": "int64", + "null_frequency": null_frequency, + "cardinality": nrows, + }, + ], + rows=nrows, + use_threads=False, ) - if add_nulls: - for i in range(0, int(nrows / 4)): - test_pdf.iloc[np.random.randint(0, nrows), 0] = pd.NA - test_pdf.iloc[np.random.randint(0, nrows), 1] = pd.NA + # Roundabout conversion to pandas to preserve nulls/data types + cudf_table = cudf.DataFrame.from_arrow(arrow_table) + test_pdf = cudf_table.to_pandas(nullable=True) pdf_fname = tmpdir.join("pdfv2.parquet") test_pdf.to_parquet( pdf_fname, From 761393f5478979934c860c93ade7b1b497c378dc Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 3 Aug 2023 17:34:44 -0700 Subject: [PATCH 22/77] formatting --- cpp/src/io/parquet/page_hdr.cu | 5 +++-- cpp/src/io/parquet/parquet_gpu.hpp | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/src/io/parquet/page_hdr.cu b/cpp/src/io/parquet/page_hdr.cu index 84daed56677..f99702a540f 100644 --- a/cpp/src/io/parquet/page_hdr.cu +++ b/cpp/src/io/parquet/page_hdr.cu @@ -160,8 +160,9 @@ __device__ void skip_struct_field(byte_stream_s* bs, int field_type) * @param page The page to decode * @param chunk Column chunk the page belongs to * @return `kernel_mask_bits` value for the given page -*/ -__device__ uint32_t kernel_mask_for_page(gpu::PageInfo const& page, gpu::ColumnChunkDesc const& chunk) + */ +__device__ uint32_t kernel_mask_for_page(gpu::PageInfo const& page, + gpu::ColumnChunkDesc const& chunk) { if (page.flags & PAGEINFO_FLAGS_DICTIONARY) { return 0; } diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index ad41f5b21b0..d0fbe397291 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -97,7 +97,7 @@ enum level_type { * @brief Enum of mask bits for the PageInfo kernel_mask * * Used to control which decode kernels to run. -*/ + */ enum kernel_mask_bits { KERNEL_MASK_GENERAL = (1 << 0), // Run catch-all decode kernel KERNEL_MASK_STRING = (1 << 1), // Run decode kernel for string data @@ -476,7 +476,7 @@ void BuildStringDictionaryIndex(ColumnChunkDesc* chunks, /** * @brief Get OR'd sum of page kernel masks. - * + * * @param[in] pages List of pages to aggregate * @param[in] stream CUDA stream to use * @return Bitwise OR of all page `kernel_mask` values From 484603273ef9553a2205d14ff6317c295341c308 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Fri, 4 Aug 2023 15:41:34 -0700 Subject: [PATCH 23/77] implement suggestion from review Co-authored-by: Vukasin Milovanovic --- cpp/src/io/parquet/page_string_decode.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/io/parquet/page_string_decode.cu b/cpp/src/io/parquet/page_string_decode.cu index 568457bc17c..561677106ae 100644 --- a/cpp/src/io/parquet/page_string_decode.cu +++ b/cpp/src/io/parquet/page_string_decode.cu @@ -45,8 +45,8 @@ constexpr int preproc_buf_size = LEVEL_DECODE_BUF_SIZE; * @param has_repetition True if the schema is nested * @param decoders Definition and repetition level decoders * @return pair containing start and end value indexes - * @tparam rle_buf_size Size of the buffer used when decoding repetition and definition levels * @tparam level_t Type used to store decoded repetition and definition levels + * @tparam rle_buf_size Size of the buffer used when decoding repetition and definition levels */ template __device__ thrust::pair page_bounds(page_state_s* const s, From d6156253e3130b2657aad5844bed554dc054167c Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 4 Aug 2023 16:20:09 -0700 Subject: [PATCH 24/77] more suggestions from review --- cpp/src/io/parquet/delta_binary.cuh | 20 +++++++++++++------- cpp/src/io/parquet/page_delta_decode.cu | 13 +++++++------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh index be5256b242d..9632e58eabb 100644 --- a/cpp/src/io/parquet/delta_binary.cuh +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -108,12 +108,12 @@ struct delta_binary_decoder { uleb128_t value[delta_rolling_buf_size]; // circular buffer of delta values - // returns the number of values encoded in the block data. when is_decode is true, + // returns the number of values encoded in the block data. when all_values is true, // account for the first value in the header. otherwise just count the values encoded // in the mini-block data. - constexpr uint32_t num_encoded_values(bool is_decode) + constexpr uint32_t num_encoded_values(bool all_values) { - return value_count == 0 ? 0 : is_decode ? value_count : value_count - 1; + return value_count == 0 ? 0 : all_values ? value_count : value_count - 1; } // read mini-block header into state object. should only be called from init_binary_block or @@ -123,6 +123,8 @@ struct delta_binary_decoder { // // on exit db->cur_mb is 0 and db->cur_mb_start points to the first mini-block of data, or // nullptr if out of data. + // is_decode indicates whether this is being called from initialization code (false) or + // the actual decoding (true) inline __device__ void init_mini_block(bool is_decode) { cur_mb = 0; @@ -161,6 +163,8 @@ struct delta_binary_decoder { // skip to the start of the next mini-block. should only be called on thread 0. // calls init_binary_block if currently on the last mini-block in a block. + // is_decode indicates whether this is being called from initialization code (false) or + // the actual decoding (true) inline __device__ void setup_next_mini_block(bool is_decode) { if (current_value_idx >= num_encoded_values(is_decode)) { return; } @@ -242,7 +246,7 @@ struct delta_binary_decoder { // save value from last lane in warp. this will become the 'first value' added to the // deltas calculated in the next iteration (or invocation). - if (lane_id == 31) { last_value = delta; } + if (lane_id == warp_size - 1) { last_value = delta; } __syncwarp(); } } @@ -251,11 +255,12 @@ struct delta_binary_decoder { // called by all threads in a thread block. inline __device__ void skip_values(int skip) { + using cudf::detail::warp_size; int const t = threadIdx.x; - int const lane_id = t & 0x1f; + int const lane_id = t % warp_size; while (current_value_idx < skip && current_value_idx < num_encoded_values(true)) { - if (t < 32) { + if (t < warp_size) { calc_mini_block_values(lane_id); if (lane_id == 0) { setup_next_mini_block(true); } } @@ -267,8 +272,9 @@ struct delta_binary_decoder { // a single warp. inline __device__ void decode_batch() { + using cudf::detail::warp_size; int const t = threadIdx.x; - int const lane_id = t & 0x1f; + int const lane_id = t % warp_size; // unpack deltas and save in db->value calc_mini_block_values(lane_id); diff --git a/cpp/src/io/parquet/page_delta_decode.cu b/cpp/src/io/parquet/page_delta_decode.cu index e8fafe4445f..6e9377627d2 100644 --- a/cpp/src/io/parquet/page_delta_decode.cu +++ b/cpp/src/io/parquet/page_delta_decode.cu @@ -35,6 +35,7 @@ template __global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( PageInfo* pages, device_span chunks, size_t min_row, size_t num_rows) { + using cudf::detail::warp_size; __shared__ __align__(16) delta_binary_decoder db_state; __shared__ __align__(16) page_state_s state_g; __shared__ __align__(16) page_state_buffers_s state_buffers; @@ -43,7 +44,7 @@ __global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( auto* const sb = &state_buffers; int const page_idx = blockIdx.x; int const t = threadIdx.x; - int const lane_id = t & 0x1f; + int const lane_id = t % warp_size; auto* const db = &db_state; [[maybe_unused]] null_count_back_copier _{s, t}; @@ -82,9 +83,9 @@ __global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( uint32_t target_pos; uint32_t const src_pos = s->src_pos; - if (t < 64) { // warp0..1 + if (t < 2 * warp_size) { // warp0..1 target_pos = min(src_pos + 2 * batch_size, s->nz_count + batch_size); - } else { // warp2... + } else { // warp2 target_pos = min(s->nz_count, src_pos + batch_size); } __syncthreads(); @@ -92,18 +93,18 @@ __global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( // warp0 will decode the rep/def levels, warp1 will unpack a mini-batch of deltas. // warp2 waits one cycle for warps 0/1 to produce a batch, and then stuffs values // into the proper location in the output. - if (t < 32) { + if (t < warp_size) { // warp 0 // decode repetition and definition levels. // - update validity vectors // - updates offsets (for nested columns) // - produces non-NULL value indices in s->nz_idx for subsequent decoding gpuDecodeLevels(s, sb, target_pos, rep, def, t); - } else if (t < 64) { + } else if (t < 2 * warp_size) { // warp 1 db->decode_batch(); - } else if (t < 96 && src_pos < target_pos) { + } else if (src_pos < target_pos) { // warp 2 // nesting level that is storing actual leaf values int const leaf_level_index = s->col.max_nesting_depth - 1; From 835e8660ead27353dfc7ba255311165a97be2f92 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Fri, 4 Aug 2023 23:35:11 -0700 Subject: [PATCH 25/77] restore old unrolled loop for testing --- cpp/src/io/parquet/delta_binary.cuh | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh index 9632e58eabb..bd68193c06d 100644 --- a/cpp/src/io/parquet/delta_binary.cuh +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -223,10 +223,45 @@ struct delta_binary_decoder { uint32_t c = 8 - ofs; // 0 - 7 bits delta = (*p++) >> ofs; +#if 1 while (c < mb_bits && p < block_end) { delta |= static_cast(*p++) << c; c += 8; } +#else + // bring back unrolled unpacker for perf testing + if (c < mb_bits && p < block_end) { // up to 8 bits + delta |= (*p++) << c; + c += 8; + if (c < mb_bits && p < block_end) { // up to 16 bits + delta |= (*p++) << c; + c += 8; + if (c < mb_bits && p < block_end) { // 24 bits + delta |= (*p++) << c; + c += 8; + if (c < mb_bits && p < block_end) { // 32 bits + delta |= static_cast(*p++) << c; + c += 8; + if (c < mb_bits && p < block_end) { // 40 bits + delta |= static_cast(*p++) << c; + c += 8; + if (c < mb_bits && p < block_end) { // 48 bits + delta |= static_cast(*p++) << c; + c += 8; + if (c < mb_bits && p < block_end) { // 56 bits + delta |= static_cast(*p++) << c; + c += 8; + if (c < mb_bits && p < block_end) { // 64 bits + delta |= static_cast(*p++) << c; + } + } + } + } + } + } + } + } +#endif delta &= (static_cast(1) << mb_bits) - 1; } } From 6353a4aa2bb706ab24a8a9dfe3e6c2291726c84c Mon Sep 17 00:00:00 2001 From: seidl Date: Mon, 7 Aug 2023 11:25:51 -0700 Subject: [PATCH 26/77] add note to revisit bit unpacker with delta_byte_array --- cpp/src/io/parquet/delta_binary.cuh | 41 ++++------------------------- 1 file changed, 5 insertions(+), 36 deletions(-) diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh index bd68193c06d..d29c52248d9 100644 --- a/cpp/src/io/parquet/delta_binary.cuh +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -213,7 +213,11 @@ struct delta_binary_decoder { // unpack deltas. modified from version in gpuDecodeDictionaryIndices(), but // that one only unpacks up to bitwidths of 24. simplified some since this - // will always do batches of 32. also replaced branching with a loop. + // will always do batches of 32. + // NOTE: because this needs to handle up to 64 bits, the branching used in the other + // implementation has been replaced with a loop. While this uses more registers, the + // looping version is just as fast and easier to read. Might need to revisit this when + // DELTA_BYTE_ARRAY is implemented. zigzag128_t delta = 0; if (lane_id + current_value_idx < value_count) { int32_t ofs = (lane_id - warp_size) * mb_bits; @@ -223,45 +227,10 @@ struct delta_binary_decoder { uint32_t c = 8 - ofs; // 0 - 7 bits delta = (*p++) >> ofs; -#if 1 while (c < mb_bits && p < block_end) { delta |= static_cast(*p++) << c; c += 8; } -#else - // bring back unrolled unpacker for perf testing - if (c < mb_bits && p < block_end) { // up to 8 bits - delta |= (*p++) << c; - c += 8; - if (c < mb_bits && p < block_end) { // up to 16 bits - delta |= (*p++) << c; - c += 8; - if (c < mb_bits && p < block_end) { // 24 bits - delta |= (*p++) << c; - c += 8; - if (c < mb_bits && p < block_end) { // 32 bits - delta |= static_cast(*p++) << c; - c += 8; - if (c < mb_bits && p < block_end) { // 40 bits - delta |= static_cast(*p++) << c; - c += 8; - if (c < mb_bits && p < block_end) { // 48 bits - delta |= static_cast(*p++) << c; - c += 8; - if (c < mb_bits && p < block_end) { // 56 bits - delta |= static_cast(*p++) << c; - c += 8; - if (c < mb_bits && p < block_end) { // 64 bits - delta |= static_cast(*p++) << c; - } - } - } - } - } - } - } - } -#endif delta &= (static_cast(1) << mb_bits) - 1; } } From e924a97fac2d68a8e91331a7b65299d9a44c15e1 Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 8 Aug 2023 11:27:58 -0700 Subject: [PATCH 27/77] fix and test int8 and int16 handling --- cpp/src/io/parquet/page_delta_decode.cu | 23 +++++++++++++++++------ python/cudf/cudf/tests/test_parquet.py | 10 ++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/cpp/src/io/parquet/page_delta_decode.cu b/cpp/src/io/parquet/page_delta_decode.cu index 6e9377627d2..e79a479388f 100644 --- a/cpp/src/io/parquet/page_delta_decode.cu +++ b/cpp/src/io/parquet/page_delta_decode.cu @@ -120,12 +120,23 @@ __global__ void __launch_bounds__(96) gpuDecodeDeltaBinary( // place value for this thread if (dst_pos >= 0 && sp < target_pos) { void* const dst = nesting_info_base[leaf_level_index].data_out + dst_pos * s->dtype_len; - if (s->dtype_len == 8) { - *static_cast(dst) = - db->value[rolling_index(sp + skipped_leaf_values)]; - } else if (s->dtype_len == 4) { - *static_cast(dst) = - db->value[rolling_index(sp + skipped_leaf_values)]; + switch (s->dtype_len) { + case 1: + *static_cast(dst) = + db->value[rolling_index(sp + skipped_leaf_values)]; + break; + case 2: + *static_cast(dst) = + db->value[rolling_index(sp + skipped_leaf_values)]; + break; + case 4: + *static_cast(dst) = + db->value[rolling_index(sp + skipped_leaf_values)]; + break; + case 8: + *static_cast(dst) = + db->value[rolling_index(sp + skipped_leaf_values)]; + break; } } } diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index ca3541c4c31..8b86ce67baa 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -1294,6 +1294,16 @@ def test_delta_binary(add_nulls, tmpdir): # Create a pandas dataframe with random data of mixed types arrow_table = dg.rand_dataframe( dtypes_meta=[ + { + "dtype": "int8", + "null_frequency": null_frequency, + "cardinality": nrows, + }, + { + "dtype": "int16", + "null_frequency": null_frequency, + "cardinality": nrows, + }, { "dtype": "int32", "null_frequency": null_frequency, From d0bf0cd3321ee5e09b5a12d906505cf15ea359c2 Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 8 Aug 2023 13:50:06 -0700 Subject: [PATCH 28/77] fix for single row files --- cpp/src/io/parquet/delta_binary.cuh | 5 ++++- python/cudf/cudf/tests/test_parquet.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/cpp/src/io/parquet/delta_binary.cuh b/cpp/src/io/parquet/delta_binary.cuh index d29c52248d9..4fc8b9cfb8e 100644 --- a/cpp/src/io/parquet/delta_binary.cuh +++ b/cpp/src/io/parquet/delta_binary.cuh @@ -158,7 +158,9 @@ struct delta_binary_decoder { // init the first mini-block block_start = d_start; - init_mini_block(false); + + // only call init if there are actually encoded values + if (value_count > 1) { init_mini_block(false); } } // skip to the start of the next mini-block. should only be called on thread 0. @@ -197,6 +199,7 @@ struct delta_binary_decoder { value[0] = last_value; } __syncwarp(); + if (current_value_idx >= value_count) { return; } } uint32_t const mb_bits = cur_bitwidths[cur_mb]; diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 8b86ce67baa..b35526cb397 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -1286,9 +1286,9 @@ def test_parquet_reader_v2(tmpdir, simple_pdf): assert_eq(cudf.read_parquet(pdf_fname), simple_pdf) +@pytest.mark.parametrize("nrows", [1, 100000]) @pytest.mark.parametrize("add_nulls", [True, False]) -def test_delta_binary(add_nulls, tmpdir): - nrows = 100000 +def test_delta_binary(nrows, add_nulls, tmpdir): null_frequency = 0.25 if add_nulls else 0 # Create a pandas dataframe with random data of mixed types From 5ac20a08491384aed366550f01b6080225681f81 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 9 Aug 2023 16:36:14 -0700 Subject: [PATCH 29/77] clean up some docstrings --- cpp/src/io/parquet/page_decode.cuh | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/cpp/src/io/parquet/page_decode.cuh b/cpp/src/io/parquet/page_decode.cuh index fddb82cd3b2..cab4ef8ecbd 100644 --- a/cpp/src/io/parquet/page_decode.cuh +++ b/cpp/src/io/parquet/page_decode.cuh @@ -166,6 +166,7 @@ inline __device__ bool is_page_contained(page_state_s* const s, size_t start_row * @param[in] s Page state input * @param[out] sb Page state buffer output * @param[in] src_pos Source position + * @tparam state_buf Typename of the `state_buf` (usually inferred) * * @return A pair containing a pointer to the string and its length */ @@ -208,6 +209,8 @@ inline __device__ cuda::std::pair gpuGetStringData(page_sta * @param[in] target_pos Target index position in dict_idx buffer (may exceed this value by up to * 31) * @param[in] t Warp1 thread ID (0..31) + * @tparam sizes_only True if only sizes are to be calculated + * @tparam state_buf Typename of the `state_buf` (usually inferred) * * @return A pair containing the new output position, and the total length of strings decoded (this * will only be valid on thread 0 and if sizes_only is true). In the event that this function @@ -327,6 +330,7 @@ __device__ cuda::std::pair gpuDecodeDictionaryIndices( * @param[out] sb Page state buffer output * @param[in] target_pos Target write position * @param[in] t Thread ID + * @tparam state_buf Typename of the `state_buf` (usually inferred) * * @return The new output position */ @@ -396,6 +400,8 @@ inline __device__ int gpuDecodeRleBooleans(page_state_s volatile* s, * @param[out] sb Page state buffer output * @param[in] target_pos Target output position * @param[in] t Thread ID + * @tparam sizes_only True if only sizes are to be calculated + * @tparam state_buf Typename of the `state_buf` (usually inferred) * * @return Total length of strings processed */ @@ -441,10 +447,13 @@ __device__ size_type gpuInitStringDescriptors(page_state_s volatile* s, /** * @brief Decode values out of a definition or repetition stream * + * @param[out] output Pointer to buffer to store level data to * @param[in,out] s Page state input/output - * @param[in] t target_count Target count of stream values on output + * @param[in] target_count Target count of stream values on output * @param[in] t Warp0 thread ID (0..31) * @param[in] lvl The level type we are decoding - DEFINITION or REPETITION + * @tparam level_t Type used to store decoded repetition and definition levels + * @tparam rolling_buf_size Size of the cyclic buffer used to store value data */ template __device__ void gpuDecodeStream( @@ -535,6 +544,7 @@ __device__ void gpuDecodeStream( * * @param[in,out] nesting_info The page/nesting information to store the mask in. The validity map * offset is also updated + * @param[in,out] valid_map Pointer to bitmask to store validity information to * @param[in] valid_mask The validity mask to be stored * @param[in] value_count # of bits in the validity mask */ @@ -592,6 +602,8 @@ inline __device__ void store_validity(int valid_map_offset, * @param[in] input_value_count The current count of input level values we have processed * @param[in] target_input_value_count The desired # of input level values we want to process * @param[in] t Thread index + * @tparam rolling_buf_size Size of the cyclic buffer used to store value data + * @tparam level_t Type used to store decoded repetition and definition levels */ template inline __device__ void get_nesting_bounds(int& start_depth, @@ -636,6 +648,9 @@ inline __device__ void get_nesting_bounds(int& start_depth, * @param[in] rep Repetition level buffer * @param[in] def Definition level buffer * @param[in] t Thread index + * @tparam level_t Type used to store decoded repetition and definition levels + * @tparam state_buf Typename of the `state_buf` (usually inferred) + * @tparam rolling_buf_size Size of the cyclic buffer used to store value data */ template __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_input_value_count, @@ -818,6 +833,9 @@ __device__ void gpuUpdateValidityOffsetsAndRowIndices(int32_t target_input_value * @param[in] rep Repetition level buffer * @param[in] def Definition level buffer * @param[in] t Thread index + * @tparam rolling_buf_size Size of the cyclic buffer used to store value data + * @tparam level_t Type used to store decoded repetition and definition levels + * @tparam state_buf Typename of the `state_buf` (usually inferred) */ template __device__ void gpuDecodeLevels(page_state_s* s, @@ -861,9 +879,7 @@ __device__ void gpuDecodeLevels(page_state_s* s, * @param[in,out] s The page state * @param[in] cur The current data position * @param[in] end The end of the data - * @param[in] level_bits The bits required - * @param[in] is_decode_step True if we are performing the decode step. - * @param[in,out] decoders The repetition and definition level stream decoders + * @param[in] lvl Enum indicating whether this is to initialize repetition or definition level data * * @return The length of the section */ @@ -973,9 +989,9 @@ struct mask_filter { * @param[in] num_rows Maximum number of rows to read * @param[in] filter Filtering function used to decide which pages to operate on * @param[in] is_decode_step If we are setting up for the decode step (instead of the preprocess) - * @param[in] decoders rle_stream decoders which will be used for decoding levels. Optional. * @tparam Filter Function that takes a PageInfo reference and returns true if the given page should * be operated on Currently only used by gpuComputePageSizes step) + * @return True if this page should be processed further */ template inline __device__ bool setupLocalPageInfo(page_state_s* const s, From 1d137ec058d593f4bd7f2d524f6c3481c8599a5b Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Thu, 17 Aug 2023 13:00:28 -0700 Subject: [PATCH 30/77] Apply suggestions from code review Co-authored-by: nvdbaranec <56695930+nvdbaranec@users.noreply.github.com> --- cpp/src/io/parquet/decode_preprocess.cu | 2 +- cpp/src/io/parquet/page_string_utils.cuh | 2 +- cpp/src/io/parquet/rle_stream.cuh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/io/parquet/decode_preprocess.cu b/cpp/src/io/parquet/decode_preprocess.cu index 03510cc693e..fb1910f693b 100644 --- a/cpp/src/io/parquet/decode_preprocess.cu +++ b/cpp/src/io/parquet/decode_preprocess.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/io/parquet/page_string_utils.cuh b/cpp/src/io/parquet/page_string_utils.cuh index 3be52642f3a..a68af7cfb16 100644 --- a/cpp/src/io/parquet/page_string_utils.cuh +++ b/cpp/src/io/parquet/page_string_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/io/parquet/rle_stream.cuh b/cpp/src/io/parquet/rle_stream.cuh index 42e1833008a..189e1bd9750 100644 --- a/cpp/src/io/parquet/rle_stream.cuh +++ b/cpp/src/io/parquet/rle_stream.cuh @@ -24,7 +24,7 @@ namespace cudf::io::parquet::gpu { constexpr int rle_stream_required_run_buffer_size(int num_threads) { - int num_rle_stream_decode_warps = (num_threads / cudf::detail::warp_size) - 1; + constexpr int num_rle_stream_decode_warps = (num_threads / cudf::detail::warp_size) - 1; return (num_rle_stream_decode_warps * 2); } From 6cd3e001f729681bd6d5cc5813310033f0e63765 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 17 Aug 2023 13:05:32 -0700 Subject: [PATCH 31/77] fix docstring --- cpp/src/io/parquet/parquet_gpu.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/io/parquet/parquet_gpu.hpp b/cpp/src/io/parquet/parquet_gpu.hpp index 7ac67d27614..cab6d3469b0 100644 --- a/cpp/src/io/parquet/parquet_gpu.hpp +++ b/cpp/src/io/parquet/parquet_gpu.hpp @@ -592,6 +592,7 @@ void DecodeStringPageData(cudf::detail::hostdevice_vector& pages, * @param[in] chunks All chunks to be decoded * @param[in] num_rows Total number of rows to read * @param[in] min_row Minimum number of rows to read + * @param[in] level_type_size Size in bytes of the type for level decoding * @param[in] stream CUDA stream to use, default 0 */ void DecodeDeltaBinary(cudf::detail::hostdevice_vector& pages, From a774ac1099491dde0dcbe060e2ee998a0450157a Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 17 Aug 2023 13:47:07 -0700 Subject: [PATCH 32/77] need to pass num_threads as template param to make constexpr --- cpp/src/io/parquet/decode_preprocess.cu | 2 +- cpp/src/io/parquet/page_string_decode.cu | 2 +- cpp/src/io/parquet/rle_stream.cuh | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/src/io/parquet/decode_preprocess.cu b/cpp/src/io/parquet/decode_preprocess.cu index fb1910f693b..8de3702bc2e 100644 --- a/cpp/src/io/parquet/decode_preprocess.cu +++ b/cpp/src/io/parquet/decode_preprocess.cu @@ -35,7 +35,7 @@ constexpr int preprocess_block_size = 512; // the required number of runs in shared memory we will need to provide the // rle_stream object -constexpr int rle_run_buffer_size = rle_stream_required_run_buffer_size(preprocess_block_size); +constexpr int rle_run_buffer_size = rle_stream_required_run_buffer_size(); // the size of the rolling batch buffer constexpr int rolling_buf_size = LEVEL_DECODE_BUF_SIZE; diff --git a/cpp/src/io/parquet/page_string_decode.cu b/cpp/src/io/parquet/page_string_decode.cu index 561677106ae..40a7be800f4 100644 --- a/cpp/src/io/parquet/page_string_decode.cu +++ b/cpp/src/io/parquet/page_string_decode.cu @@ -485,7 +485,7 @@ __global__ void __launch_bounds__(preprocess_block_size) gpuComputePageStringSiz // the required number of runs in shared memory we will need to provide the // rle_stream object - constexpr int rle_run_buffer_size = rle_stream_required_run_buffer_size(preprocess_block_size); + constexpr int rle_run_buffer_size = rle_stream_required_run_buffer_size(); // the level stream decoders __shared__ rle_run def_runs[rle_run_buffer_size]; diff --git a/cpp/src/io/parquet/rle_stream.cuh b/cpp/src/io/parquet/rle_stream.cuh index 189e1bd9750..2545a074a38 100644 --- a/cpp/src/io/parquet/rle_stream.cuh +++ b/cpp/src/io/parquet/rle_stream.cuh @@ -22,7 +22,8 @@ namespace cudf::io::parquet::gpu { -constexpr int rle_stream_required_run_buffer_size(int num_threads) +template +constexpr int rle_stream_required_run_buffer_size() { constexpr int num_rle_stream_decode_warps = (num_threads / cudf::detail::warp_size) - 1; return (num_rle_stream_decode_warps * 2); @@ -149,7 +150,7 @@ struct rle_stream { static constexpr int num_rle_stream_decode_warps = (num_rle_stream_decode_threads / cudf::detail::warp_size) - 1; - static constexpr int run_buffer_size = rle_stream_required_run_buffer_size(decode_threads); + static constexpr int run_buffer_size = rle_stream_required_run_buffer_size(); int level_bits; uint8_t const* start; From d7112a59f46c4d7a38de312788df491ce2ca2d39 Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 18 Aug 2023 16:01:45 -0700 Subject: [PATCH 33/77] refactor stream pool --- cpp/CMakeLists.txt | 1 + cpp/src/io/parquet/reader_impl.cpp | 44 +++++--------- cpp/src/io/parquet/stream_pool.cpp | 94 ++++++++++++++++++++++++++++++ cpp/src/io/parquet/stream_pool.hpp | 78 +++++++++++++++++++++++++ 4 files changed, 187 insertions(+), 30 deletions(-) create mode 100644 cpp/src/io/parquet/stream_pool.cpp create mode 100644 cpp/src/io/parquet/stream_pool.hpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 516865e5782..007ce4e65f8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -402,6 +402,7 @@ add_library( src/io/parquet/reader_impl.cpp src/io/parquet/reader_impl_helpers.cpp src/io/parquet/reader_impl_preprocess.cu + src/io/parquet/stream_pool.cpp src/io/parquet/writer_impl.cu src/io/statistics/orc_column_statistics.cu src/io/statistics/parquet_column_statistics.cu diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 87ea404e3e7..11ad68925bc 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -15,34 +15,18 @@ */ #include "reader_impl.hpp" +#include "stream_pool.hpp" #include #include #include #include +#include #include namespace cudf::io::detail::parquet { -namespace { - -int constexpr NUM_DECODERS = 3; // how many decode kernels are there to run -int constexpr APPROX_NUM_THREADS = 4; // guestimate from DaveB -int constexpr STREAM_POOL_SIZE = NUM_DECODERS * APPROX_NUM_THREADS; - -auto& get_stream_pool() -{ - // TODO: creating this on the heap because there were issues with trying to call the - // stream pool destructor during cuda shutdown that lead to a segmentation fault in - // nvbench. this allocation is being deliberately leaked to avoid the above, but still - // results in non-fatal warnings when running nvbench in cuda-gdb. - static auto pool = new rmm::cuda_stream_pool{STREAM_POOL_SIZE}; - return *pool; -} - -} // namespace - void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) { auto& chunks = _file_itm_data.chunks; @@ -178,34 +162,34 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) chunks.host_to_device_async(_stream); chunk_nested_valids.host_to_device_async(_stream); chunk_nested_data.host_to_device_async(_stream); - _stream.synchronize(); - auto const level_type_size = _file_itm_data.level_type_size; + // get the number of streams we need from the pool and tell them to wait on the H2D copies + int nkernels = std::bitset<32>(kernel_mask).count(); + auto streams = get_streams(nkernels); + fork_streams(streams, _stream); - // vector of launched streams - std::vector streams; + auto const level_type_size = _file_itm_data.level_type_size; // launch string decoder + int s_idx = 0; if (has_strings) { - streams.push_back(get_stream_pool().get_stream()); - chunk_nested_str_data.host_to_device_async(streams.back()); - gpu::DecodeStringPageData(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); + auto& stream = streams[s_idx++]; + chunk_nested_str_data.host_to_device_async(stream); + gpu::DecodeStringPageData(pages, chunks, num_rows, skip_rows, level_type_size, stream); } // launch delta binary decoder if ((kernel_mask & gpu::KERNEL_MASK_DELTA_BINARY) != 0) { - streams.push_back(get_stream_pool().get_stream()); - gpu::DecodeDeltaBinary(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); + gpu::DecodeDeltaBinary(pages, chunks, num_rows, skip_rows, level_type_size, streams[s_idx++]); } // launch the catch-all page decoder if ((kernel_mask & gpu::KERNEL_MASK_GENERAL) != 0) { - streams.push_back(get_stream_pool().get_stream()); - gpu::DecodePageData(pages, chunks, num_rows, skip_rows, level_type_size, streams.back()); + gpu::DecodePageData(pages, chunks, num_rows, skip_rows, level_type_size, streams[s_idx++]); } // synchronize the streams - std::for_each(streams.begin(), streams.end(), [](auto& stream) { stream.synchronize(); }); + join_streams(streams, _stream); pages.device_to_host_async(_stream); page_nesting.device_to_host_async(_stream); diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp new file mode 100644 index 00000000000..20d1824fe6a --- /dev/null +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "stream_pool.hpp" + +#include + +namespace cudf::io::detail::parquet { + +namespace { + +std::size_t constexpr STREAM_POOL_SIZE = 32; + +// doing lazy initialization to avoid allocating this before cuInit is called (was +// causing issues with compute-sanitizer). +// deliberately allowing the pool to leak to avoid deleting streams after cuda shutdown. +rmm::cuda_stream_pool* pool = nullptr; + +std::mutex stream_pool_mutex; + +void init() +{ + std::lock_guard lock(stream_pool_mutex); + if (pool == nullptr) { pool = new rmm::cuda_stream_pool(STREAM_POOL_SIZE); } +} + +std::atomic_size_t stream_idx{}; + +} // anonymous namespace + +rmm::cuda_stream_view get_stream() +{ + init(); + return pool->get_stream(); +} + +rmm::cuda_stream_view get_stream(std::size_t stream_id) +{ + init(); + return pool->get_stream(stream_id); +} + +std::vector get_streams(uint32_t count) +{ + init(); + + // TODO maybe add mutex to be sure streams don't overlap + auto streams = std::vector(); + for (uint32_t i = 0; i < count; i++) { + streams.emplace_back(pool->get_stream((stream_idx++))); + } + return streams; +} + +void fork_streams(std::vector& streams, rmm::cuda_stream_view stream) +{ + cudaEvent_t event; + CUDF_CUDA_TRY(cudaEventCreate(&event)); + CUDF_CUDA_TRY(cudaEventRecord(event, stream)); + std::for_each(streams.begin(), streams.end(), [&](auto& strm) { + CUDF_CUDA_TRY(cudaStreamWaitEvent(strm, event, 0)); + }); + CUDF_CUDA_TRY(cudaEventDestroy(event)); +} + +void join_streams(std::vector& streams, rmm::cuda_stream_view stream) +{ + cudaEvent_t event; + CUDF_CUDA_TRY(cudaEventCreate(&event)); + std::for_each(streams.begin(), streams.end(), [&](auto& strm) { + CUDF_CUDA_TRY(cudaEventRecord(event, strm)); + CUDF_CUDA_TRY(cudaStreamWaitEvent(stream, event, 0)); + }); + CUDF_CUDA_TRY(cudaEventDestroy(event)); +} + +std::size_t get_stream_pool_size() { return STREAM_POOL_SIZE; } + +} // namespace cudf::io::detail::parquet diff --git a/cpp/src/io/parquet/stream_pool.hpp b/cpp/src/io/parquet/stream_pool.hpp new file mode 100644 index 00000000000..e9bef3bda35 --- /dev/null +++ b/cpp/src/io/parquet/stream_pool.hpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cudf::io::detail::parquet { + +/** + * @brief Get a `cuda_stream_view` of a stream in the pool. + * + * This function is thread safe with respect to other calls to the same function. + * + * @return Stream view. + */ +rmm::cuda_stream_view get_stream(); + +/** + * @brief Get a `cuda_stream_view` of the stream associated with `stream_id`. + * + * Equivalent values of `stream_id` return a stream_view to the same underlying stream. + * This function is thread safe with respect to other calls to the same function. + * + * @param stream_id Unique identifier for the desired stream + * @return Requested stream view. + */ +rmm::cuda_stream_view get_stream(std::size_t stream_id); + +/** + * @brief Get a set of `cuda_stream_view` objects from the pool. + * + * This function is thread safe with respect to other calls to the same function. + * + * @param count The number of stream views to return. + * @return Vector containing `count` stream views. + */ +std::vector get_streams(uint32_t count); + +/** + * @brief Synchronize a set of streams to an event on another stream. + * + * @param streams Vector of streams to synchronize on. + * @param stream Stream to synchronize the other streams to, usually the default stream. + */ +void fork_streams(std::vector& streams, rmm::cuda_stream_view stream); + +/** + * @brief Synchronize a stream to an event on a set of streams. + * + * @param streams Vector of streams to synchronize on. + * @param stream Stream to synchronize the other streams to, usually the default stream. + */ +void join_streams(std::vector& streams, rmm::cuda_stream_view stream); + +/** + * @brief Get the number of streams in the pool. + * + * This function is thread safe with respect to other calls to the same function. + * + * @return the number of streams in the pool + */ +std::size_t get_stream_pool_size(); + +} // namespace cudf::io::detail::parquet From 71ece7394741e622f6d4fe5f58170f8579f0f609 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Wed, 23 Aug 2023 22:38:36 -0700 Subject: [PATCH 34/77] checkpoint --- cpp/src/io/parquet/stream_pool.cpp | 32 +++++++++++++----------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 20d1824fe6a..013f0bf2116 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -26,43 +26,39 @@ namespace { std::size_t constexpr STREAM_POOL_SIZE = 32; -// doing lazy initialization to avoid allocating this before cuInit is called (was -// causing issues with compute-sanitizer). -// deliberately allowing the pool to leak to avoid deleting streams after cuda shutdown. -rmm::cuda_stream_pool* pool = nullptr; - std::mutex stream_pool_mutex; -void init() +auto& get_stream_pool() { - std::lock_guard lock(stream_pool_mutex); - if (pool == nullptr) { pool = new rmm::cuda_stream_pool(STREAM_POOL_SIZE); } + // TODO: is the following still true? test this again. + // TODO: creating this on the heap because there were issues with trying to call the + // stream pool destructor during cuda shutdown that lead to a segmentation fault in + // nvbench. this allocation is being deliberately leaked to avoid the above, but still + // results in non-fatal warnings when running nvbench in cuda-gdb. + static auto pool = new rmm::cuda_stream_pool{STREAM_POOL_SIZE}; + return *pool; } -std::atomic_size_t stream_idx{}; - } // anonymous namespace rmm::cuda_stream_view get_stream() { - init(); - return pool->get_stream(); + return get_stream_pool().get_stream(); } rmm::cuda_stream_view get_stream(std::size_t stream_id) { - init(); - return pool->get_stream(stream_id); + return get_stream_pool().get_stream(stream_id); } std::vector get_streams(uint32_t count) { - init(); - - // TODO maybe add mutex to be sure streams don't overlap + // TODO: if count > STREAM_POOL_SIZE log a warning + CUDF_LOG_WARN("get_streams called with count ({}) > pool size ({})", count, STREAM_POOL_SIZE); auto streams = std::vector(); + std::lock_guard lock(stream_pool_mutex); for (uint32_t i = 0; i < count; i++) { - streams.emplace_back(pool->get_stream((stream_idx++))); + streams.emplace_back(pool->get_stream()); } return streams; } From db1d08dd8912ea226416cb2aa965b797a375a8df Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Wed, 23 Aug 2023 22:42:10 -0700 Subject: [PATCH 35/77] remove comment --- cpp/src/io/parquet/stream_pool.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 013f0bf2116..687b107ce7d 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -53,7 +53,6 @@ rmm::cuda_stream_view get_stream(std::size_t stream_id) std::vector get_streams(uint32_t count) { - // TODO: if count > STREAM_POOL_SIZE log a warning CUDF_LOG_WARN("get_streams called with count ({}) > pool size ({})", count, STREAM_POOL_SIZE); auto streams = std::vector(); std::lock_guard lock(stream_pool_mutex); From f233870c86c9aac1ae9c811ed0fde95ec743d89b Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 24 Aug 2023 08:49:33 -0700 Subject: [PATCH 36/77] compiles now --- cpp/src/io/parquet/stream_pool.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 687b107ce7d..45ea57b653a 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -18,6 +18,7 @@ #include "stream_pool.hpp" +#include #include namespace cudf::io::detail::parquet { @@ -41,10 +42,7 @@ auto& get_stream_pool() } // anonymous namespace -rmm::cuda_stream_view get_stream() -{ - return get_stream_pool().get_stream(); -} +rmm::cuda_stream_view get_stream() { return get_stream_pool().get_stream(); } rmm::cuda_stream_view get_stream(std::size_t stream_id) { @@ -53,11 +51,13 @@ rmm::cuda_stream_view get_stream(std::size_t stream_id) std::vector get_streams(uint32_t count) { - CUDF_LOG_WARN("get_streams called with count ({}) > pool size ({})", count, STREAM_POOL_SIZE); + if (count > STREAM_POOL_SIZE) { + CUDF_LOG_WARN("get_streams called with count ({}) > pool size ({})", count, STREAM_POOL_SIZE); + } auto streams = std::vector(); std::lock_guard lock(stream_pool_mutex); for (uint32_t i = 0; i < count; i++) { - streams.emplace_back(pool->get_stream()); + streams.emplace_back(get_stream_pool().get_stream()); } return streams; } From 301e596acba728c6393edd4d62dee4d5023d5b82 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 24 Aug 2023 10:37:11 -0700 Subject: [PATCH 37/77] update some comments --- cpp/src/io/parquet/stream_pool.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 45ea57b653a..33bae3a77a7 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -25,19 +25,27 @@ namespace cudf::io::detail::parquet { namespace { +// TODO: what is a good number here. what's the penalty for making it larger? std::size_t constexpr STREAM_POOL_SIZE = 32; std::mutex stream_pool_mutex; auto& get_stream_pool() { - // TODO: is the following still true? test this again. // TODO: creating this on the heap because there were issues with trying to call the // stream pool destructor during cuda shutdown that lead to a segmentation fault in - // nvbench. this allocation is being deliberately leaked to avoid the above, but still - // results in non-fatal warnings when running nvbench in cuda-gdb. + // nvbench. this allocation is being deliberately leaked to avoid the above. +#if 1 static auto pool = new rmm::cuda_stream_pool{STREAM_POOL_SIZE}; return *pool; +#else + // FIXME + // running ./benchmarks/PARQUET_READER_NVBENCH -b parquet_read_decode --axis data_type=STRUCT + // with this code results in a segmentation fault in the cuda_stream_pool dtor during shutdown. + // it seems cudaStreamDestroy is called twice on the streams in the pool. + static rmm::cuda_stream_pool pool{STREAM_POOL_SIZE}; + return pool; +#endif } } // anonymous namespace From e3cfa89afcf35897c94b3ff1757c5da5febb9078 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 24 Aug 2023 12:20:49 -0700 Subject: [PATCH 38/77] implement Vukasin's idea for making the pool extensible --- cpp/src/io/parquet/stream_pool.cpp | 65 +++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 18 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 33bae3a77a7..c45a1de849c 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -19,6 +19,7 @@ #include "stream_pool.hpp" #include +#include #include namespace cudf::io::detail::parquet { @@ -28,33 +29,61 @@ namespace { // TODO: what is a good number here. what's the penalty for making it larger? std::size_t constexpr STREAM_POOL_SIZE = 32; -std::mutex stream_pool_mutex; +class cuda_stream_pool { + public: + virtual ~cuda_stream_pool() = default; + + virtual rmm::cuda_stream_view get_stream() = 0; + virtual rmm::cuda_stream_view get_stream(std::size_t stream_id) = 0; +}; + +class rmm_cuda_stream_pool : public cuda_stream_pool { + rmm::cuda_stream_pool _pool; + + public: + rmm_cuda_stream_pool() : _pool{STREAM_POOL_SIZE} {} + rmm::cuda_stream_view get_stream() override { return _pool.get_stream(); } + rmm::cuda_stream_view get_stream(std::size_t stream_id) override + { + return _pool.get_stream(stream_id); + } +}; + +class debug_cuda_stream_pool : public cuda_stream_pool { + public: + rmm::cuda_stream_view get_stream() override { return cudf::get_default_stream(); } + rmm::cuda_stream_view get_stream(std::size_t stream_id) override + { + return cudf::get_default_stream(); + } +}; -auto& get_stream_pool() +cuda_stream_pool* create_global_cuda_stream_pool() { - // TODO: creating this on the heap because there were issues with trying to call the - // stream pool destructor during cuda shutdown that lead to a segmentation fault in - // nvbench. this allocation is being deliberately leaked to avoid the above. -#if 1 - static auto pool = new rmm::cuda_stream_pool{STREAM_POOL_SIZE}; + if (getenv("LIBCUDF_USE_DEBUG_STREAM_POOL")) return new debug_cuda_stream_pool(); + + return new rmm_cuda_stream_pool(); +} + +// TODO: hidden for now...can move out of the anonymous namespace if this needs to be exposed +// to users. +// TODO: move get_streams(uint32_t) into the interface, or leave as is? +cuda_stream_pool& global_cuda_stream_pool() +{ + static cuda_stream_pool* pool = create_global_cuda_stream_pool(); return *pool; -#else - // FIXME - // running ./benchmarks/PARQUET_READER_NVBENCH -b parquet_read_decode --axis data_type=STRUCT - // with this code results in a segmentation fault in the cuda_stream_pool dtor during shutdown. - // it seems cudaStreamDestroy is called twice on the streams in the pool. - static rmm::cuda_stream_pool pool{STREAM_POOL_SIZE}; - return pool; -#endif } +std::mutex stream_pool_mutex; + } // anonymous namespace -rmm::cuda_stream_view get_stream() { return get_stream_pool().get_stream(); } +// TODO: these next 2 (3?) can go away if we expose global_cuda_stream_pool() +rmm::cuda_stream_view get_stream() { return global_cuda_stream_pool().get_stream(); } rmm::cuda_stream_view get_stream(std::size_t stream_id) { - return get_stream_pool().get_stream(stream_id); + return global_cuda_stream_pool().get_stream(stream_id); } std::vector get_streams(uint32_t count) @@ -65,7 +94,7 @@ std::vector get_streams(uint32_t count) auto streams = std::vector(); std::lock_guard lock(stream_pool_mutex); for (uint32_t i = 0; i < count; i++) { - streams.emplace_back(get_stream_pool().get_stream()); + streams.emplace_back(global_cuda_stream_pool().get_stream()); } return streams; } From e74149f0db12310ec1e0fa6154da4e54074d314e Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 24 Aug 2023 13:23:06 -0700 Subject: [PATCH 39/77] more api jiggering --- cpp/src/io/parquet/reader_impl.cpp | 2 +- cpp/src/io/parquet/stream_pool.cpp | 51 +++++++++++++++--------------- cpp/src/io/parquet/stream_pool.hpp | 14 ++++++++ 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index d72e90d3fdb..f0b3bfa25d9 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -165,7 +165,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) // get the number of streams we need from the pool and tell them to wait on the H2D copies int nkernels = std::bitset<32>(kernel_mask).count(); - auto streams = get_streams(nkernels); + auto streams = global_cuda_stream_pool().get_streams(nkernels); fork_streams(streams, _stream); auto const level_type_size = _file_itm_data.level_type_size; diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index c45a1de849c..cecb4176209 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -29,14 +29,6 @@ namespace { // TODO: what is a good number here. what's the penalty for making it larger? std::size_t constexpr STREAM_POOL_SIZE = 32; -class cuda_stream_pool { - public: - virtual ~cuda_stream_pool() = default; - - virtual rmm::cuda_stream_view get_stream() = 0; - virtual rmm::cuda_stream_view get_stream(std::size_t stream_id) = 0; -}; - class rmm_cuda_stream_pool : public cuda_stream_pool { rmm::cuda_stream_pool _pool; @@ -47,6 +39,21 @@ class rmm_cuda_stream_pool : public cuda_stream_pool { { return _pool.get_stream(stream_id); } + + std::vector get_streams(uint32_t count) + { + static std::mutex stream_pool_mutex; + + if (count > STREAM_POOL_SIZE) { + CUDF_LOG_WARN("get_streams called with count ({}) > pool size ({})", count, STREAM_POOL_SIZE); + } + auto streams = std::vector(); + std::lock_guard lock(stream_pool_mutex); + for (uint32_t i = 0; i < count; i++) { + streams.emplace_back(_pool.get_stream()); + } + return streams; + } }; class debug_cuda_stream_pool : public cuda_stream_pool { @@ -56,6 +63,11 @@ class debug_cuda_stream_pool : public cuda_stream_pool { { return cudf::get_default_stream(); } + + std::vector get_streams(uint32_t count) + { + return std::vector(count, cudf::get_default_stream()); + } }; cuda_stream_pool* create_global_cuda_stream_pool() @@ -65,20 +77,16 @@ cuda_stream_pool* create_global_cuda_stream_pool() return new rmm_cuda_stream_pool(); } -// TODO: hidden for now...can move out of the anonymous namespace if this needs to be exposed -// to users. -// TODO: move get_streams(uint32_t) into the interface, or leave as is? +} // anonymous namespace + cuda_stream_pool& global_cuda_stream_pool() { static cuda_stream_pool* pool = create_global_cuda_stream_pool(); return *pool; } -std::mutex stream_pool_mutex; - -} // anonymous namespace - -// TODO: these next 2 (3?) can go away if we expose global_cuda_stream_pool() +#if 0 +// TODO: these next 3 can go away if we expose global_cuda_stream_pool() rmm::cuda_stream_view get_stream() { return global_cuda_stream_pool().get_stream(); } rmm::cuda_stream_view get_stream(std::size_t stream_id) @@ -88,16 +96,9 @@ rmm::cuda_stream_view get_stream(std::size_t stream_id) std::vector get_streams(uint32_t count) { - if (count > STREAM_POOL_SIZE) { - CUDF_LOG_WARN("get_streams called with count ({}) > pool size ({})", count, STREAM_POOL_SIZE); - } - auto streams = std::vector(); - std::lock_guard lock(stream_pool_mutex); - for (uint32_t i = 0; i < count; i++) { - streams.emplace_back(global_cuda_stream_pool().get_stream()); - } - return streams; + return global_cuda_stream_pool().get_streams(count); } +#endif void fork_streams(std::vector& streams, rmm::cuda_stream_view stream) { diff --git a/cpp/src/io/parquet/stream_pool.hpp b/cpp/src/io/parquet/stream_pool.hpp index e9bef3bda35..c2fb328e433 100644 --- a/cpp/src/io/parquet/stream_pool.hpp +++ b/cpp/src/io/parquet/stream_pool.hpp @@ -20,6 +20,19 @@ namespace cudf::io::detail::parquet { +// TODO move docstrings +class cuda_stream_pool { + public: + virtual ~cuda_stream_pool() = default; + + virtual rmm::cuda_stream_view get_stream() = 0; + virtual rmm::cuda_stream_view get_stream(std::size_t stream_id) = 0; + virtual std::vector get_streams(uint32_t count) = 0; +}; + +cuda_stream_pool& global_cuda_stream_pool(); + +#if 0 /** * @brief Get a `cuda_stream_view` of a stream in the pool. * @@ -49,6 +62,7 @@ rmm::cuda_stream_view get_stream(std::size_t stream_id); * @return Vector containing `count` stream views. */ std::vector get_streams(uint32_t count); +#endif /** * @brief Synchronize a set of streams to an event on another stream. From 0d946e8b70cf5bb87cfa3277e29d818e59ec2136 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 24 Aug 2023 13:31:07 -0700 Subject: [PATCH 40/77] clean up some --- cpp/src/io/parquet/stream_pool.cpp | 15 ------- cpp/src/io/parquet/stream_pool.hpp | 64 +++++++++++++++--------------- 2 files changed, 32 insertions(+), 47 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index cecb4176209..9db91013637 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -85,21 +85,6 @@ cuda_stream_pool& global_cuda_stream_pool() return *pool; } -#if 0 -// TODO: these next 3 can go away if we expose global_cuda_stream_pool() -rmm::cuda_stream_view get_stream() { return global_cuda_stream_pool().get_stream(); } - -rmm::cuda_stream_view get_stream(std::size_t stream_id) -{ - return global_cuda_stream_pool().get_stream(stream_id); -} - -std::vector get_streams(uint32_t count) -{ - return global_cuda_stream_pool().get_streams(count); -} -#endif - void fork_streams(std::vector& streams, rmm::cuda_stream_view stream) { cudaEvent_t event; diff --git a/cpp/src/io/parquet/stream_pool.hpp b/cpp/src/io/parquet/stream_pool.hpp index c2fb328e433..a871338a553 100644 --- a/cpp/src/io/parquet/stream_pool.hpp +++ b/cpp/src/io/parquet/stream_pool.hpp @@ -20,49 +20,49 @@ namespace cudf::io::detail::parquet { -// TODO move docstrings class cuda_stream_pool { public: virtual ~cuda_stream_pool() = default; - virtual rmm::cuda_stream_view get_stream() = 0; - virtual rmm::cuda_stream_view get_stream(std::size_t stream_id) = 0; - virtual std::vector get_streams(uint32_t count) = 0; -}; + /** + * @brief Get a `cuda_stream_view` of a stream in the pool. + * + * This function is thread safe with respect to other calls to the same function. + * + * @return Stream view. + */ + virtual rmm::cuda_stream_view get_stream() = 0; -cuda_stream_pool& global_cuda_stream_pool(); - -#if 0 -/** - * @brief Get a `cuda_stream_view` of a stream in the pool. - * - * This function is thread safe with respect to other calls to the same function. - * - * @return Stream view. - */ -rmm::cuda_stream_view get_stream(); + /** + * @brief Get a `cuda_stream_view` of the stream associated with `stream_id`. + * + * Equivalent values of `stream_id` return a stream_view to the same underlying stream. + * This function is thread safe with respect to other calls to the same function. + * + * @param stream_id Unique identifier for the desired stream + * @return Requested stream view. + */ + virtual rmm::cuda_stream_view get_stream(std::size_t stream_id) = 0; -/** - * @brief Get a `cuda_stream_view` of the stream associated with `stream_id`. - * - * Equivalent values of `stream_id` return a stream_view to the same underlying stream. - * This function is thread safe with respect to other calls to the same function. - * - * @param stream_id Unique identifier for the desired stream - * @return Requested stream view. - */ -rmm::cuda_stream_view get_stream(std::size_t stream_id); + /** + * @brief Get a set of `cuda_stream_view` objects from the pool. + * + * This function is thread safe with respect to other calls to the same function. + * + * @param count The number of stream views to return. + * @return Vector containing `count` stream views. + */ + virtual std::vector get_streams(uint32_t count) = 0; +}; /** - * @brief Get a set of `cuda_stream_view` objects from the pool. + * @brief Return the global cuda_stream_pool object. * - * This function is thread safe with respect to other calls to the same function. + * TODO: document how to control the implementation * - * @param count The number of stream views to return. - * @return Vector containing `count` stream views. + * @return The cuda_stream_pool singleton. */ -std::vector get_streams(uint32_t count); -#endif +cuda_stream_pool& global_cuda_stream_pool(); /** * @brief Synchronize a set of streams to an event on another stream. From 0c1faed5dc2e925996196595b4ac891600c7695e Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 24 Aug 2023 13:37:23 -0700 Subject: [PATCH 41/77] stub in docstring --- cpp/src/io/parquet/stream_pool.hpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cpp/src/io/parquet/stream_pool.hpp b/cpp/src/io/parquet/stream_pool.hpp index a871338a553..99325ad331d 100644 --- a/cpp/src/io/parquet/stream_pool.hpp +++ b/cpp/src/io/parquet/stream_pool.hpp @@ -20,6 +20,13 @@ namespace cudf::io::detail::parquet { +/** + * @brief A pool of CUDA stream objects + * + * Meant to provide efficient on-demand access to CUDA streams. + * + * TODO: better docs! + */ class cuda_stream_pool { public: virtual ~cuda_stream_pool() = default; From a31056c5d1a50ee10d1db4438640d06dcdcfb4ae Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 24 Aug 2023 13:45:03 -0700 Subject: [PATCH 42/77] move get_stream_pool_size into object --- cpp/src/io/parquet/stream_pool.cpp | 6 ++++-- cpp/src/io/parquet/stream_pool.hpp | 18 +++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 9db91013637..66020004d59 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -54,6 +54,8 @@ class rmm_cuda_stream_pool : public cuda_stream_pool { } return streams; } + + std::size_t get_stream_pool_size() const override { return STREAM_POOL_SIZE; } }; class debug_cuda_stream_pool : public cuda_stream_pool { @@ -68,6 +70,8 @@ class debug_cuda_stream_pool : public cuda_stream_pool { { return std::vector(count, cudf::get_default_stream()); } + + std::size_t get_stream_pool_size() const override { return 1UL; } }; cuda_stream_pool* create_global_cuda_stream_pool() @@ -107,6 +111,4 @@ void join_streams(std::vector& streams, rmm::cuda_stream_ CUDF_CUDA_TRY(cudaEventDestroy(event)); } -std::size_t get_stream_pool_size() { return STREAM_POOL_SIZE; } - } // namespace cudf::io::detail::parquet diff --git a/cpp/src/io/parquet/stream_pool.hpp b/cpp/src/io/parquet/stream_pool.hpp index 99325ad331d..ea09144ccc9 100644 --- a/cpp/src/io/parquet/stream_pool.hpp +++ b/cpp/src/io/parquet/stream_pool.hpp @@ -60,6 +60,15 @@ class cuda_stream_pool { * @return Vector containing `count` stream views. */ virtual std::vector get_streams(uint32_t count) = 0; + + /** + * @brief Get the number of streams in the pool. + * + * This function is thread safe with respect to other calls to the same function. + * + * @return the number of streams in the pool + */ + virtual std::size_t get_stream_pool_size() const = 0; }; /** @@ -87,13 +96,4 @@ void fork_streams(std::vector& streams, rmm::cuda_stream_ */ void join_streams(std::vector& streams, rmm::cuda_stream_view stream); -/** - * @brief Get the number of streams in the pool. - * - * This function is thread safe with respect to other calls to the same function. - * - * @return the number of streams in the pool - */ -std::size_t get_stream_pool_size(); - } // namespace cudf::io::detail::parquet From 576230a301966322e9cfe3fdef706a723e44729d Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 24 Aug 2023 13:47:38 -0700 Subject: [PATCH 43/77] forgot some overrides --- cpp/src/io/parquet/stream_pool.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 66020004d59..96ec47e512b 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -40,7 +40,7 @@ class rmm_cuda_stream_pool : public cuda_stream_pool { return _pool.get_stream(stream_id); } - std::vector get_streams(uint32_t count) + std::vector get_streams(uint32_t count) override { static std::mutex stream_pool_mutex; @@ -66,7 +66,7 @@ class debug_cuda_stream_pool : public cuda_stream_pool { return cudf::get_default_stream(); } - std::vector get_streams(uint32_t count) + std::vector get_streams(uint32_t count) override { return std::vector(count, cudf::get_default_stream()); } From 21b4443c2b6ee5e369e5a585a4340efde74d26e7 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 24 Aug 2023 13:54:45 -0700 Subject: [PATCH 44/77] pass host_span to fork/join_streams --- cpp/src/io/parquet/stream_pool.cpp | 4 ++-- cpp/src/io/parquet/stream_pool.hpp | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 96ec47e512b..64d83a4f225 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -89,7 +89,7 @@ cuda_stream_pool& global_cuda_stream_pool() return *pool; } -void fork_streams(std::vector& streams, rmm::cuda_stream_view stream) +void fork_streams(host_span streams, rmm::cuda_stream_view stream) { cudaEvent_t event; CUDF_CUDA_TRY(cudaEventCreate(&event)); @@ -100,7 +100,7 @@ void fork_streams(std::vector& streams, rmm::cuda_stream_ CUDF_CUDA_TRY(cudaEventDestroy(event)); } -void join_streams(std::vector& streams, rmm::cuda_stream_view stream) +void join_streams(host_span streams, rmm::cuda_stream_view stream) { cudaEvent_t event; CUDF_CUDA_TRY(cudaEventCreate(&event)); diff --git a/cpp/src/io/parquet/stream_pool.hpp b/cpp/src/io/parquet/stream_pool.hpp index ea09144ccc9..d11ac5cfbf5 100644 --- a/cpp/src/io/parquet/stream_pool.hpp +++ b/cpp/src/io/parquet/stream_pool.hpp @@ -16,6 +16,8 @@ #pragma once +#include + #include namespace cudf::io::detail::parquet { @@ -86,7 +88,7 @@ cuda_stream_pool& global_cuda_stream_pool(); * @param streams Vector of streams to synchronize on. * @param stream Stream to synchronize the other streams to, usually the default stream. */ -void fork_streams(std::vector& streams, rmm::cuda_stream_view stream); +void fork_streams(host_span streams, rmm::cuda_stream_view stream); /** * @brief Synchronize a stream to an event on a set of streams. @@ -94,6 +96,6 @@ void fork_streams(std::vector& streams, rmm::cuda_stream_ * @param streams Vector of streams to synchronize on. * @param stream Stream to synchronize the other streams to, usually the default stream. */ -void join_streams(std::vector& streams, rmm::cuda_stream_view stream); +void join_streams(host_span streams, rmm::cuda_stream_view stream); } // namespace cudf::io::detail::parquet From 3f3c5b53118adbe533aaf5c66a6db154fff47f47 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 24 Aug 2023 15:33:43 -0700 Subject: [PATCH 45/77] add to TODO --- cpp/src/io/parquet/stream_pool.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 64d83a4f225..0f9de8a566c 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -27,6 +27,11 @@ namespace cudf::io::detail::parquet { namespace { // TODO: what is a good number here. what's the penalty for making it larger? +// Dave Baranec rule of thumb was max_streams_needed * num_concurrent_threads, +// where num_concurrent_threads was estimated to be 4. so using 32 will allow +// for 8 streams per thread, which should be plenty (decoding will be up to 4 +// kernels when delta_byte_array decoding is added). rmm::cuda_stream_pool +// defaults to 16. std::size_t constexpr STREAM_POOL_SIZE = 32; class rmm_cuda_stream_pool : public cuda_stream_pool { From 3059d942fbc68dd9edfd2e0cd892538d71c494d6 Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 29 Aug 2023 08:42:54 -0700 Subject: [PATCH 46/77] use static events and disable timing per suggestion from review switch to for loop since lambda couldn't capture the static stream (could be convinced to use a functor) --- cpp/src/io/parquet/stream_pool.cpp | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 0f9de8a566c..0e12fab863b 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -86,6 +86,13 @@ cuda_stream_pool* create_global_cuda_stream_pool() return new rmm_cuda_stream_pool(); } +cudaEvent_t create_event() +{ + cudaEvent_t event; + CUDF_CUDA_TRY(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + return event; +} + } // anonymous namespace cuda_stream_pool& global_cuda_stream_pool() @@ -96,24 +103,20 @@ cuda_stream_pool& global_cuda_stream_pool() void fork_streams(host_span streams, rmm::cuda_stream_view stream) { - cudaEvent_t event; - CUDF_CUDA_TRY(cudaEventCreate(&event)); + static cudaEvent_t event = create_event(); CUDF_CUDA_TRY(cudaEventRecord(event, stream)); - std::for_each(streams.begin(), streams.end(), [&](auto& strm) { + for (auto& strm : streams) { CUDF_CUDA_TRY(cudaStreamWaitEvent(strm, event, 0)); - }); - CUDF_CUDA_TRY(cudaEventDestroy(event)); + } } void join_streams(host_span streams, rmm::cuda_stream_view stream) { - cudaEvent_t event; - CUDF_CUDA_TRY(cudaEventCreate(&event)); - std::for_each(streams.begin(), streams.end(), [&](auto& strm) { + static cudaEvent_t event = create_event(); + for (auto& strm : streams) { CUDF_CUDA_TRY(cudaEventRecord(event, strm)); CUDF_CUDA_TRY(cudaStreamWaitEvent(stream, event, 0)); - }); - CUDF_CUDA_TRY(cudaEventDestroy(event)); + } } } // namespace cudf::io::detail::parquet From bc62b928ab6b00af170a93bf2342b21474434781 Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 29 Aug 2023 11:55:31 -0700 Subject: [PATCH 47/77] add Vukasins per-thread-default-event implmenentation --- cpp/src/io/parquet/stream_pool.cpp | 38 +++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 0e12fab863b..e6e1569fb41 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -86,11 +86,37 @@ cuda_stream_pool* create_global_cuda_stream_pool() return new rmm_cuda_stream_pool(); } -cudaEvent_t create_event() +// implementation of per-thread-default-event. +class cuda_event_map { + public: + cuda_event_map() {} + + cudaEvent_t find(std::thread::id thread_id) + { + std::lock_guard lock(map_mutex_); + auto it = event_map_.find(thread_id); + if (it != event_map_.end()) { + return it->second; + } else { + cudaEvent_t event; + CUDF_CUDA_TRY(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + event_map_[thread_id] = event; + return event; + } + } + + cuda_event_map(cuda_event_map const&) = delete; + void operator=(cuda_event_map const&) = delete; + + private: + std::unordered_map event_map_; + std::mutex map_mutex_; +}; + +cudaEvent_t event_for_thread() { - cudaEvent_t event; - CUDF_CUDA_TRY(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); - return event; + static cuda_event_map instance; + return instance.find(std::this_thread::get_id()); } } // anonymous namespace @@ -103,7 +129,7 @@ cuda_stream_pool& global_cuda_stream_pool() void fork_streams(host_span streams, rmm::cuda_stream_view stream) { - static cudaEvent_t event = create_event(); + static cudaEvent_t event = event_for_thread(); CUDF_CUDA_TRY(cudaEventRecord(event, stream)); for (auto& strm : streams) { CUDF_CUDA_TRY(cudaStreamWaitEvent(strm, event, 0)); @@ -112,7 +138,7 @@ void fork_streams(host_span streams, rmm::cuda_stream_vie void join_streams(host_span streams, rmm::cuda_stream_view stream) { - static cudaEvent_t event = create_event(); + static cudaEvent_t event = event_for_thread(); for (auto& strm : streams) { CUDF_CUDA_TRY(cudaEventRecord(event, strm)); CUDF_CUDA_TRY(cudaStreamWaitEvent(stream, event, 0)); From 4c2c17b2c8c7e6be2209e64469aa5ad7f301728e Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 29 Aug 2023 11:58:37 -0700 Subject: [PATCH 48/77] remove static from fork/join events --- cpp/src/io/parquet/stream_pool.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index e6e1569fb41..8d3b8693a03 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -129,7 +129,7 @@ cuda_stream_pool& global_cuda_stream_pool() void fork_streams(host_span streams, rmm::cuda_stream_view stream) { - static cudaEvent_t event = event_for_thread(); + cudaEvent_t event = event_for_thread(); CUDF_CUDA_TRY(cudaEventRecord(event, stream)); for (auto& strm : streams) { CUDF_CUDA_TRY(cudaStreamWaitEvent(strm, event, 0)); @@ -138,7 +138,7 @@ void fork_streams(host_span streams, rmm::cuda_stream_vie void join_streams(host_span streams, rmm::cuda_stream_view stream) { - static cudaEvent_t event = event_for_thread(); + cudaEvent_t event = event_for_thread(); for (auto& strm : streams) { CUDF_CUDA_TRY(cudaEventRecord(event, strm)); CUDF_CUDA_TRY(cudaStreamWaitEvent(stream, event, 0)); From 06d2a75c049359900e640c63744bcbbadea8a188 Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 29 Aug 2023 13:10:39 -0700 Subject: [PATCH 49/77] replace event map with thread_local event struct --- cpp/src/io/parquet/stream_pool.cpp | 37 ++++++++++++------------------ 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/io/parquet/stream_pool.cpp index 8d3b8693a03..6808b80ae7d 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/io/parquet/stream_pool.cpp @@ -86,37 +86,30 @@ cuda_stream_pool* create_global_cuda_stream_pool() return new rmm_cuda_stream_pool(); } -// implementation of per-thread-default-event. -class cuda_event_map { - public: - cuda_event_map() {} - - cudaEvent_t find(std::thread::id thread_id) +struct cuda_event { + cuda_event() + : e_{[]() { + cudaEvent_t event; + CUDF_CUDA_TRY(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + return event; + }()} { - std::lock_guard lock(map_mutex_); - auto it = event_map_.find(thread_id); - if (it != event_map_.end()) { - return it->second; - } else { - cudaEvent_t event; - CUDF_CUDA_TRY(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); - event_map_[thread_id] = event; - return event; - } } - cuda_event_map(cuda_event_map const&) = delete; - void operator=(cuda_event_map const&) = delete; + operator cudaEvent_t() { return e_.get(); } private: - std::unordered_map event_map_; - std::mutex map_mutex_; + struct deleter { + using pointer = cudaEvent_t; + auto operator()(cudaEvent_t e) { cudaEventDestroy(e); } + }; + std::unique_ptr e_; }; cudaEvent_t event_for_thread() { - static cuda_event_map instance; - return instance.find(std::this_thread::get_id()); + thread_local cuda_event thread_event; + return thread_event; } } // anonymous namespace From ddfc1180e5174aa121277276ca556366dc07cae7 Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 29 Aug 2023 16:34:46 -0700 Subject: [PATCH 50/77] move stream pool to cudf::detail --- conda/recipes/libcudf/meta.yaml | 1 + cpp/CMakeLists.txt | 2 +- .../cudf/detail/utilities}/stream_pool.hpp | 4 ++-- cpp/src/io/parquet/reader_impl.cpp | 8 ++++---- cpp/src/{io/parquet => utilities}/stream_pool.cpp | 7 +++---- 5 files changed, 11 insertions(+), 11 deletions(-) rename cpp/{src/io/parquet => include/cudf/detail/utilities}/stream_pool.hpp (97%) rename cpp/src/{io/parquet => utilities}/stream_pool.cpp (97%) diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index de32facba74..760c83a3119 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -174,6 +174,7 @@ outputs: - test -f $PREFIX/include/cudf/detail/utilities/logger.hpp - test -f $PREFIX/include/cudf/detail/utilities/pinned_host_vector.hpp - test -f $PREFIX/include/cudf/detail/utilities/stacktrace.hpp + - test -f $PREFIX/include/cudf/detail/utilities/stream_pool.hpp - test -f $PREFIX/include/cudf/detail/utilities/vector_factories.hpp - test -f $PREFIX/include/cudf/detail/utilities/visitor_overload.hpp - test -f $PREFIX/include/cudf/dictionary/detail/concatenate.hpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 007ce4e65f8..c37d05a21c7 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -402,7 +402,6 @@ add_library( src/io/parquet/reader_impl.cpp src/io/parquet/reader_impl_helpers.cpp src/io/parquet/reader_impl_preprocess.cu - src/io/parquet/stream_pool.cpp src/io/parquet/writer_impl.cu src/io/statistics/orc_column_statistics.cu src/io/statistics/parquet_column_statistics.cu @@ -634,6 +633,7 @@ add_library( src/utilities/linked_column.cpp src/utilities/logger.cpp src/utilities/stacktrace.cpp + src/utilities/stream_pool.cpp src/utilities/traits.cpp src/utilities/type_checks.cpp src/utilities/type_dispatcher.cpp diff --git a/cpp/src/io/parquet/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp similarity index 97% rename from cpp/src/io/parquet/stream_pool.hpp rename to cpp/include/cudf/detail/utilities/stream_pool.hpp index d11ac5cfbf5..c1e66db2b20 100644 --- a/cpp/src/io/parquet/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -20,7 +20,7 @@ #include -namespace cudf::io::detail::parquet { +namespace cudf::detail { /** * @brief A pool of CUDA stream objects @@ -98,4 +98,4 @@ void fork_streams(host_span streams, rmm::cuda_stream_vie */ void join_streams(host_span streams, rmm::cuda_stream_view stream); -} // namespace cudf::io::detail::parquet +} // namespace cudf::detail diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index b54789d4fc8..16011706ea9 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -15,10 +15,10 @@ */ #include "reader_impl.hpp" -#include "stream_pool.hpp" #include #include +#include #include #include @@ -165,8 +165,8 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) // get the number of streams we need from the pool and tell them to wait on the H2D copies int nkernels = std::bitset<32>(kernel_mask).count(); - auto streams = global_cuda_stream_pool().get_streams(nkernels); - fork_streams(streams, _stream); + auto streams = cudf::detail::global_cuda_stream_pool().get_streams(nkernels); + cudf::detail::fork_streams(streams, _stream); auto const level_type_size = _file_itm_data.level_type_size; @@ -189,7 +189,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) } // synchronize the streams - join_streams(streams, _stream); + cudf::detail::join_streams(streams, _stream); pages.device_to_host_async(_stream); page_nesting.device_to_host_async(_stream); diff --git a/cpp/src/io/parquet/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp similarity index 97% rename from cpp/src/io/parquet/stream_pool.cpp rename to cpp/src/utilities/stream_pool.cpp index 6808b80ae7d..7066fc64898 100644 --- a/cpp/src/io/parquet/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -16,13 +16,12 @@ #include -#include "stream_pool.hpp" - #include +#include #include #include -namespace cudf::io::detail::parquet { +namespace cudf::detail { namespace { @@ -138,4 +137,4 @@ void join_streams(host_span streams, rmm::cuda_stream_vie } } -} // namespace cudf::io::detail::parquet +} // namespace cudf::detail From ac55b7e7309f1de6856d809f64d224f99921dbc4 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Tue, 29 Aug 2023 20:31:53 -0700 Subject: [PATCH 51/77] start cleaning up docstrings --- cpp/include/cudf/detail/utilities/stream_pool.hpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index c1e66db2b20..af06ec802d3 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -28,6 +28,7 @@ namespace cudf::detail { * Meant to provide efficient on-demand access to CUDA streams. * * TODO: better docs! + * if env LIBCUDF_USE_DEBUG_STREAM_POOL is set, then use default stream... */ class cuda_stream_pool { public: @@ -56,6 +57,10 @@ class cuda_stream_pool { /** * @brief Get a set of `cuda_stream_view` objects from the pool. * + * An attempt is made to ensure that the returned vector does not contain duplicate + * streams, but this cannot be guaranteed if `count` is greater than + * `get_stream_pool_size()`. + * * This function is thread safe with respect to other calls to the same function. * * @param count The number of stream views to return. @@ -74,9 +79,7 @@ class cuda_stream_pool { }; /** - * @brief Return the global cuda_stream_pool object. - * - * TODO: document how to control the implementation + * @brief Return a reference to the global cuda_stream_pool object. * * @return The cuda_stream_pool singleton. */ @@ -85,7 +88,7 @@ cuda_stream_pool& global_cuda_stream_pool(); /** * @brief Synchronize a set of streams to an event on another stream. * - * @param streams Vector of streams to synchronize on. + * @param streams Vector of streams to synchronize. * @param stream Stream to synchronize the other streams to, usually the default stream. */ void fork_streams(host_span streams, rmm::cuda_stream_view stream); @@ -93,8 +96,8 @@ void fork_streams(host_span streams, rmm::cuda_stream_vie /** * @brief Synchronize a stream to an event on a set of streams. * - * @param streams Vector of streams to synchronize on. - * @param stream Stream to synchronize the other streams to, usually the default stream. + * @param streams Vector of streams to synchronize to. + * @param stream Stream to synchronize, usually the default stream. */ void join_streams(host_span streams, rmm::cuda_stream_view stream); From 4455230086e85c7e9f49ca549a355024930abaea Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 30 Aug 2023 13:57:14 -0700 Subject: [PATCH 52/77] use new stream pool in multibyte_split --- cpp/src/io/text/multibyte_split.cu | 39 +++--------------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/cpp/src/io/text/multibyte_split.cu b/cpp/src/io/text/multibyte_split.cu index 818bbc0a18a..23a888de9fd 100644 --- a/cpp/src/io/text/multibyte_split.cu +++ b/cpp/src/io/text/multibyte_split.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -32,7 +33,6 @@ #include #include -#include #include #include #include @@ -301,37 +301,6 @@ namespace io { namespace text { namespace detail { -void fork_stream(std::vector streams, rmm::cuda_stream_view stream) -{ - cudaEvent_t event; - CUDF_CUDA_TRY(cudaEventCreate(&event)); - CUDF_CUDA_TRY(cudaEventRecord(event, stream)); - for (uint32_t i = 0; i < streams.size(); i++) { - CUDF_CUDA_TRY(cudaStreamWaitEvent(streams[i], event, 0)); - } - CUDF_CUDA_TRY(cudaEventDestroy(event)); -} - -void join_stream(std::vector streams, rmm::cuda_stream_view stream) -{ - cudaEvent_t event; - CUDF_CUDA_TRY(cudaEventCreate(&event)); - for (uint32_t i = 0; i < streams.size(); i++) { - CUDF_CUDA_TRY(cudaEventRecord(event, streams[i])); - CUDF_CUDA_TRY(cudaStreamWaitEvent(stream, event, 0)); - } - CUDF_CUDA_TRY(cudaEventDestroy(event)); -} - -std::vector get_streams(int32_t count, rmm::cuda_stream_pool& stream_pool) -{ - auto streams = std::vector(); - for (int32_t i = 0; i < count; i++) { - streams.emplace_back(stream_pool.get_stream()); - } - return streams; -} - std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source const& source, std::string const& delimiter, byte_range_info byte_range, @@ -366,7 +335,7 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source "delimiter contains too many total tokens to produce a deterministic result."); auto concurrency = 2; - auto streams = get_streams(concurrency, stream_pool); + auto streams = cudf::detail::global_cuda_stream_pool().get_streams(concurrency); // must be at least 32 when using warp-reduce on partials // must be at least 1 more than max possible concurrent tiles @@ -411,7 +380,7 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source output_builder row_offset_storage(ITEMS_PER_CHUNK, max_growth, stream); output_builder char_storage(ITEMS_PER_CHUNK, max_growth, stream); - fork_stream(streams, stream); + cudf::detail::fork_streams(streams, stream); cudaEvent_t last_launch_event; CUDF_CUDA_TRY(cudaEventCreate(&last_launch_event)); @@ -532,7 +501,7 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source CUDF_CUDA_TRY(cudaEventDestroy(last_launch_event)); - join_stream(streams, stream); + cudf::detail::join_streams(streams, stream); // if the input was empty, we didn't find a delimiter at all, // or the first delimiter was also the last: empty output From b5c55bb6440b803af5f6bd3ac4b9524313627983 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 30 Aug 2023 16:53:34 -0700 Subject: [PATCH 53/77] forgot to get rid of rmm stream pool --- cpp/src/io/text/multibyte_split.cu | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/cpp/src/io/text/multibyte_split.cu b/cpp/src/io/text/multibyte_split.cu index 23a888de9fd..9718aec05bf 100644 --- a/cpp/src/io/text/multibyte_split.cu +++ b/cpp/src/io/text/multibyte_split.cu @@ -306,8 +306,7 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source byte_range_info byte_range, bool strip_delimiters, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr, - rmm::cuda_stream_pool& stream_pool) + rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); @@ -571,11 +570,10 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source parse_options options, rmm::mr::device_memory_resource* mr) { - auto stream = cudf::get_default_stream(); - auto stream_pool = rmm::cuda_stream_pool(2); + auto stream = cudf::get_default_stream(); auto result = detail::multibyte_split( - source, delimiter, options.byte_range, options.strip_delimiters, stream, mr, stream_pool); + source, delimiter, options.byte_range, options.strip_delimiters, stream, mr); return result; } From f5ff4d08ad7c2973b72483b59b6e0b6c67013fb4 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 31 Aug 2023 12:34:05 -0700 Subject: [PATCH 54/77] add more documentation --- .../cudf/detail/utilities/stream_pool.hpp | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index af06ec802d3..c44d544b676 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -25,10 +25,28 @@ namespace cudf::detail { /** * @brief A pool of CUDA stream objects * - * Meant to provide efficient on-demand access to CUDA streams. + * Provides efficient access to a collection of asynchronous (i.e. non-default) CUDA stream objects. * - * TODO: better docs! - * if env LIBCUDF_USE_DEBUG_STREAM_POOL is set, then use default stream... + * The default implementation uses an underlying `rmm::cuda_stream_pool`. The only other + * implementation at present is a debugging version that always returns the stream returned by + * `cudf::get_default_stream()`. To use this debugging version, set the environment variable + * `LIBCUDF_USE_DEBUG_STREAM_POOL`. + * + * Access to the global `cuda_stream_pool` is granted via `cudf::detail::global_cuda_stream_pool()`. + * + * Example usage: + * @code{.cpp} + * auto stream = cudf::get_default_stream(); + * auto const num_streams = 2; + * // do work on stream + * auto streams = cudf::detail::global_cuda_stream_pool().get_streams(num_streams); + * // wait for event on stream before executing on any of streams + * cudf::detail::fork_streams(streams, stream); + * // invoke kernel on streams[0] + * // invoke kernel on streams[1] + * // wait for event on streams before executing on stream + * cudf::detail::join_streams(streams, stream); + * @endcode */ class cuda_stream_pool { public: @@ -46,7 +64,7 @@ class cuda_stream_pool { /** * @brief Get a `cuda_stream_view` of the stream associated with `stream_id`. * - * Equivalent values of `stream_id` return a stream_view to the same underlying stream. + * Equivalent values of `stream_id` return a `cuda_stream_view` to the same underlying stream. * This function is thread safe with respect to other calls to the same function. * * @param stream_id Unique identifier for the desired stream @@ -58,7 +76,7 @@ class cuda_stream_pool { * @brief Get a set of `cuda_stream_view` objects from the pool. * * An attempt is made to ensure that the returned vector does not contain duplicate - * streams, but this cannot be guaranteed if `count` is greater than + * streams, but this cannot be guaranteed if `count` is greater than the value returned by * `get_stream_pool_size()`. * * This function is thread safe with respect to other calls to the same function. @@ -69,17 +87,17 @@ class cuda_stream_pool { virtual std::vector get_streams(uint32_t count) = 0; /** - * @brief Get the number of streams in the pool. + * @brief Get the number of stream objects in the pool. * * This function is thread safe with respect to other calls to the same function. * - * @return the number of streams in the pool + * @return the number of stream objects in the pool */ virtual std::size_t get_stream_pool_size() const = 0; }; /** - * @brief Return a reference to the global cuda_stream_pool object. + * @brief Return a reference to the global `cuda_stream_pool` object. * * @return The cuda_stream_pool singleton. */ From 1dc75d6f83797539c2b019c8ebb331b3d58c74c8 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 31 Aug 2023 13:02:22 -0700 Subject: [PATCH 55/77] fix formatting --- cpp/include/cudf/detail/utilities/stream_pool.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index c44d544b676..4c3c98b5097 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -33,7 +33,7 @@ namespace cudf::detail { * `LIBCUDF_USE_DEBUG_STREAM_POOL`. * * Access to the global `cuda_stream_pool` is granted via `cudf::detail::global_cuda_stream_pool()`. - * + * * Example usage: * @code{.cpp} * auto stream = cudf::get_default_stream(); From 1ceaedf10ef412667d894e5bb034b9ba691fbe1e Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 1 Sep 2023 08:45:44 -0700 Subject: [PATCH 56/77] remove mutex from get_streams() --- cpp/src/utilities/stream_pool.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 7066fc64898..be3f20dc267 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -14,8 +14,6 @@ * limitations under the License. */ -#include - #include #include #include @@ -46,13 +44,10 @@ class rmm_cuda_stream_pool : public cuda_stream_pool { std::vector get_streams(uint32_t count) override { - static std::mutex stream_pool_mutex; - if (count > STREAM_POOL_SIZE) { CUDF_LOG_WARN("get_streams called with count ({}) > pool size ({})", count, STREAM_POOL_SIZE); } auto streams = std::vector(); - std::lock_guard lock(stream_pool_mutex); for (uint32_t i = 0; i < count; i++) { streams.emplace_back(_pool.get_stream()); } From 8b316ba223e43f431d315f09ee1c6e9929ca0565 Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 1 Sep 2023 16:37:38 -0700 Subject: [PATCH 57/77] change fork_streams as suggested in review --- .../cudf/detail/utilities/stream_pool.hpp | 20 +++++++++---------- cpp/src/io/parquet/reader_impl.cpp | 5 ++--- cpp/src/io/text/multibyte_split.cu | 5 ++--- cpp/src/utilities/stream_pool.cpp | 7 ++++++- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index 4c3c98b5097..282f7441c63 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -39,12 +39,10 @@ namespace cudf::detail { * auto stream = cudf::get_default_stream(); * auto const num_streams = 2; * // do work on stream - * auto streams = cudf::detail::global_cuda_stream_pool().get_streams(num_streams); - * // wait for event on stream before executing on any of streams - * cudf::detail::fork_streams(streams, stream); - * // invoke kernel on streams[0] - * // invoke kernel on streams[1] - * // wait for event on streams before executing on stream + * // allocate streams and wait for an event on stream before executing on any of streams + * auto streams = cudf::detail::fork_streams(stream, num_streams); + * // do work on streams[0] and streams[1] + * // wait for event on streams before continuing to do work on stream * cudf::detail::join_streams(streams, stream); * @endcode */ @@ -104,12 +102,14 @@ class cuda_stream_pool { cuda_stream_pool& global_cuda_stream_pool(); /** - * @brief Synchronize a set of streams to an event on another stream. + * @brief Acquire a set of `cuda_stream_view` objects and synchronize them to an event on another + * stream. * - * @param streams Vector of streams to synchronize. - * @param stream Stream to synchronize the other streams to, usually the default stream. + * @param stream Stream to synchronize the returned streams to, usually the default stream. + * @param count The number of stream views to return. + * @return Vector containing `count` stream views. */ -void fork_streams(host_span streams, rmm::cuda_stream_view stream); +std::vector fork_streams(rmm::cuda_stream_view stream, uint32_t count); /** * @brief Synchronize a stream to an event on a set of streams. diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 6380ddfca34..8b0a0bd4eb0 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -164,9 +164,8 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) chunk_nested_data.host_to_device_async(_stream); // get the number of streams we need from the pool and tell them to wait on the H2D copies - int nkernels = std::bitset<32>(kernel_mask).count(); - auto streams = cudf::detail::global_cuda_stream_pool().get_streams(nkernels); - cudf::detail::fork_streams(streams, _stream); + int const nkernels = std::bitset<32>(kernel_mask).count(); + auto streams = cudf::detail::fork_streams(_stream, nkernels); auto const level_type_size = _file_itm_data.level_type_size; diff --git a/cpp/src/io/text/multibyte_split.cu b/cpp/src/io/text/multibyte_split.cu index 9718aec05bf..772bcad8ada 100644 --- a/cpp/src/io/text/multibyte_split.cu +++ b/cpp/src/io/text/multibyte_split.cu @@ -333,8 +333,7 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source CUDF_EXPECTS(delimiter.size() < multistate::max_segment_value, "delimiter contains too many total tokens to produce a deterministic result."); - auto concurrency = 2; - auto streams = cudf::detail::global_cuda_stream_pool().get_streams(concurrency); + auto const concurrency = 2; // must be at least 32 when using warp-reduce on partials // must be at least 1 more than max possible concurrent tiles @@ -379,7 +378,7 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source output_builder row_offset_storage(ITEMS_PER_CHUNK, max_growth, stream); output_builder char_storage(ITEMS_PER_CHUNK, max_growth, stream); - cudf::detail::fork_streams(streams, stream); + auto streams = cudf::detail::fork_streams(stream, concurrency); cudaEvent_t last_launch_event; CUDF_CUDA_TRY(cudaEventCreate(&last_launch_event)); diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index be3f20dc267..09913c44cc9 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -114,13 +114,18 @@ cuda_stream_pool& global_cuda_stream_pool() return *pool; } -void fork_streams(host_span streams, rmm::cuda_stream_view stream) +std::vector fork_streams(rmm::cuda_stream_view stream, uint32_t count) { + auto streams = global_cuda_stream_pool().get_streams(count); cudaEvent_t event = event_for_thread(); CUDF_CUDA_TRY(cudaEventRecord(event, stream)); + std::for_each(streams.begin(), streams.end(), [&](auto& strm) { + CUDF_CUDA_TRY(cudaStreamWaitEvent(strm, event, 0)); + }); for (auto& strm : streams) { CUDF_CUDA_TRY(cudaStreamWaitEvent(strm, event, 0)); } + return streams; } void join_streams(host_span streams, rmm::cuda_stream_view stream) From d538be902f1dd8fd28f47898a2a567d3d70a60d6 Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 1 Sep 2023 17:07:21 -0700 Subject: [PATCH 58/77] hide the actual stream pool --- .../cudf/detail/utilities/stream_pool.hpp | 70 ++----------------- cpp/src/utilities/stream_pool.cpp | 57 ++++++++++++++- 2 files changed, 59 insertions(+), 68 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index 282f7441c63..47c1e1d4e7f 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -18,22 +18,19 @@ #include -#include +#include namespace cudf::detail { /** - * @brief A pool of CUDA stream objects - * - * Provides efficient access to a collection of asynchronous (i.e. non-default) CUDA stream objects. + * @brief Acquire a set of `cuda_stream_view` objects and synchronize them to an event on another + * stream. * - * The default implementation uses an underlying `rmm::cuda_stream_pool`. The only other + * By default an underlying `rmm::cuda_stream_pool` is used to obtain the streams. The only other * implementation at present is a debugging version that always returns the stream returned by * `cudf::get_default_stream()`. To use this debugging version, set the environment variable * `LIBCUDF_USE_DEBUG_STREAM_POOL`. * - * Access to the global `cuda_stream_pool` is granted via `cudf::detail::global_cuda_stream_pool()`. - * * Example usage: * @code{.cpp} * auto stream = cudf::get_default_stream(); @@ -45,65 +42,6 @@ namespace cudf::detail { * // wait for event on streams before continuing to do work on stream * cudf::detail::join_streams(streams, stream); * @endcode - */ -class cuda_stream_pool { - public: - virtual ~cuda_stream_pool() = default; - - /** - * @brief Get a `cuda_stream_view` of a stream in the pool. - * - * This function is thread safe with respect to other calls to the same function. - * - * @return Stream view. - */ - virtual rmm::cuda_stream_view get_stream() = 0; - - /** - * @brief Get a `cuda_stream_view` of the stream associated with `stream_id`. - * - * Equivalent values of `stream_id` return a `cuda_stream_view` to the same underlying stream. - * This function is thread safe with respect to other calls to the same function. - * - * @param stream_id Unique identifier for the desired stream - * @return Requested stream view. - */ - virtual rmm::cuda_stream_view get_stream(std::size_t stream_id) = 0; - - /** - * @brief Get a set of `cuda_stream_view` objects from the pool. - * - * An attempt is made to ensure that the returned vector does not contain duplicate - * streams, but this cannot be guaranteed if `count` is greater than the value returned by - * `get_stream_pool_size()`. - * - * This function is thread safe with respect to other calls to the same function. - * - * @param count The number of stream views to return. - * @return Vector containing `count` stream views. - */ - virtual std::vector get_streams(uint32_t count) = 0; - - /** - * @brief Get the number of stream objects in the pool. - * - * This function is thread safe with respect to other calls to the same function. - * - * @return the number of stream objects in the pool - */ - virtual std::size_t get_stream_pool_size() const = 0; -}; - -/** - * @brief Return a reference to the global `cuda_stream_pool` object. - * - * @return The cuda_stream_pool singleton. - */ -cuda_stream_pool& global_cuda_stream_pool(); - -/** - * @brief Acquire a set of `cuda_stream_view` objects and synchronize them to an event on another - * stream. * * @param stream Stream to synchronize the returned streams to, usually the default stream. * @param count The number of stream views to return. diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 09913c44cc9..43921815ad9 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -19,6 +19,8 @@ #include #include +#include + namespace cudf::detail { namespace { @@ -31,6 +33,57 @@ namespace { // defaults to 16. std::size_t constexpr STREAM_POOL_SIZE = 32; +class cuda_stream_pool { + public: + virtual ~cuda_stream_pool() = default; + + /** + * @brief Get a `cuda_stream_view` of a stream in the pool. + * + * This function is thread safe with respect to other calls to the same function. + * + * @return Stream view. + */ + virtual rmm::cuda_stream_view get_stream() = 0; + + /** + * @brief Get a `cuda_stream_view` of the stream associated with `stream_id`. + * + * Equivalent values of `stream_id` return a `cuda_stream_view` to the same underlying stream. + * This function is thread safe with respect to other calls to the same function. + * + * @param stream_id Unique identifier for the desired stream + * @return Requested stream view. + */ + virtual rmm::cuda_stream_view get_stream(std::size_t stream_id) = 0; + + /** + * @brief Get a set of `cuda_stream_view` objects from the pool. + * + * An attempt is made to ensure that the returned vector does not contain duplicate + * streams, but this cannot be guaranteed if `count` is greater than the value returned by + * `get_stream_pool_size()`. + * + * This function is thread safe with respect to other calls to the same function. + * + * @param count The number of stream views to return. + * @return Vector containing `count` stream views. + */ + virtual std::vector get_streams(uint32_t count) = 0; + + /** + * @brief Get the number of stream objects in the pool. + * + * This function is thread safe with respect to other calls to the same function. + * + * @return the number of stream objects in the pool + */ + virtual std::size_t get_stream_pool_size() const = 0; +}; + +/** + * @brief Implementation of `cuda_stream_pool` that wraps an `rmm::cuda_stram_pool`. + */ class rmm_cuda_stream_pool : public cuda_stream_pool { rmm::cuda_stream_pool _pool; @@ -106,14 +159,14 @@ cudaEvent_t event_for_thread() return thread_event; } -} // anonymous namespace - cuda_stream_pool& global_cuda_stream_pool() { static cuda_stream_pool* pool = create_global_cuda_stream_pool(); return *pool; } +} // anonymous namespace + std::vector fork_streams(rmm::cuda_stream_view stream, uint32_t count) { auto streams = global_cuda_stream_pool().get_streams(count); From f8af2b650c117e5e5c6c0b3dedcaf71b5e961d61 Mon Sep 17 00:00:00 2001 From: seidl Date: Fri, 1 Sep 2023 17:13:18 -0700 Subject: [PATCH 59/77] add TODO --- cpp/src/utilities/stream_pool.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 43921815ad9..1c1ad384080 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -33,6 +33,7 @@ namespace { // defaults to 16. std::size_t constexpr STREAM_POOL_SIZE = 32; +// TODO: if this stays internal, then strip out the unused bits and trim the docs. class cuda_stream_pool { public: virtual ~cuda_stream_pool() = default; From d7d30f9c1a91e9d864e23ae7e556fa6b7e58d206 Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 5 Sep 2023 16:12:22 -0700 Subject: [PATCH 60/77] rename fork_stream --- cpp/include/cudf/detail/utilities/stream_pool.hpp | 2 +- cpp/src/io/parquet/reader_impl.cpp | 2 +- cpp/src/io/text/multibyte_split.cu | 2 +- cpp/src/utilities/stream_pool.cpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index 47c1e1d4e7f..ed4d161bae0 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -47,7 +47,7 @@ namespace cudf::detail { * @param count The number of stream views to return. * @return Vector containing `count` stream views. */ -std::vector fork_streams(rmm::cuda_stream_view stream, uint32_t count); +std::vector fork_stream(rmm::cuda_stream_view stream, uint32_t count); /** * @brief Synchronize a stream to an event on a set of streams. diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 8b0a0bd4eb0..6a8b71b78ce 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -165,7 +165,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) // get the number of streams we need from the pool and tell them to wait on the H2D copies int const nkernels = std::bitset<32>(kernel_mask).count(); - auto streams = cudf::detail::fork_streams(_stream, nkernels); + auto streams = cudf::detail::fork_stream(_stream, nkernels); auto const level_type_size = _file_itm_data.level_type_size; diff --git a/cpp/src/io/text/multibyte_split.cu b/cpp/src/io/text/multibyte_split.cu index 772bcad8ada..4a248f094c7 100644 --- a/cpp/src/io/text/multibyte_split.cu +++ b/cpp/src/io/text/multibyte_split.cu @@ -378,7 +378,7 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source output_builder row_offset_storage(ITEMS_PER_CHUNK, max_growth, stream); output_builder char_storage(ITEMS_PER_CHUNK, max_growth, stream); - auto streams = cudf::detail::fork_streams(stream, concurrency); + auto streams = cudf::detail::fork_stream(stream, concurrency); cudaEvent_t last_launch_event; CUDF_CUDA_TRY(cudaEventCreate(&last_launch_event)); diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 1c1ad384080..8ffeec04635 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -168,7 +168,7 @@ cuda_stream_pool& global_cuda_stream_pool() } // anonymous namespace -std::vector fork_streams(rmm::cuda_stream_view stream, uint32_t count) +std::vector fork_stream(rmm::cuda_stream_view stream, uint32_t count) { auto streams = global_cuda_stream_pool().get_streams(count); cudaEvent_t event = event_for_thread(); From 64b4e42ab2dc4b2f735f1193da585a0d065e1b27 Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 5 Sep 2023 16:19:55 -0700 Subject: [PATCH 61/77] add some documentation --- cpp/src/utilities/stream_pool.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 8ffeec04635..704e6f67bc6 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -33,7 +33,6 @@ namespace { // defaults to 16. std::size_t constexpr STREAM_POOL_SIZE = 32; -// TODO: if this stays internal, then strip out the unused bits and trim the docs. class cuda_stream_pool { public: virtual ~cuda_stream_pool() = default; @@ -111,6 +110,9 @@ class rmm_cuda_stream_pool : public cuda_stream_pool { std::size_t get_stream_pool_size() const override { return STREAM_POOL_SIZE; } }; +/** + * @brief Implementation of `cuda_stream_pool` that always returns `cudf::get_default_stream()` + */ class debug_cuda_stream_pool : public cuda_stream_pool { public: rmm::cuda_stream_view get_stream() override { return cudf::get_default_stream(); } @@ -127,6 +129,9 @@ class debug_cuda_stream_pool : public cuda_stream_pool { std::size_t get_stream_pool_size() const override { return 1UL; } }; +/** + * @brief Initialize global stream pool. + */ cuda_stream_pool* create_global_cuda_stream_pool() { if (getenv("LIBCUDF_USE_DEBUG_STREAM_POOL")) return new debug_cuda_stream_pool(); @@ -134,6 +139,9 @@ cuda_stream_pool* create_global_cuda_stream_pool() return new rmm_cuda_stream_pool(); } +/** + * @brief RAII struct to wrap a cuda event and ensure it's proper destruction. + */ struct cuda_event { cuda_event() : e_{[]() { @@ -154,12 +162,18 @@ struct cuda_event { std::unique_ptr e_; }; +/** + * @brief Returns a cudaEvent_t for the current thread. + */ cudaEvent_t event_for_thread() { thread_local cuda_event thread_event; return thread_event; } +/** + * Returns a reference to the global stream ppol. + */ cuda_stream_pool& global_cuda_stream_pool() { static cuda_stream_pool* pool = create_global_cuda_stream_pool(); From 98263661769b7cdbc70ad8ce731ce95c18c67b01 Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 5 Sep 2023 16:59:25 -0700 Subject: [PATCH 62/77] can use std::for_each again --- cpp/src/utilities/stream_pool.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 704e6f67bc6..c0b22db16fa 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -190,19 +190,16 @@ std::vector fork_stream(rmm::cuda_stream_view stream, uin std::for_each(streams.begin(), streams.end(), [&](auto& strm) { CUDF_CUDA_TRY(cudaStreamWaitEvent(strm, event, 0)); }); - for (auto& strm : streams) { - CUDF_CUDA_TRY(cudaStreamWaitEvent(strm, event, 0)); - } return streams; } void join_streams(host_span streams, rmm::cuda_stream_view stream) { cudaEvent_t event = event_for_thread(); - for (auto& strm : streams) { + std::for_each(streams.begin(), streams.end(), [&](auto& strm) { CUDF_CUDA_TRY(cudaEventRecord(event, strm)); CUDF_CUDA_TRY(cudaStreamWaitEvent(stream, event, 0)); - } + }); } } // namespace cudf::detail From 2a2006767c43d32fbf52e21cefff5bcb80aa44cd Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Tue, 5 Sep 2023 21:39:46 -0700 Subject: [PATCH 63/77] fix typo Co-authored-by: Mark Harris <783069+harrism@users.noreply.github.com> --- cpp/src/utilities/stream_pool.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index c0b22db16fa..3ef6a34ec69 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -140,7 +140,7 @@ cuda_stream_pool* create_global_cuda_stream_pool() } /** - * @brief RAII struct to wrap a cuda event and ensure it's proper destruction. + * @brief RAII struct to wrap a cuda event and ensure its proper destruction. */ struct cuda_event { cuda_event() From 4ed8082cab7ac5fc15ec33559c39b0217c6effe9 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 6 Sep 2023 09:34:11 -0700 Subject: [PATCH 64/77] add alias stream_id_t --- cpp/src/utilities/stream_pool.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 3ef6a34ec69..74c02d3c917 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -35,6 +35,9 @@ std::size_t constexpr STREAM_POOL_SIZE = 32; class cuda_stream_pool { public: + // matching type used in rmm::cuda_stream_pool::get_stream(stream_id) + using stream_id_t = std::size_t; + virtual ~cuda_stream_pool() = default; /** @@ -55,7 +58,7 @@ class cuda_stream_pool { * @param stream_id Unique identifier for the desired stream * @return Requested stream view. */ - virtual rmm::cuda_stream_view get_stream(std::size_t stream_id) = 0; + virtual rmm::cuda_stream_view get_stream(stream_id_t stream_id) = 0; /** * @brief Get a set of `cuda_stream_view` objects from the pool. @@ -90,7 +93,7 @@ class rmm_cuda_stream_pool : public cuda_stream_pool { public: rmm_cuda_stream_pool() : _pool{STREAM_POOL_SIZE} {} rmm::cuda_stream_view get_stream() override { return _pool.get_stream(); } - rmm::cuda_stream_view get_stream(std::size_t stream_id) override + rmm::cuda_stream_view get_stream(stream_id_t stream_id) override { return _pool.get_stream(stream_id); } @@ -116,7 +119,7 @@ class rmm_cuda_stream_pool : public cuda_stream_pool { class debug_cuda_stream_pool : public cuda_stream_pool { public: rmm::cuda_stream_view get_stream() override { return cudf::get_default_stream(); } - rmm::cuda_stream_view get_stream(std::size_t stream_id) override + rmm::cuda_stream_view get_stream(stream_id_t stream_id) override { return cudf::get_default_stream(); } From 272d883268e8c3520e3ccf8c31576bd30fb1ce3c Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 6 Sep 2023 09:37:03 -0700 Subject: [PATCH 65/77] change stream count to size_t --- cpp/src/utilities/stream_pool.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 74c02d3c917..f93eecadd1a 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -72,7 +72,7 @@ class cuda_stream_pool { * @param count The number of stream views to return. * @return Vector containing `count` stream views. */ - virtual std::vector get_streams(uint32_t count) = 0; + virtual std::vector get_streams(std::size_t count) = 0; /** * @brief Get the number of stream objects in the pool. @@ -98,7 +98,7 @@ class rmm_cuda_stream_pool : public cuda_stream_pool { return _pool.get_stream(stream_id); } - std::vector get_streams(uint32_t count) override + std::vector get_streams(std::size_t count) override { if (count > STREAM_POOL_SIZE) { CUDF_LOG_WARN("get_streams called with count ({}) > pool size ({})", count, STREAM_POOL_SIZE); @@ -124,7 +124,7 @@ class debug_cuda_stream_pool : public cuda_stream_pool { return cudf::get_default_stream(); } - std::vector get_streams(uint32_t count) override + std::vector get_streams(std::size_t count) override { return std::vector(count, cudf::get_default_stream()); } From 8b4f15d01a9262b0871f19c2bf6042896088d8fe Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 6 Sep 2023 09:58:26 -0700 Subject: [PATCH 66/77] wrap cudaEventDestroy in a debug-only assert --- cpp/src/utilities/stream_pool.cpp | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index f93eecadd1a..f9a60e0ea02 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -33,6 +33,25 @@ namespace { // defaults to 16. std::size_t constexpr STREAM_POOL_SIZE = 32; +// FIXME: "borrowed" from rmm...remove when this stream pool is moved there +#ifdef NDEBUG +#define CUDF_ASSERT_CUDA_SUCCESS(_call) \ + do { \ + (_call); \ + } while (0); +#else +#define CUDF_ASSERT_CUDA_SUCCESS(_call) \ + do { \ + cudaError_t const status__ = (_call); \ + if (status__ != cudaSuccess) { \ + std::cerr << "CUDA Error detected. " << cudaGetErrorName(status__) << " " \ + << cudaGetErrorString(status__) << std::endl; \ + } \ + /* NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay) */ \ + assert(status__ == cudaSuccess); \ + } while (0) +#endif + class cuda_stream_pool { public: // matching type used in rmm::cuda_stream_pool::get_stream(stream_id) @@ -160,7 +179,7 @@ struct cuda_event { private: struct deleter { using pointer = cudaEvent_t; - auto operator()(cudaEvent_t e) { cudaEventDestroy(e); } + auto operator()(cudaEvent_t e) { CUDF_ASSERT_CUDA_SUCCESS(cudaEventDestroy(e)); } }; std::unique_ptr e_; }; From 457c6fba05b5bafc35d3c266d5d74b3ebda96076 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 6 Sep 2023 12:03:36 -0700 Subject: [PATCH 67/77] modify event_for_thread() to take into account multiple devices --- cpp/src/utilities/stream_pool.cpp | 41 +++++++++++++++++++------------ 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index f9a60e0ea02..2ae15926ef9 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -161,27 +161,32 @@ cuda_stream_pool* create_global_cuda_stream_pool() return new rmm_cuda_stream_pool(); } +// FIXME: these will be available in rmm soon +inline int get_num_cuda_devices() +{ + rmm::cuda_device_id::value_type num_dev{}; + CUDF_CUDA_TRY(cudaGetDeviceCount(&num_dev)); + return num_dev; +} + +rmm::cuda_device_id get_current_cuda_device() +{ + int device_id; + CUDF_CUDA_TRY(cudaGetDevice(&device_id)); + return rmm::cuda_device_id{device_id}; +} + /** * @brief RAII struct to wrap a cuda event and ensure its proper destruction. */ struct cuda_event { - cuda_event() - : e_{[]() { - cudaEvent_t event; - CUDF_CUDA_TRY(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); - return event; - }()} - { - } + cuda_event() { CUDF_CUDA_TRY(cudaEventCreateWithFlags(&e_, cudaEventDisableTiming)); } + virtual ~cuda_event() { CUDF_ASSERT_CUDA_SUCCESS(cudaEventDestroy(e_)); } - operator cudaEvent_t() { return e_.get(); } + operator cudaEvent_t() { return e_; } private: - struct deleter { - using pointer = cudaEvent_t; - auto operator()(cudaEvent_t e) { CUDF_ASSERT_CUDA_SUCCESS(cudaEventDestroy(e)); } - }; - std::unique_ptr e_; + cudaEvent_t e_; }; /** @@ -189,8 +194,12 @@ struct cuda_event { */ cudaEvent_t event_for_thread() { - thread_local cuda_event thread_event; - return thread_event; + thread_local std::vector> thread_events(get_num_cuda_devices()); + auto const device_id = get_current_cuda_device(); + if (not thread_events[device_id.value()]) { + thread_events[device_id.value()] = std::make_unique(); + } + return *thread_events[device_id.value()]; } /** From 04dbba59aa17ba829d0073292f4feb148b7565ca Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 6 Sep 2023 16:16:12 -0700 Subject: [PATCH 68/77] per-device stream pools --- cpp/src/utilities/stream_pool.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 2ae15926ef9..b9c230693dc 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -203,12 +203,21 @@ cudaEvent_t event_for_thread() } /** - * Returns a reference to the global stream ppol. + * Returns a reference to the global stream pool for the current device. */ cuda_stream_pool& global_cuda_stream_pool() { - static cuda_stream_pool* pool = create_global_cuda_stream_pool(); - return *pool; + // using bare pointers here to deliberately allow them to leak. otherwise we wind up with + // seg faults trying to destroy stream objects after the context has shut down. + static std::vector pools(get_num_cuda_devices()); + static std::mutex mutex; + auto const device_id = get_current_cuda_device(); + + std::lock_guard lock(mutex); + if (pools[device_id.value()] == nullptr) { + pools[device_id.value()] = create_global_cuda_stream_pool(); + } + return *pools[device_id.value()]; } } // anonymous namespace From ded5900e8c3fed5d5df6ed3de6ba55b289d37c06 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 6 Sep 2023 16:18:48 -0700 Subject: [PATCH 69/77] use size_t for fork_stream too --- cpp/include/cudf/detail/utilities/stream_pool.hpp | 2 +- cpp/src/utilities/stream_pool.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index ed4d161bae0..e55011842d0 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -47,7 +47,7 @@ namespace cudf::detail { * @param count The number of stream views to return. * @return Vector containing `count` stream views. */ -std::vector fork_stream(rmm::cuda_stream_view stream, uint32_t count); +std::vector fork_stream(rmm::cuda_stream_view stream, std::size_t count); /** * @brief Synchronize a stream to an event on a set of streams. diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index b9c230693dc..87aa36a9aac 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -222,7 +222,7 @@ cuda_stream_pool& global_cuda_stream_pool() } // anonymous namespace -std::vector fork_stream(rmm::cuda_stream_view stream, uint32_t count) +std::vector fork_stream(rmm::cuda_stream_view stream, std::size_t count) { auto streams = global_cuda_stream_pool().get_streams(count); cudaEvent_t event = event_for_thread(); From 1c5ae323f5b31d4adf284a88dfa290078c05179d Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 6 Sep 2023 16:20:15 -0700 Subject: [PATCH 70/77] rename stream_id_t to stream_id_type --- cpp/src/utilities/stream_pool.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 87aa36a9aac..25359fb8567 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -55,7 +55,7 @@ std::size_t constexpr STREAM_POOL_SIZE = 32; class cuda_stream_pool { public: // matching type used in rmm::cuda_stream_pool::get_stream(stream_id) - using stream_id_t = std::size_t; + using stream_id_type = std::size_t; virtual ~cuda_stream_pool() = default; @@ -77,7 +77,7 @@ class cuda_stream_pool { * @param stream_id Unique identifier for the desired stream * @return Requested stream view. */ - virtual rmm::cuda_stream_view get_stream(stream_id_t stream_id) = 0; + virtual rmm::cuda_stream_view get_stream(stream_id_type stream_id) = 0; /** * @brief Get a set of `cuda_stream_view` objects from the pool. @@ -112,7 +112,7 @@ class rmm_cuda_stream_pool : public cuda_stream_pool { public: rmm_cuda_stream_pool() : _pool{STREAM_POOL_SIZE} {} rmm::cuda_stream_view get_stream() override { return _pool.get_stream(); } - rmm::cuda_stream_view get_stream(stream_id_t stream_id) override + rmm::cuda_stream_view get_stream(stream_id_type stream_id) override { return _pool.get_stream(stream_id); } @@ -138,7 +138,7 @@ class rmm_cuda_stream_pool : public cuda_stream_pool { class debug_cuda_stream_pool : public cuda_stream_pool { public: rmm::cuda_stream_view get_stream() override { return cudf::get_default_stream(); } - rmm::cuda_stream_view get_stream(stream_id_t stream_id) override + rmm::cuda_stream_view get_stream(stream_id_type stream_id) override { return cudf::get_default_stream(); } From 5ffc75a2ffe7edccd1c716bac4ff041608878260 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 6 Sep 2023 16:30:27 -0700 Subject: [PATCH 71/77] add some more docstrings --- cpp/src/utilities/stream_pool.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 25359fb8567..c61f9905b99 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -191,6 +191,10 @@ struct cuda_event { /** * @brief Returns a cudaEvent_t for the current thread. + * + * The returned event is valid for the current device. + * + * @return A cudaEvent_t unique to the current thread and valid on the current device. */ cudaEvent_t event_for_thread() { @@ -203,7 +207,8 @@ cudaEvent_t event_for_thread() } /** - * Returns a reference to the global stream pool for the current device. + * @brief Returns a reference to the global stream pool for the current device. + * @return `cuda_stream_pool` valid on the current device. */ cuda_stream_pool& global_cuda_stream_pool() { From 0a7035f5f62073d8eb12dbcbd3ec735c1a0343c9 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Wed, 6 Sep 2023 16:32:21 -0700 Subject: [PATCH 72/77] implement suggestion from review Co-authored-by: Bradley Dice --- cpp/include/cudf/detail/utilities/stream_pool.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index e55011842d0..40407f9b731 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -52,8 +52,8 @@ std::vector fork_stream(rmm::cuda_stream_view stream, std /** * @brief Synchronize a stream to an event on a set of streams. * - * @param streams Vector of streams to synchronize to. - * @param stream Stream to synchronize, usually the default stream. + * @param streams Streams to wait on. + * @param stream Joined stream that synchronizes with the waited-on streams. */ void join_streams(host_span streams, rmm::cuda_stream_view stream); From 9fb0958c1712e5bac13b8af4e5a5979d7e11f308 Mon Sep 17 00:00:00 2001 From: seidl Date: Wed, 6 Sep 2023 16:35:57 -0700 Subject: [PATCH 73/77] more docstring cleanup --- cpp/include/cudf/detail/utilities/stream_pool.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index 40407f9b731..b0a9bf3eeaf 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -37,14 +37,14 @@ namespace cudf::detail { * auto const num_streams = 2; * // do work on stream * // allocate streams and wait for an event on stream before executing on any of streams - * auto streams = cudf::detail::fork_streams(stream, num_streams); + * auto streams = cudf::detail::fork_stream(stream, num_streams); * // do work on streams[0] and streams[1] * // wait for event on streams before continuing to do work on stream * cudf::detail::join_streams(streams, stream); * @endcode * - * @param stream Stream to synchronize the returned streams to, usually the default stream. - * @param count The number of stream views to return. + * @param stream Stream that the returned streams will wait on. + * @param count The number of `cuda_stream_view` objects to return. * @return Vector containing `count` stream views. */ std::vector fork_stream(rmm::cuda_stream_view stream, std::size_t count); From 2629a7927998fff4f340fc6be4f79548ed226423 Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 7 Sep 2023 10:59:54 -0700 Subject: [PATCH 74/77] fork_streams is back on the menu --- cpp/include/cudf/detail/utilities/stream_pool.hpp | 2 +- cpp/src/io/parquet/reader_impl.cpp | 2 +- cpp/src/io/text/multibyte_split.cu | 2 +- cpp/src/utilities/stream_pool.cpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index b0a9bf3eeaf..be3bfafbe93 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -47,7 +47,7 @@ namespace cudf::detail { * @param count The number of `cuda_stream_view` objects to return. * @return Vector containing `count` stream views. */ -std::vector fork_stream(rmm::cuda_stream_view stream, std::size_t count); +std::vector fork_streams(rmm::cuda_stream_view stream, std::size_t count); /** * @brief Synchronize a stream to an event on a set of streams. diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index 6a8b71b78ce..8b0a0bd4eb0 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -165,7 +165,7 @@ void reader::impl::decode_page_data(size_t skip_rows, size_t num_rows) // get the number of streams we need from the pool and tell them to wait on the H2D copies int const nkernels = std::bitset<32>(kernel_mask).count(); - auto streams = cudf::detail::fork_stream(_stream, nkernels); + auto streams = cudf::detail::fork_streams(_stream, nkernels); auto const level_type_size = _file_itm_data.level_type_size; diff --git a/cpp/src/io/text/multibyte_split.cu b/cpp/src/io/text/multibyte_split.cu index 4a248f094c7..772bcad8ada 100644 --- a/cpp/src/io/text/multibyte_split.cu +++ b/cpp/src/io/text/multibyte_split.cu @@ -378,7 +378,7 @@ std::unique_ptr multibyte_split(cudf::io::text::data_chunk_source output_builder row_offset_storage(ITEMS_PER_CHUNK, max_growth, stream); output_builder char_storage(ITEMS_PER_CHUNK, max_growth, stream); - auto streams = cudf::detail::fork_stream(stream, concurrency); + auto streams = cudf::detail::fork_streams(stream, concurrency); cudaEvent_t last_launch_event; CUDF_CUDA_TRY(cudaEventCreate(&last_launch_event)); diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index c61f9905b99..4e2db0b835e 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -227,7 +227,7 @@ cuda_stream_pool& global_cuda_stream_pool() } // anonymous namespace -std::vector fork_stream(rmm::cuda_stream_view stream, std::size_t count) +std::vector fork_streams(rmm::cuda_stream_view stream, std::size_t count) { auto streams = global_cuda_stream_pool().get_streams(count); cudaEvent_t event = event_for_thread(); From 87626b665bed91acbbe86ac011efa6da6b1ebfca Mon Sep 17 00:00:00 2001 From: seidl Date: Thu, 7 Sep 2023 12:41:17 -0700 Subject: [PATCH 75/77] add nodiscard to fork_streams --- cpp/include/cudf/detail/utilities/stream_pool.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index be3bfafbe93..d62ed69fcae 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -47,7 +47,8 @@ namespace cudf::detail { * @param count The number of `cuda_stream_view` objects to return. * @return Vector containing `count` stream views. */ -std::vector fork_streams(rmm::cuda_stream_view stream, std::size_t count); +[[nodiscard]] std::vector fork_streams(rmm::cuda_stream_view stream, + std::size_t count); /** * @brief Synchronize a stream to an event on a set of streams. From 560c03ca49837ed9dacac9a75a1f5bd29e456bc8 Mon Sep 17 00:00:00 2001 From: Ed Seidl Date: Fri, 8 Sep 2023 11:14:31 -0700 Subject: [PATCH 76/77] Apply suggestions from code review Co-authored-by: Yunsong Wang --- cpp/include/cudf/detail/utilities/stream_pool.hpp | 3 +++ cpp/src/utilities/stream_pool.cpp | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index d62ed69fcae..d338ce7f7ea 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -20,6 +20,9 @@ #include +#include +#include + namespace cudf::detail { /** diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index 4e2db0b835e..e2097c60dc4 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -21,6 +21,12 @@ #include +#include +#include +#include +#include +#include + namespace cudf::detail { namespace { From afd71f900386469ac75b2dc35d009573d0e34354 Mon Sep 17 00:00:00 2001 From: seidl Date: Tue, 12 Sep 2023 13:48:57 -0700 Subject: [PATCH 77/77] add consts per review comments --- cpp/include/cudf/detail/utilities/stream_pool.hpp | 2 +- cpp/src/utilities/stream_pool.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/stream_pool.hpp b/cpp/include/cudf/detail/utilities/stream_pool.hpp index d338ce7f7ea..95384a9d73e 100644 --- a/cpp/include/cudf/detail/utilities/stream_pool.hpp +++ b/cpp/include/cudf/detail/utilities/stream_pool.hpp @@ -59,6 +59,6 @@ namespace cudf::detail { * @param streams Streams to wait on. * @param stream Joined stream that synchronizes with the waited-on streams. */ -void join_streams(host_span streams, rmm::cuda_stream_view stream); +void join_streams(host_span streams, rmm::cuda_stream_view stream); } // namespace cudf::detail diff --git a/cpp/src/utilities/stream_pool.cpp b/cpp/src/utilities/stream_pool.cpp index e2097c60dc4..b3b20889ef8 100644 --- a/cpp/src/utilities/stream_pool.cpp +++ b/cpp/src/utilities/stream_pool.cpp @@ -235,8 +235,8 @@ cuda_stream_pool& global_cuda_stream_pool() std::vector fork_streams(rmm::cuda_stream_view stream, std::size_t count) { - auto streams = global_cuda_stream_pool().get_streams(count); - cudaEvent_t event = event_for_thread(); + auto const streams = global_cuda_stream_pool().get_streams(count); + auto const event = event_for_thread(); CUDF_CUDA_TRY(cudaEventRecord(event, stream)); std::for_each(streams.begin(), streams.end(), [&](auto& strm) { CUDF_CUDA_TRY(cudaStreamWaitEvent(strm, event, 0)); @@ -244,9 +244,9 @@ std::vector fork_streams(rmm::cuda_stream_view stream, st return streams; } -void join_streams(host_span streams, rmm::cuda_stream_view stream) +void join_streams(host_span streams, rmm::cuda_stream_view stream) { - cudaEvent_t event = event_for_thread(); + auto const event = event_for_thread(); std::for_each(streams.begin(), streams.end(), [&](auto& strm) { CUDF_CUDA_TRY(cudaEventRecord(event, strm)); CUDF_CUDA_TRY(cudaStreamWaitEvent(stream, event, 0));