Skip to content

Commit

Permalink
add config to switch
Browse files Browse the repository at this point in the history
Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven committed Apr 8, 2024
1 parent 23b8dbf commit 5682864
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 45 deletions.
7 changes: 4 additions & 3 deletions integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
else:
pytestmark = pytest.mark.regexp

_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True }
_regexp_conf = { 'spark.rapids.sql.regexp.enabled': True,
'spark.rapids.sql.rLikeRegexRewrite.enabled': False}

def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')
Expand Down Expand Up @@ -445,8 +446,8 @@ def test_regexp_like():
conf=_regexp_conf)

@pytest.mark.skipif(is_before_spark_320(), reason='regexp_like is synonym for RLike starting in Spark 3.2.0')
def test_regexp_rlike_startswith():
gen = mk_str_gen('[abcd]{3,4}[0-9]{0,2}')
def test_regexp_rlike_contains():
gen = mk_str_gen('[abcd]{3,6}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).selectExpr(
'a',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,12 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.booleanConf
.createWithDefault(true)

val ENABLE_RLIKE_REGEX_REWRITE = conf("spark.rapids.sql.rLikeRegexRewrite.enabled")
.doc("Enable the optimization to rewrite rlike regex to contains in some cases.")
.internal()
.booleanConf
.createWithDefault(true)

// FILE FORMATS
val MULTITHREAD_READ_NUM_THREADS = conf("spark.rapids.sql.multiThreadedRead.numThreads")
.doc("The maximum number of threads on each executor to use for reading small " +
Expand Down Expand Up @@ -2570,6 +2576,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val isTieredProjectEnabled: Boolean = get(ENABLE_TIERED_PROJECT)

lazy val isRlikeRegexRewriteEnabled: Boolean = get(ENABLE_RLIKE_REGEX_REWRITE)

lazy val isExpandPreprojectEnabled: Boolean = get(ENABLE_EXPAND_PREPROJECT)

lazy val multiThreadReadNumThreads: Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,39 @@ class GpuRLikeMeta(

private var pattern: Option[String] = None

val specialChars = Seq('^', '$', '.', '|', '*', '?', '+', '[', ']', '{', '}')

val startWithSuffix = "([^\n\r\u0085\u2028\u2029]*)"

// val endWithPatterns = Seq(".*$", "(.*)$")
// val startWithPatterns = Seq("^.*", "^(.*)")
// val allMatchPatterns = Seq(".*", "(.*)")

def isSimplePattern(pattern: String): Boolean = {
pattern.forall(c => !specialChars.contains(c))
}

def removeBrackets(pattern: String): String = {
if (pattern.startsWith("(") && pattern.endsWith(")")) {
pattern.substring(1, pattern.length - 1)
} else {
pattern
}
}

def optimizeSimplePattern(rhs: Expression, lhs: Expression, pattern: String): GpuExpression = {
// check if the pattern is end with startWithSuffix
if (conf.isRlikeRegexRewriteEnabled && pattern.endsWith(startWithSuffix)) {
val startWithPattern = removeBrackets(pattern.stripSuffix(startWithSuffix))
if (isSimplePattern(startWithPattern)) {
// println(s"Optimizing $pattern to GpuContains $startWithPattern")
return GpuContains(lhs, GpuLiteral(startWithPattern, StringType))
}
}
// println(s"Optimizing $pattern to gpurlike")
GpuRLike(lhs, rhs, pattern)
}

override def tagExprForGpu(): Unit = {
GpuRegExpUtils.tagForRegExpEnabled(this)
expr.right match {
Expand All @@ -1086,49 +1119,8 @@ class GpuRLikeMeta(
throw new IllegalStateException("Expression has not been tagged with cuDF regex pattern"))
// if the pattern can be converted to a startswith or endswith pattern, we can use
// GpuStartsWith or GpuEndsWith instead to get better performance
GpuRLike.optimizeSimplePattern(rhs, lhs, patternStr)
}
}

object GpuRLike {

// // '(' and ')' are allowed
val specialChars = Seq('^', '$', '.', '|', '*', '?', '+', '[', ']', '{', '}')

val startWithSuffix = "([^\n\r\u0085\u2028\u2029]*)"

// val endWithPatterns = Seq(".*$", "(.*)$")
// val startWithPatterns = Seq("^.*", "^(.*)")
// val allMatchPatterns = Seq(".*", "(.*)")

def isSimplePattern(pattern: String): Boolean = {
pattern.forall(c => !specialChars.contains(c))
}

def removeBrackets(pattern: String): String = {
if (pattern.startsWith("(") && pattern.endsWith(")")) {
pattern.substring(1, pattern.length - 1)
} else {
pattern
optimizeSimplePattern(rhs, lhs, patternStr)
}
}

def optimizeSimplePattern(rhs: Expression, lhs: Expression, pattern: String): GpuExpression = {
// check if the pattern is end with startWithSuffix
if (pattern.endsWith(startWithSuffix)) {
val startWithPattern = removeBrackets(pattern.stripSuffix(startWithSuffix))
if (isSimplePattern(startWithPattern)) {
// println(s"Optimizing $pattern to GpuStartsWith $startWithPattern")
GpuStartsWith(lhs, GpuLiteral(startWithPattern, StringType))
} else {
// println(s"Optimizing $pattern to gpurlike")
GpuRLike(lhs, rhs, pattern)
}
} else {
// println(s"Optimizing $pattern to gpurlike")
GpuRLike(lhs, rhs, pattern)
}
}
}

case class GpuRLike(left: Expression, right: Expression, pattern: String)
Expand Down

0 comments on commit 5682864

Please sign in to comment.