Skip to content

Commit

Permalink
Expose stream parameter in public strings replace APIs (#14261)
Browse files Browse the repository at this point in the history
Add stream parameter to public APIs:

- `cudf::strings::replace()` (x2)
- `cudf::strings::replace_slice()`
- `cudf::strings::replace_re()` (x2)
- `cudf::strings::replace_with_backrefs()`

Also cleaned up some of the doxygen comments and added stream-tests.

Reference #13744

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Mike Wilson (https://github.com/hyperbolic2346)
  - Nghia Truong (https://github.com/ttnghia)

URL: #14261
  • Loading branch information
davidwendt authored Oct 12, 2023
1 parent 737b759 commit fa4e8ab
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 39 deletions.
42 changes: 24 additions & 18 deletions cpp/include/cudf/strings/replace.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,19 +54,21 @@ namespace strings {
*
* @throw cudf::logic_error if target is an empty string.
*
* @param strings Strings column for this operation.
* @param target String to search for within each string.
* @param repl Replacement string if target is found.
* @param input Strings column for this operation
* @param target String to search for within each string
* @param repl Replacement string if target is found
* @param maxrepl Maximum times to replace if target appears multiple times in the input string.
* Default of -1 specifies replace all occurrences of target in each string.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New strings column.
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return New strings column
*/
std::unique_ptr<column> replace(
strings_column_view const& strings,
strings_column_view const& input,
string_scalar const& target,
string_scalar const& repl,
int32_t maxrepl = -1,
cudf::size_type maxrepl = -1,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -92,21 +94,23 @@ std::unique_ptr<column> replace(
*
* @throw cudf::logic_error if start is greater than stop.
*
* @param strings Strings column for this operation.
* @param input Strings column for this operation.
* @param repl Replacement string for specified positions found.
* Default is empty string.
* @param start Start position where repl will be added.
* Default is 0, first character position.
* @param stop End position (exclusive) to use for replacement.
* Default of -1 specifies the end of each string.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New strings column.
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return New strings column
*/
std::unique_ptr<column> replace_slice(
strings_column_view const& strings,
strings_column_view const& input,
string_scalar const& repl = string_scalar(""),
size_type start = 0,
size_type stop = -1,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -141,16 +145,18 @@ std::unique_ptr<column> replace_slice(
* if repls is a single string.
* @throw cudf::logic_error if targets or repls contain null entries.
*
* @param strings Strings column for this operation.
* @param targets Strings to search for in each string.
* @param repls Corresponding replacement strings for target strings.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New strings column.
* @param input Strings column for this operation
* @param targets Strings to search for in each string
* @param repls Corresponding replacement strings for target strings
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return New strings column
*/
std::unique_ptr<column> replace(
strings_column_view const& strings,
strings_column_view const& input,
strings_column_view const& targets,
strings_column_view const& repls,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of doxygen group
Expand Down
28 changes: 17 additions & 11 deletions cpp/include/cudf/strings/replace_re.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,22 @@ struct regex_program;
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @param strings Strings instance for this operation
* @param input Strings instance for this operation
* @param prog Regex program instance
* @param replacement The string used to replace the matched sequence in each string.
* Default is an empty string.
* @param max_replace_count The maximum number of times to replace the matched pattern
* within each string. Default replaces every substring that is matched.
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return New strings column
*/
std::unique_ptr<column> replace_re(
strings_column_view const& strings,
strings_column_view const& input,
regex_program const& prog,
string_scalar const& replacement = string_scalar(""),
std::optional<size_type> max_replace_count = std::nullopt,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -67,18 +69,20 @@ std::unique_ptr<column> replace_re(
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @param strings Strings instance for this operation.
* @param patterns The regular expression patterns to search within each string.
* @param replacements The strings used for replacement.
* @param flags Regex flags for interpreting special characters in the patterns.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New strings column.
* @param input Strings instance for this operation
* @param patterns The regular expression patterns to search within each string
* @param replacements The strings used for replacement
* @param flags Regex flags for interpreting special characters in the patterns
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return New strings column
*/
std::unique_ptr<column> replace_re(
strings_column_view const& strings,
strings_column_view const& input,
std::vector<std::string> const& patterns,
strings_column_view const& replacements,
regex_flags const flags = regex_flags::DEFAULT,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -92,16 +96,18 @@ std::unique_ptr<column> replace_re(
* @throw cudf::logic_error if capture index values in `replacement` are not in range 0-99, and also
* if the index exceeds the group count specified in the pattern
*
* @param strings Strings instance for this operation
* @param input Strings instance for this operation
* @param prog Regex program instance
* @param replacement The replacement template for creating the output string
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return New strings column
*/
std::unique_ptr<column> replace_with_backrefs(
strings_column_view const& strings,
strings_column_view const& input,
regex_program const& prog,
std::string_view replacement,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

} // namespace strings
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/strings/replace/backref_re.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,11 @@ std::unique_ptr<column> replace_with_backrefs(strings_column_view const& input,
std::unique_ptr<column> replace_with_backrefs(strings_column_view const& strings,
regex_program const& prog,
std::string_view replacement,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::replace_with_backrefs(strings, prog, replacement, cudf::get_default_stream(), mr);
return detail::replace_with_backrefs(strings, prog, replacement, stream, mr);
}

} // namespace strings
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/strings/replace/multi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,11 @@ std::unique_ptr<column> replace(strings_column_view const& input,
std::unique_ptr<column> replace(strings_column_view const& strings,
strings_column_view const& targets,
strings_column_view const& repls,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::replace(strings, targets, repls, cudf::get_default_stream(), mr);
return detail::replace(strings, targets, repls, stream, mr);
}

} // namespace strings
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/strings/replace/multi_re.cu
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,11 @@ std::unique_ptr<column> replace_re(strings_column_view const& strings,
std::vector<std::string> const& patterns,
strings_column_view const& replacements,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::replace_re(strings, patterns, replacements, flags, cudf::get_default_stream(), mr);
return detail::replace_re(strings, patterns, replacements, flags, stream, mr);
}

} // namespace strings
Expand Down
8 changes: 5 additions & 3 deletions cpp/src/strings/replace/replace.cu
Original file line number Diff line number Diff line change
Expand Up @@ -751,21 +751,23 @@ std::unique_ptr<column> replace_nulls(strings_column_view const& strings,
std::unique_ptr<column> replace(strings_column_view const& strings,
string_scalar const& target,
string_scalar const& repl,
int32_t maxrepl,
cudf::size_type maxrepl,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::replace(strings, target, repl, maxrepl, cudf::get_default_stream(), mr);
return detail::replace(strings, target, repl, maxrepl, stream, mr);
}

std::unique_ptr<column> replace_slice(strings_column_view const& strings,
string_scalar const& repl,
size_type start,
size_type stop,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::replace_slice(strings, repl, start, stop, cudf::get_default_stream(), mr);
return detail::replace_slice(strings, repl, start, stop, stream, mr);
}

} // namespace strings
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/strings/replace/replace_re.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ std::unique_ptr<column> replace_re(strings_column_view const& strings,
regex_program const& prog,
string_scalar const& replacement,
std::optional<size_type> max_replace_count,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::replace_re(
strings, prog, replacement, max_replace_count, cudf::get_default_stream(), mr);
return detail::replace_re(strings, prog, replacement, max_replace_count, stream, mr);
}

} // namespace strings
Expand Down
10 changes: 8 additions & 2 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -633,8 +633,14 @@ ConfigureTest(STREAM_REPLACE_TEST streams/replace_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_SEARCH_TEST streams/search_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_DICTIONARY_TEST streams/dictionary_test.cpp STREAM_MODE testing)
ConfigureTest(
STREAM_STRINGS_TEST streams/strings/case_test.cpp streams/strings/find_test.cpp
streams/strings/split_test.cpp streams/strings/strings_tests.cpp STREAM_MODE testing
STREAM_STRINGS_TEST
streams/strings/case_test.cpp
streams/strings/find_test.cpp
streams/strings/replace_test.cpp
streams/strings/split_test.cpp
streams/strings/strings_tests.cpp
STREAM_MODE
testing
)
ConfigureTest(STREAM_SORTING_TEST streams/sorting_test.cpp STREAM_MODE testing)
ConfigureTest(STREAM_TEXT_TEST streams/text/ngrams_test.cpp STREAM_MODE testing)
Expand Down
80 changes: 80 additions & 0 deletions cpp/tests/streams/strings/replace_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/default_stream.hpp>

#include <cudf/strings/regex/regex_program.hpp>
#include <cudf/strings/replace.hpp>
#include <cudf/strings/replace_re.hpp>

#include <string>

class StringsReplaceTest : public cudf::test::BaseFixture {};

TEST_F(StringsReplaceTest, Replace)
{
auto input = cudf::test::strings_column_wrapper({"Héllo", "thesé", "tést strings", ""});
auto view = cudf::strings_column_view(input);

auto const target = cudf::string_scalar("é", true, cudf::test::get_default_stream());
auto const repl = cudf::string_scalar(" ", true, cudf::test::get_default_stream());
cudf::strings::replace(view, target, repl, -1, cudf::test::get_default_stream());
cudf::strings::replace(view, view, view, cudf::test::get_default_stream());
cudf::strings::replace_slice(view, repl, 1, 2, cudf::test::get_default_stream());

auto const pattern = std::string("[a-z]");
auto const prog = cudf::strings::regex_program::create(pattern);
cudf::strings::replace_re(view, *prog, repl, 1, cudf::test::get_default_stream());

cudf::test::strings_column_wrapper repls({"1", "a", " "});
cudf::strings::replace_re(view,
{pattern, pattern, pattern},
cudf::strings_column_view(repls),
cudf::strings::regex_flags::DEFAULT,
cudf::test::get_default_stream());
}

TEST_F(StringsReplaceTest, ReplaceRegex)
{
auto input = cudf::test::strings_column_wrapper({"Héllo", "thesé", "tést strings", ""});
auto view = cudf::strings_column_view(input);

auto const repl = cudf::string_scalar(" ", true, cudf::test::get_default_stream());
auto const pattern = std::string("[a-z]");
auto const prog = cudf::strings::regex_program::create(pattern);
cudf::strings::replace_re(view, *prog, repl, 1, cudf::test::get_default_stream());

cudf::test::strings_column_wrapper repls({"1", "a", " "});
cudf::strings::replace_re(view,
{pattern, pattern, pattern},
cudf::strings_column_view(repls),
cudf::strings::regex_flags::DEFAULT,
cudf::test::get_default_stream());
}

TEST_F(StringsReplaceTest, ReplaceRegexBackref)
{
auto input = cudf::test::strings_column_wrapper({"Héllo thesé", "tést strings"});
auto view = cudf::strings_column_view(input);

auto const repl_template = std::string("\\2-\\1");
auto const pattern = std::string("(\\w) (\\w)");
auto const prog = cudf::strings::regex_program::create(pattern);
cudf::strings::replace_with_backrefs(
view, *prog, repl_template, cudf::test::get_default_stream());
}

0 comments on commit fa4e8ab

Please sign in to comment.