Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes stack context for json lines format that recovers from invalid JSON lines #14309

Merged
143 changes: 123 additions & 20 deletions cpp/src/io/fst/logical_stack.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/fill.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/scatter.h>

#include <cub/cub.cuh>
Expand All @@ -48,6 +49,14 @@ enum class stack_op_type : int8_t {
RESET = 3 ///< Operation popping all items currently on the stack
};

/**
* @brief Describes the kind of stack operations supported by the logical stack.
*/
enum class stack_op_support : bool {
NO_RESET_SUPPORT = false, ///< A stack that only supports push(x) and pop() operations
WITH_RESET_SUPPORT = true ///< A stack that supports push(x), pop(), and reset() operations
};

namespace detail {

/**
Expand Down Expand Up @@ -130,6 +139,37 @@ struct StackSymbolToStackOp {
StackSymbolToStackOpTypeT symbol_to_stack_op_type;
};

/**
* @brief Function object that maps a stack `reset` operation to `1`.
*/
template <typename StackSymbolToStackOpTypeT>
struct NewlineToResetStackSegmentOp {
template <typename StackSymbolT>
constexpr CUDF_HOST_DEVICE uint32_t operator()(StackSymbolT const& stack_symbol) const
{
stack_op_type stack_op = symbol_to_stack_op_type(stack_symbol);

// Every reset operation marks the beginning of a new segment
return (stack_op == stack_op_type::RESET) ? 1 : 0;
}

/// Function object returning a stack operation type for a given stack symbol
StackSymbolToStackOpTypeT symbol_to_stack_op_type;
};

/**
* @brief Function object that wraps around for values that exceed the largest value of `TargetT`
*/
template <typename TargetT>
struct ModToTargetTypeOpT {
template <typename T>
constexpr CUDF_HOST_DEVICE TargetT operator()(T const& val) const
{
return static_cast<TargetT>(
val % (static_cast<T>(cuda::std::numeric_limits<TargetT>::max()) + static_cast<T>(1)));
}
};

/**
* @brief Binary reduction operator to compute the absolute stack level from relative stack levels
* (i.e., +1 for a PUSH, -1 for a POP operation).
Expand All @@ -140,9 +180,7 @@ struct AddStackLevelFromStackOp {
constexpr CUDF_HOST_DEVICE StackOp<StackLevelT, ValueT> operator()(
StackOp<StackLevelT, ValueT> const& lhs, StackOp<StackLevelT, ValueT> const& rhs) const
{
StackLevelT new_level = (symbol_to_stack_op_type(rhs.value) == stack_op_type::RESET)
? 0
: (lhs.stack_level + rhs.stack_level);
StackLevelT new_level = lhs.stack_level + rhs.stack_level;
return StackOp<StackLevelT, ValueT>{new_level, rhs.value};
}

Expand Down Expand Up @@ -230,6 +268,8 @@ struct RemapEmptyStack {
* onto the stack or pop something from the stack and resolves the symbol that is on top of the
* stack.
*
* @tparam SupportResetOperation Whether the logical stack also supports `reset` operations that
* reset the stack to the empty stack
* @tparam StackLevelT Signed integer type that must be sufficient to cover [-max_stack_level,
* max_stack_level] for the given sequence of stack operations. Must be signed as it needs to cover
* the stack level of any arbitrary subsequence of stack operations.
Expand Down Expand Up @@ -261,7 +301,8 @@ struct RemapEmptyStack {
* what-is-on-top-of-the-stack
* @param[in] stream The cuda stream to which to dispatch the work
*/
template <typename StackLevelT,
template <stack_op_support SupportResetOperation,
typename StackLevelT,
typename StackSymbolItT,
typename SymbolPositionT,
typename StackSymbolToStackOpTypeT,
Expand All @@ -281,6 +322,9 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols,
// Type used to hold pairs of (stack_level, value) pairs
using StackOpT = detail::StackOp<StackLevelT, StackSymbolT>;

// Type used to mark *-by-key segments after `reset` operations
using StackSegmentT = uint8_t;

// The unsigned integer type that we use for radix sorting items of type StackOpT
using StackOpUnsignedT = detail::UnsignedStackOpType<StackOpT>;
static_assert(!std::is_void<StackOpUnsignedT>(), "unsupported StackOpT size");
Expand All @@ -292,6 +336,8 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols,
using TransformInputItT =
cub::TransformInputIterator<StackOpT, StackSymbolToStackOpT, StackSymbolItT>;

constexpr bool supports_reset_op = SupportResetOperation == stack_op_support::WITH_RESET_SUPPORT;

auto const num_symbols_in = d_symbol_positions.size();

// Converting a stack symbol that may either push or pop to a stack operation:
Expand Down Expand Up @@ -330,14 +376,44 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols,

// Getting temporary storage requirements for the prefix sum of the stack level after each
// operation
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan(
nullptr,
stack_level_scan_bytes,
stack_symbols_in,
d_kv_operations.Current(),
detail::AddStackLevelFromStackOp<StackSymbolToStackOpTypeT>{symbol_to_stack_op},
num_symbols_in,
stream));
if constexpr (supports_reset_op) {
// Iterator that returns `1` for every symbol that corresponds to a `reset` operation
auto reset_segments_it = thrust::make_transform_iterator(
d_symbols,
detail::NewlineToResetStackSegmentOp<StackSymbolToStackOpTypeT>{symbol_to_stack_op});

auto const fake_key_segment_it = static_cast<StackSegmentT*>(nullptr);
std::size_t gen_segments_scan_bytes = 0;
std::size_t scan_by_key_bytes = 0;
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveSum(
nullptr,
gen_segments_scan_bytes,
reset_segments_it,
thrust::make_transform_output_iterator(fake_key_segment_it,
detail::ModToTargetTypeOpT<StackSegmentT>{}),
num_symbols_in,
stream));
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScanByKey(
nullptr,
scan_by_key_bytes,
fake_key_segment_it,
stack_symbols_in,
d_kv_operations.Current(),
detail::AddStackLevelFromStackOp<StackSymbolToStackOpTypeT>{symbol_to_stack_op},
num_symbols_in,
cub::Equality{},
stream));
stack_level_scan_bytes = std::max(gen_segments_scan_bytes, scan_by_key_bytes);
} else {
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan(
nullptr,
stack_level_scan_bytes,
stack_symbols_in,
d_kv_operations.Current(),
detail::AddStackLevelFromStackOp<StackSymbolToStackOpTypeT>{symbol_to_stack_op},
num_symbols_in,
stream));
}

// Getting temporary storage requirements for the stable radix sort (sorting by stack level of the
// operations)
Expand Down Expand Up @@ -401,14 +477,41 @@ void sparse_stack_op_to_top_of_stack(StackSymbolItT d_symbols,
d_kv_operations = cub::DoubleBuffer<StackOpT>{d_kv_ops_current.data(), d_kv_ops_alt.data()};

// Compute prefix sum of the stack level after each operation
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan(
temp_storage.data(),
total_temp_storage_bytes,
stack_symbols_in,
d_kv_operations.Current(),
detail::AddStackLevelFromStackOp<StackSymbolToStackOpTypeT>{symbol_to_stack_op},
num_symbols_in,
stream));
if constexpr (supports_reset_op) {
// Iterator that returns `1` for every symbol that corresponds to a `reset` operation
auto reset_segments_it = thrust::make_transform_iterator(
d_symbols,
detail::NewlineToResetStackSegmentOp<StackSymbolToStackOpTypeT>{symbol_to_stack_op});

rmm::device_uvector<StackSegmentT> key_segments{num_symbols_in, stream};
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveSum(
temp_storage.data(),
total_temp_storage_bytes,
reset_segments_it,
thrust::make_transform_output_iterator(key_segments.data(),
detail::ModToTargetTypeOpT<StackSegmentT>{}),
elstehle marked this conversation as resolved.
Show resolved Hide resolved
num_symbols_in,
stream));
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScanByKey(
temp_storage.data(),
total_temp_storage_bytes,
key_segments.data(),
stack_symbols_in,
d_kv_operations.Current(),
detail::AddStackLevelFromStackOp<StackSymbolToStackOpTypeT>{symbol_to_stack_op},
num_symbols_in,
cub::Equality{},
stream));
} else {
CUDF_CUDA_TRY(cub::DeviceScan::InclusiveScan(
temp_storage.data(),
total_temp_storage_bytes,
stack_symbols_in,
d_kv_operations.Current(),
detail::AddStackLevelFromStackOp<StackSymbolToStackOpTypeT>{symbol_to_stack_op},
num_symbols_in,
stream));
}

// Stable radix sort, sorting by stack level of the operations
d_kv_operations_unsigned = cub::DoubleBuffer<StackOpUnsignedT>{
Expand Down
31 changes: 22 additions & 9 deletions cpp/src/io/json/nested_json_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -343,27 +343,35 @@ constexpr auto NUM_SYMBOL_GROUPS = static_cast<uint32_t>(dfa_symbol_group_id::NU
std::array<std::string, NUM_SYMBOL_GROUPS - 1> const symbol_groups{
{{"{"}, {"["}, {"}"}, {"]"}, {"\""}, {"\\"}, {"\n"}}};

// Transition table
// Transition table for the default JSON and JSON lines formats
std::array<std::array<dfa_states, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> const transition_table{
{/* IN_STATE { [ } ] " \ \n OTHER */
/* TT_OOS */ {{TT_OOS, TT_OOS, TT_OOS, TT_OOS, TT_STR, TT_OOS, TT_OOS, TT_OOS}},
/* TT_STR */ {{TT_STR, TT_STR, TT_STR, TT_STR, TT_OOS, TT_ESC, TT_STR, TT_STR}},
/* TT_ESC */ {{TT_STR, TT_STR, TT_STR, TT_STR, TT_STR, TT_STR, TT_STR, TT_STR}}}};

// Translation table (i.e., for each transition, what are the symbols that we output)
// Transition table for the JSON lines format that recovers from invalid JSON lines
std::array<std::array<dfa_states, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> const
resetting_transition_table{
{/* IN_STATE { [ } ] " \ \n OTHER */
/* TT_OOS */ {{TT_OOS, TT_OOS, TT_OOS, TT_OOS, TT_STR, TT_OOS, TT_OOS, TT_OOS}},
/* TT_STR */ {{TT_STR, TT_STR, TT_STR, TT_STR, TT_OOS, TT_ESC, TT_OOS, TT_STR}},
/* TT_ESC */ {{TT_STR, TT_STR, TT_STR, TT_STR, TT_STR, TT_STR, TT_OOS, TT_STR}}}};

// Translation table for the default JSON and JSON lines formats
std::array<std::array<std::vector<char>, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> const translation_table{
{/* IN_STATE { [ } ] " \ \n OTHER */
/* TT_OOS */ {{{'{'}, {'['}, {'}'}, {']'}, {}, {}, {}, {}}},
/* TT_STR */ {{{}, {}, {}, {}, {}, {}, {}, {}}},
/* TT_ESC */ {{{}, {}, {}, {}, {}, {}, {}, {}}}}};

// Translation table
// Translation table for the JSON lines format that recovers from invalid JSON lines
std::array<std::array<std::vector<char>, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> const
resetting_translation_table{
{/* IN_STATE { [ } ] " \ \n OTHER */
/* TT_OOS */ {{{'{'}, {'['}, {'}'}, {']'}, {}, {}, {'\n'}, {}}},
/* TT_STR */ {{{}, {}, {}, {}, {}, {}, {}, {}}},
/* TT_ESC */ {{{}, {}, {}, {}, {}, {}, {}, {}}}}};
/* TT_STR */ {{{}, {}, {}, {}, {}, {}, {'\n'}, {}}},
/* TT_ESC */ {{{}, {}, {}, {}, {}, {}, {'\n'}, {}}}}};

// The DFA's starting state
constexpr auto start_state = static_cast<StateT>(TT_OOS);
Expand Down Expand Up @@ -1415,14 +1423,19 @@ void get_stack_context(device_span<SymbolT const> json_in,
constexpr auto max_translation_table_size =
to_stack_op::NUM_SYMBOL_GROUPS * to_stack_op::TT_NUM_STATES;

// Translation table specialized on the choice of whether to reset on newlines outside of strings
// Transition table specialized on the choice of whether to reset on newlines
const auto transition_table = (stack_behavior == stack_behavior_t::ResetOnDelimiter)
? to_stack_op::resetting_transition_table
: to_stack_op::transition_table;

// Translation table specialized on the choice of whether to reset on newlines
const auto translation_table = (stack_behavior == stack_behavior_t::ResetOnDelimiter)
? to_stack_op::resetting_translation_table
: to_stack_op::translation_table;

auto json_to_stack_ops_fst = fst::detail::make_fst(
fst::detail::make_symbol_group_lut(to_stack_op::symbol_groups),
fst::detail::make_transition_table(to_stack_op::transition_table),
fst::detail::make_transition_table(transition_table),
fst::detail::make_translation_table<max_translation_table_size>(translation_table),
stream);

Expand All @@ -1441,7 +1454,7 @@ void get_stack_context(device_span<SymbolT const> json_in,

// Stack operations with indices are converted to top of the stack for each character in the input
if (stack_behavior == stack_behavior_t::ResetOnDelimiter) {
fst::sparse_stack_op_to_top_of_stack<StackLevelT>(
fst::sparse_stack_op_to_top_of_stack<fst::stack_op_support::WITH_RESET_SUPPORT, StackLevelT>(
stack_ops.data(),
device_span<SymbolOffsetT>{stack_op_indices.data(), num_stack_ops},
JSONWithRecoveryToStackOp{},
Expand All @@ -1451,7 +1464,7 @@ void get_stack_context(device_span<SymbolT const> json_in,
json_in.size(),
stream);
} else {
fst::sparse_stack_op_to_top_of_stack<StackLevelT>(
fst::sparse_stack_op_to_top_of_stack<fst::stack_op_support::NO_RESET_SUPPORT, StackLevelT>(
stack_ops.data(),
device_span<SymbolOffsetT>{stack_op_indices.data(), num_stack_ops},
JSONToStackOp{},
Expand Down
17 changes: 9 additions & 8 deletions cpp/tests/io/fst/logical_stack_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,15 @@ TEST_F(LogicalStackTest, GroundTruth)
stream.value()));

// Run algorithm
fst::sparse_stack_op_to_top_of_stack<StackLevelT>(d_stack_ops.data(),
d_stack_op_idx_span,
JSONToStackOp{},
top_of_stack_gpu.device_ptr(),
empty_stack_symbol,
read_symbol,
string_size,
stream.value());
fst::sparse_stack_op_to_top_of_stack<fst::stack_op_support::NO_RESET_SUPPORT, StackLevelT>(
d_stack_ops.data(),
d_stack_op_idx_span,
JSONToStackOp{},
top_of_stack_gpu.device_ptr(),
empty_stack_symbol,
read_symbol,
string_size,
stream.value());

// Async copy results from device to host
top_of_stack_gpu.device_to_host_async(stream_view);
Expand Down
Loading