diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index e069f4c49c573..68401e36950bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -68,8 +68,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize * param sets regex as splitting on gaps (true) or matching tokens (false) * @group param */ - val gaps: BooleanParam = new BooleanParam(this, "gaps", - "Set regex to match gaps or tokens", Some(false)) + val gaps: BooleanParam = new BooleanParam( + this, "gaps", "Set regex to match gaps or tokens", Some(false)) /** @group setParam */ def setGaps(value: Boolean): this.type = set(gaps, value) @@ -81,21 +81,20 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize * param sets regex pattern used by tokenizer * @group param */ - val pattern: Param[scala.util.matching.Regex] = new Param(this, "pattern", - "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+".r)) + val pattern: Param[String] = new Param( + this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+")) /** @group setParam */ - def setPattern(value: String): this.type = set(pattern, value.r) + def setPattern(value: String): this.type = set(pattern, value) /** @group getParam */ - def getPattern: String = get(pattern).toString + def getPattern: String = get(pattern) override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str => - - val re = paramMap(pattern) - val tokens = if(paramMap(gaps)) re.split(str).toList else (re.findAllIn(str)).toList - - tokens.filter(_.length >= paramMap(minTokenLength)) + val re = paramMap(pattern).r + val tokens = if (paramMap(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq + val minLength = paramMap(minTokenLength) + tokens.filter(_.length >= minLength) } override protected def validateInputType(inputType: DataType): Unit = { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 41e0aba55745c..3806f650025b2 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -17,9 +17,7 @@ package org.apache.spark.ml.feature; -import java.util.Arrays; -import java.util.List; - +import com.google.common.collect.Lists; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -48,26 +46,26 @@ public void tearDown() { } @Test - public void RegexTokenizer() { + public void regexTokenizer() { RegexTokenizer myRegExTokenizer = new RegexTokenizer() .setInputCol("rawText") .setOutputCol("tokens") .setPattern("\\s") .setGaps(true) - .setMinTokenLength(0); - - List t = Arrays.asList( - "{\"rawText\": \"Test of tok.\", \"wantedTokens\": [\"Test\", \"of\", \"tok.\"]}", - "{\"rawText\": \"Te,st. punct\", \"wantedTokens\": [\"Te,st.\", \"\", \"punct\"]}"); + .setMinTokenLength(3); - JavaRDD myRdd = jsc.parallelize(t); - DataFrame dataset = jsql.jsonRDD(myRdd); + JavaRDD rdd = jsc.parallelize(Lists.newArrayList( + new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), + new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) + )); + DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); Row[] pairs = myRegExTokenizer.transform(dataset) - .select("tokens","wantedTokens") + .select("tokens", "wantedTokens") .collect(); - Assert.assertEquals(pairs[0].get(0), pairs[0].get(1)); - Assert.assertEquals(pairs[1].get(0), pairs[1].get(1)); + for (Row r : pairs) { + Assert.assertEquals(r.get(0), r.get(1)); + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index ffd18de2f7d02..bf862b912d326 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -17,15 +17,21 @@ package org.apache.spark.ml.feature +import scala.beans.BeanInfo + import org.scalatest.FunSuite -import org.apache.spark.SparkException import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row, SQLContext} -case class TextData(rawText: String, wantedTokens: Seq[String]) +@BeanInfo +case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) { + /** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */ + def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq) +} -class TokenizerSuite extends FunSuite with MLlibTestSparkContext { +class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { + import org.apache.spark.ml.feature.RegexTokenizerSuite._ @transient var sqlContext: SQLContext = _ @@ -35,66 +41,45 @@ class TokenizerSuite extends FunSuite with MLlibTestSparkContext { } test("RegexTokenizer") { - val myRegExTokenizer = new RegexTokenizer() + val tokenizer = new RegexTokenizer() .setInputCol("rawText") .setOutputCol("tokens") - var dataset = sqlContext.createDataFrame( - sc.parallelize(Seq( - TextData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")), - TextData("Te,st. punct", Seq("Te", ",", "st", ".", "punct")) - ))) - testRegexTokenizer(myRegExTokenizer, dataset) + val dataset0 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct")) + )) + testRegexTokenizer(tokenizer, dataset0) - dataset = sqlContext.createDataFrame( - sc.parallelize(Seq( - TextData("Test for tokenization.", Seq("Test", "for", "tokenization")), - TextData("Te,st. punct", Seq("punct")) - ))) - myRegExTokenizer.asInstanceOf[RegexTokenizer] - .setMinTokenLength(3) - testRegexTokenizer(myRegExTokenizer, dataset) + val dataset1 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")), + TokenizerTestData("Te,st. punct", Seq("punct")) + )) - myRegExTokenizer.asInstanceOf[RegexTokenizer] + tokenizer.setMinTokenLength(3) + testRegexTokenizer(tokenizer, dataset1) + + tokenizer .setPattern("\\s") .setGaps(true) .setMinTokenLength(0) - dataset = sqlContext.createDataFrame( - sc.parallelize(Seq( - TextData("Test for tokenization.", Seq("Test", "for", "tokenization.")), - TextData("Te,st. punct", Seq("Te,st.", "", "punct")) - ))) - testRegexTokenizer(myRegExTokenizer, dataset) - } - - test("Tokenizer") { - val oldTokenizer = new Tokenizer() - .setInputCol("rawText") - .setOutputCol("tokens") - var dataset = sqlContext.createDataFrame( - sc.parallelize(Seq( - TextData("Test for tokenization.", Seq("test", "for", "tokenization.")), - TextData("Te,st. punct", Seq("te,st.", "", "punct")) - ))) - testTokenizer(oldTokenizer, dataset) + val dataset2 = sqlContext.createDataFrame(Seq( + TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct")) + )) + testRegexTokenizer(tokenizer, dataset2) } +} - def testTokenizer(t: Tokenizer, dataset: DataFrame): Unit = { - t.transform(dataset) - .select("tokens", "wantedTokens") - .collect().foreach { - case Row(tokens: Seq[Any], wantedTokens: Seq[Any]) => - assert(tokens === wantedTokens) - } - } +object RegexTokenizerSuite extends FunSuite { def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { t.transform(dataset) .select("tokens", "wantedTokens") - .collect().foreach { - case Row(tokens: Seq[Any], wantedTokens: Seq[Any]) => + .collect() + .foreach { + case Row(tokens, wantedTokens) => assert(tokens === wantedTokens) - } + } } - }