Skip to content

Commit

Permalink
Merge branch 'mengxr-SPARK-5566'
Browse files Browse the repository at this point in the history
  • Loading branch information
sagacifyTestUser committed Mar 25, 2015
2 parents 5f09434 + cb07021 commit 716d257
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 74 deletions.
21 changes: 10 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
* param sets regex as splitting on gaps (true) or matching tokens (false)
* @group param
*/
val gaps: BooleanParam = new BooleanParam(this, "gaps",
"Set regex to match gaps or tokens", Some(false))
val gaps: BooleanParam = new BooleanParam(
this, "gaps", "Set regex to match gaps or tokens", Some(false))

/** @group setParam */
def setGaps(value: Boolean): this.type = set(gaps, value)
Expand All @@ -81,21 +81,20 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
* param sets regex pattern used by tokenizer
* @group param
*/
val pattern: Param[scala.util.matching.Regex] = new Param(this, "pattern",
"regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+".r))
val pattern: Param[String] = new Param(
this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+"))

/** @group setParam */
def setPattern(value: String): this.type = set(pattern, value.r)
def setPattern(value: String): this.type = set(pattern, value)

/** @group getParam */
def getPattern: String = get(pattern).toString
def getPattern: String = get(pattern)

override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str =>

val re = paramMap(pattern)
val tokens = if(paramMap(gaps)) re.split(str).toList else (re.findAllIn(str)).toList

tokens.filter(_.length >= paramMap(minTokenLength))
val re = paramMap(pattern).r
val tokens = if (paramMap(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
val minLength = paramMap(minTokenLength)
tokens.filter(_.length >= minLength)
}

override protected def validateInputType(inputType: DataType): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.spark.ml.feature;

import java.util.Arrays;
import java.util.List;

import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
Expand Down Expand Up @@ -48,26 +46,26 @@ public void tearDown() {
}

@Test
public void RegexTokenizer() {
public void regexTokenizer() {
RegexTokenizer myRegExTokenizer = new RegexTokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")
.setPattern("\\s")
.setGaps(true)
.setMinTokenLength(0);

List<String> t = Arrays.asList(
"{\"rawText\": \"Test of tok.\", \"wantedTokens\": [\"Test\", \"of\", \"tok.\"]}",
"{\"rawText\": \"Te,st. punct\", \"wantedTokens\": [\"Te,st.\", \"\", \"punct\"]}");
.setMinTokenLength(3);

JavaRDD<String> myRdd = jsc.parallelize(t);
DataFrame dataset = jsql.jsonRDD(myRdd);
JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Lists.newArrayList(
new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
));
DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);

Row[] pairs = myRegExTokenizer.transform(dataset)
.select("tokens","wantedTokens")
.select("tokens", "wantedTokens")
.collect();

Assert.assertEquals(pairs[0].get(0), pairs[0].get(1));
Assert.assertEquals(pairs[1].get(0), pairs[1].get(1));
for (Row r : pairs) {
Assert.assertEquals(r.get(0), r.get(1));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@

package org.apache.spark.ml.feature

import scala.beans.BeanInfo

import org.scalatest.FunSuite

import org.apache.spark.SparkException
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

case class TextData(rawText: String, wantedTokens: Seq[String])
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) {
/** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */
def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq)
}

class TokenizerSuite extends FunSuite with MLlibTestSparkContext {
class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._

@transient var sqlContext: SQLContext = _

Expand All @@ -35,66 +41,45 @@ class TokenizerSuite extends FunSuite with MLlibTestSparkContext {
}

test("RegexTokenizer") {
val myRegExTokenizer = new RegexTokenizer()
val tokenizer = new RegexTokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")

var dataset = sqlContext.createDataFrame(
sc.parallelize(Seq(
TextData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")),
TextData("Te,st. punct", Seq("Te", ",", "st", ".", "punct"))
)))
testRegexTokenizer(myRegExTokenizer, dataset)
val dataset0 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")),
TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct"))
))
testRegexTokenizer(tokenizer, dataset0)

dataset = sqlContext.createDataFrame(
sc.parallelize(Seq(
TextData("Test for tokenization.", Seq("Test", "for", "tokenization")),
TextData("Te,st. punct", Seq("punct"))
)))
myRegExTokenizer.asInstanceOf[RegexTokenizer]
.setMinTokenLength(3)
testRegexTokenizer(myRegExTokenizer, dataset)
val dataset1 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")),
TokenizerTestData("Te,st. punct", Seq("punct"))
))

myRegExTokenizer.asInstanceOf[RegexTokenizer]
tokenizer.setMinTokenLength(3)
testRegexTokenizer(tokenizer, dataset1)

tokenizer
.setPattern("\\s")
.setGaps(true)
.setMinTokenLength(0)
dataset = sqlContext.createDataFrame(
sc.parallelize(Seq(
TextData("Test for tokenization.", Seq("Test", "for", "tokenization.")),
TextData("Te,st. punct", Seq("Te,st.", "", "punct"))
)))
testRegexTokenizer(myRegExTokenizer, dataset)
}

test("Tokenizer") {
val oldTokenizer = new Tokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")
var dataset = sqlContext.createDataFrame(
sc.parallelize(Seq(
TextData("Test for tokenization.", Seq("test", "for", "tokenization.")),
TextData("Te,st. punct", Seq("te,st.", "", "punct"))
)))
testTokenizer(oldTokenizer, dataset)
val dataset2 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")),
TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct"))
))
testRegexTokenizer(tokenizer, dataset2)
}
}

def testTokenizer(t: Tokenizer, dataset: DataFrame): Unit = {
t.transform(dataset)
.select("tokens", "wantedTokens")
.collect().foreach {
case Row(tokens: Seq[Any], wantedTokens: Seq[Any]) =>
assert(tokens === wantedTokens)
}
}
object RegexTokenizerSuite extends FunSuite {

def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
t.transform(dataset)
.select("tokens", "wantedTokens")
.collect().foreach {
case Row(tokens: Seq[Any], wantedTokens: Seq[Any]) =>
.collect()
.foreach {
case Row(tokens, wantedTokens) =>
assert(tokens === wantedTokens)
}
}
}

}

0 comments on commit 716d257

Please sign in to comment.