Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added method comments explaining what MX_PRIMITIVE_TYPE is
Browse files Browse the repository at this point in the history
  • Loading branch information
piyushghai committed Jan 8, 2019
1 parent 155321f commit cbda37e
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,15 @@

package org.apache.mxnet

/**
* This defines the basic primitives we can use in Scala for mathematical
* computations in NDArrays. This gives us a flexibility to expand to
* more supported primitives in the future. Currently Float and Double
* are supported.
*/
object MX_PRIMITIVES {

/**
* This defines the basic primitives we can use in Scala for mathematical
* computations in NDArrays.This gives us a flexibility to expand to
* more supported primitives in the future. Currently Float and Double
* * are supported.
* are supported. The functions which accept MX_PRIMITIVE_TYPE as input can also accept
* plain old Float and Double data as inputs because of the underlying
* implicit conversion between primitives to MX_PRIMITIVE_TYPE.
*/
trait MX_PRIMITIVE_TYPE extends Ordered[MX_PRIMITIVE_TYPE]{

Expand All @@ -47,7 +43,7 @@ object MX_PRIMITIVES {
implicit object MX_PRIMITIVE_TYPE extends MXPrimitiveOrdering

/**
* Mimics Float in Scala.
* Wrapper over Float in Scala.
* @param data
*/
class MX_FLOAT(val data: Float) extends MX_PRIMITIVE_TYPE {
Expand All @@ -68,7 +64,7 @@ object MX_PRIMITIVES {
implicit def IntToMX_Float(d: Int): MX_FLOAT = new MX_FLOAT(d.toFloat)

/**
* Mimics Double in Scala.
* Wrapper over Double in Scala.
* @param data
*/
class MX_Double(val data: Double) extends MX_PRIMITIVE_TYPE {
Expand Down
87 changes: 86 additions & 1 deletion scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,29 @@ object NDArray extends NDArrayBase {
}


// Perform power operator
/**
* Perform power operation on NDArray. Returns result as NDArray
* @param lhs
* @param rhs
*/
def power(lhs: NDArray, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_power", Seq(lhs, rhs))
}

/**
* Perform scalar power operation on NDArray. Returns result as NDArray
* @param lhs NDArray on which to perform the operation on.
* @param rhs The scalar input. Can be of type Float/Double
*/
def power(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(lhs, rhs))
}

/**
* Perform scalar power operation on NDArray. Returns result as NDArray
* @param lhs The scalar input. Can be of type Float/Double
* @param rhs NDArray on which to perform the operation on.
*/
def power(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_rpower_scalar", Seq(lhs, rhs))
}
Expand All @@ -297,10 +311,20 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("_maximum", Seq(lhs, rhs))
}

/**
* Perform the max operation on NDArray. Returns the result as NDArray.
* @param lhs NDArray on which to perform the operation on.
* @param rhs The scalar input. Can be of type Float/Double
*/
def maximum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs))
}

/**
* Perform the max operation on NDArray. Returns the result as NDArray.
* @param lhs The scalar input. Can be of type Float/Double
* @param rhs NDArray on which to perform the operation on.
*/
def maximum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs))
}
Expand All @@ -310,10 +334,20 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("_minimum", Seq(lhs, rhs))
}

/**
* Perform the min operation on NDArray. Returns the result as NDArray.
* @param lhs NDArray on which to perform the operation on.
* @param rhs The scalar input. Can be of type Float/Double
*/
def minimum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs))
}

/**
* Perform the min operation on NDArray. Returns the result as NDArray.
* @param lhs The scalar input. Can be of type Float/Double
* @param rhs NDArray on which to perform the operation on.
*/
def minimum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs))
}
Expand All @@ -327,6 +361,14 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_equal", Seq(lhs, rhs))
}

/**
* Returns the result of element-wise **equal to** (==) comparison operation with broadcasting.
* For each element in input arrays, return 1(true) if corresponding elements are same,
* otherwise return 0(false).
*
* @param lhs NDArray
* @param rhs The scalar input. Can be of type Float/Double
*/
def equal(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_equal_scalar", Seq(lhs, rhs))
}
Expand All @@ -341,6 +383,14 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_not_equal", Seq(lhs, rhs))
}

/**
* Returns the result of element-wise **not equal to** (!=) comparison operation
* with broadcasting.
* For each element in input arrays, return 1(true) if corresponding elements are different,
* otherwise return 0(false).
* @param lhs NDArray
* @param rhs The scalar input. Can be of type Float/Double
*/
def notEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_not_equal_scalar", Seq(lhs, rhs))
}
Expand All @@ -355,6 +405,15 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_greater", Seq(lhs, rhs))
}

/**
* Returns the result of element-wise **greater than** (>) comparison operation
* with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are greater than rhs,
* otherwise return 0(false).
*
* @param lhs NDArray
* @param rhs The scalar input. Can be of type Float/Double
*/
def greater(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_greater_scalar", Seq(lhs, rhs))
}
Expand All @@ -369,6 +428,15 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_greater_equal", Seq(lhs, rhs))
}

/**
* Returns the result of element-wise **greater than or equal to** (>=) comparison
* operation with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are greater than equal to
* rhs, otherwise return 0(false).
*
* @param lhs NDArray
* @param rhs The scalar input. Can be of type Float/Double
*/
def greaterEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_greater_equal_scalar", Seq(lhs, rhs))
}
Expand All @@ -383,6 +451,14 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_lesser", Seq(lhs, rhs))
}

/**
* Returns the result of element-wise **lesser than** (<) comparison operation
* with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are less than rhs,
* otherwise return 0(false).
* @param lhs NDArray
* @param rhs The scalar input. Can be of type Float/Double
*/
def lesser(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_lesser_scalar", Seq(lhs, rhs))
}
Expand All @@ -397,6 +473,15 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_lesser_equal", Seq(lhs, rhs))
}

/**
* Returns the result of element-wise **lesser than or equal to** (<=) comparison
* operation with broadcasting.
* For each element in input arrays, return 1(true) if lhs elements are
* lesser than equal to rhs, otherwise return 0(false).
*
* @param lhs NDArray
* @param rhs The scalar input. Can be of type Float/Double
*/
def lesserEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_lesser_equal_scalar", Seq(lhs, rhs))
}
Expand Down

0 comments on commit cbda37e

Please sign in to comment.