Skip to content

Commit

Permalink
[SPARK-44349][R] Add math functions to SparkR
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add following math functions to `SparkR`:

- ~~e~~ (discard together with `pi`, they are both math constants)
- ~~pi~~ (conflict with R's built-in `pi` constant)
- std
- ln
- negative
- positive
- ~~pow~~ (discard together with `power`)
- ~~power~~ (conflict with R's built-in `power` function)
- width_bucket

### Why are the changes needed?
for parity

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
updated UT

Closes #41914 from zhengruifeng/math_r.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jul 12, 2023
1 parent 482497c commit 8fd6d4d
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 0 deletions.
5 changes: 5 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ exportMethods("%<=>%",
"levenshtein",
"like",
"lit",
"ln",
"locate",
"log",
"log10",
Expand Down Expand Up @@ -374,6 +375,7 @@ exportMethods("%<=>%",
"n_distinct",
"nanvl",
"negate",
"negative",
"next_day",
"not",
"nth_value",
Expand All @@ -387,6 +389,7 @@ exportMethods("%<=>%",
"pmod",
"posexplode",
"posexplode_outer",
"positive",
"product",
"quarter",
"radians",
Expand Down Expand Up @@ -429,6 +432,7 @@ exportMethods("%<=>%",
"soundex",
"spark_partition_id",
"split_string",
"std",
"stddev",
"stddev_pop",
"stddev_samp",
Expand Down Expand Up @@ -467,6 +471,7 @@ exportMethods("%<=>%",
"vector_to_array",
"weekofyear",
"when",
"width_bucket",
"window",
"withField",
"xxhash64",
Expand Down
72 changes: 72 additions & 0 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,18 @@ setMethod("log",
column(jc)
})

#' @details
#' \code{ln}: Alias for \code{log}.
#'
#' @rdname column_math_functions
#' @aliases ln ln,Column-method
#' @note ln since 3.5.0
setMethod("ln",
signature(x = "Column"),
function(x) {
log(x)
})

#' @details
#' \code{log10}: Computes the logarithm of the given value in base 10.
#'
Expand Down Expand Up @@ -1677,6 +1689,54 @@ setMethod("negate",
column(jc)
})

#' @details
#' \code{negative}: Alias for \code{negate}.
#'
#' @rdname column_nonaggregate_functions
#' @aliases negative negative,Column-method
#' @note negative since 3.5.0
setMethod("negative",
signature(x = "Column"),
function(x) {
negate(x)
})

#' @details
#' \code{positive}: Unary plus, i.e. return the expression.
#'
#' @rdname column_nonaggregate_functions
#' @aliases positive positive,Column-method
#' @note positive since 3.5.0
setMethod("positive",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "positive", x@jc)
column(jc)
})

#' @details
#' \code{width_bucket} Returns the bucket number into which the value of this expression would
#' fall after being evaluated. Note that input arguments must follow conditions listed below;
#' otherwise, the method will return null.
#'
#' @param v value to compute a bucket number in the histogram.
#' @param min minimum value of the histogram
#' @param max maximum value of the histogram
#' @param numBucket the number of buckets
#'
#' @rdname column_math_functions
#' @aliases width_bucket width_bucket,Column-method
#' @note width_bucket since 3.5.0
setMethod("width_bucket",
signature(v = "Column", min = "Column", max = "Column", numBucket = "Column"),
function(v, min, max, numBucket) {
jc <- callJStatic(
"org.apache.spark.sql.functions", "width_bucket",
v@jc, min@jc, max@jc, numBucket@jc
)
column(jc)
})

#' @details
#' \code{octet_length}: Calculates the byte length for the specified string column.
#'
Expand Down Expand Up @@ -2031,6 +2091,18 @@ setMethod("stddev",
column(jc)
})

#' @details
#' \code{std}: Alias for \code{stddev}.
#'
#' @rdname column_aggregate_functions
#' @aliases std std,Column-method
#' @note std since 3.5.0
setMethod("std",
signature(x = "Column"),
function(x) {
stddev(x)
})

#' @details
#' \code{stddev_pop}: Returns the population standard deviation of the expression in a group.
#'
Expand Down
21 changes: 21 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,10 @@ setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") })
#' @name NULL
setGeneric("lit", function(x) { standardGeneric("lit") })

#' @rdname column_math_functions
#' @name NULL
setGeneric("ln", function(x) { standardGeneric("ln") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") })
Expand Down Expand Up @@ -1248,6 +1252,10 @@ setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") })
#' @name NULL
setGeneric("negate", function(x) { standardGeneric("negate") })

#' @rdname column_nonaggregate_functions
#' @name NULL
setGeneric("negative", function(x) { standardGeneric("negative") })

#' @rdname not
setGeneric("not", function(x) { standardGeneric("not") })

Expand Down Expand Up @@ -1284,6 +1292,10 @@ setGeneric("percentile_approx",
#' @name NULL
setGeneric("pmod", function(y, x) { standardGeneric("pmod") })

#' @rdname column_nonaggregate_functions
#' @name NULL
setGeneric("positive", function(x) { standardGeneric("positive") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("posexplode", function(x) { standardGeneric("posexplode") })
Expand Down Expand Up @@ -1441,6 +1453,10 @@ setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spar
#' @name NULL
setGeneric("stddev", function(x) { standardGeneric("stddev") })

#' @rdname column_aggregate_functions
#' @name NULL
setGeneric("std", function(x) { standardGeneric("std") })

#' @rdname column_aggregate_functions
#' @name NULL
setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") })
Expand Down Expand Up @@ -1565,6 +1581,11 @@ setGeneric("vector_to_array", function(x, ...) { standardGeneric("vector_to_arra
#' @name NULL
setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") })

#' @rdname column_math_functions
#' @name NULL
setGeneric("width_bucket",
function(v, min, max, numBucket) { standardGeneric("width_bucket") })

#' @rdname column_datetime_functions
#' @name NULL
setGeneric("window", function(x, ...) { standardGeneric("window") })
Expand Down
2 changes: 2 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,8 @@ test_that("column functions", {
c29 <- acosh(c1) + asinh(c1) + atanh(c1)
c30 <- product(c1) + product(c1 * 0.5)
c31 <- sec(c1) + csc(c1) + cot(c1)
c32 <- ln(c1) + positive(c2) + negative(c3)
c33 <- width_bucket(lit(2.5), lit(2.0), lit(3.0), lit(10L))

# Test if base::is.nan() is exposed
expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE))
Expand Down

0 comments on commit 8fd6d4d

Please sign in to comment.