diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index ab22c26e41b04..fab89c42b1aca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -40,7 +40,6 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { override protected def outputDataType: DataType = new ArrayType(StringType, false) } - /** * :: AlphaComponent :: * A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default) @@ -51,10 +50,10 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { * lowercase = false, minTokenLength = 1 */ @AlphaComponent -class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] { +class RegexTokenizer extends Tokenizer { /** - * param for minimum token length + * param for minimum token length, default is one to avoid returning empty strings * @group param */ val minTokenLength = new IntParam(this, "minLength", "minimum token length", Some(1)) @@ -66,7 +65,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def getMinTokenLength: Int = get(minTokenLength) /** - * param sets regex as matching gaps(true) or tokens (false) + * param sets regex as splitting on gaps(true) or matching tokens (false) * @group param */ val gaps = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens", Some(false)) @@ -78,7 +77,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def getGaps: Boolean = get(gaps) /** - * param sets regex used by tokenizer + * param sets regex pattern used by tokenizer * @group param */ val pattern = new Param(this, "pattern", @@ -95,7 +94,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize val re = paramMap(pattern) val tokens = if(paramMap(gaps)) str.split(re).toList else (re.r.findAllIn(str)).toList - tokens.filter(_.length >= paramMap(minTokenLength)) + tokens.filter(_.length >= paramMap(minTokenLength)).toSeq } override protected def validateInputType(inputType: DataType): Unit = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala new file mode 100644 index 0000000000000..c19ea225940ee --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +case class TextData(rawText : String,wantedTokens: Seq[String]) +class TokenizerSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("RegexTokenizer"){ + var myRegExTokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + + dataset = sqlContext.createDataFrame( + sc.parallelize(List( + TextData("Test for tokenization.",List("Test","for","tokenization",".")), + TextData("Te,st. punct",List("Te",",","st",".","punct")) + ))) + testTokenizer(myRegExTokenizer,dataset) + + dataset = sqlContext.createDataFrame( + sc.parallelize(List( + TextData("Test for tokenization.",List("Test","for","tokenization")), + TextData("Te,st. punct",List("punct")) + ))) + myRegExTokenizer.asInstanceOf[RegexTokenizer] + .setMinTokenLength(3) + testTokenizer(myRegExTokenizer,dataset) + + myRegExTokenizer.asInstanceOf[RegexTokenizer] + .setPattern("\\s") + .setGaps(true) + .setMinTokenLength(0) + dataset = sqlContext.createDataFrame( + sc.parallelize(List( + TextData("Test for tokenization.",List("Test","for","tokenization.")), + TextData("Te,st. punct",List("Te,st.","","punct")) + ))) + testTokenizer(myRegExTokenizer,dataset) + } + + test("Tokenizer"){ + val oldTokenizer = new Tokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + dataset = sqlContext.createDataFrame( + sc.parallelize(List( + TextData("Test for tokenization.",List("test","for","tokenization.")), + TextData("Te,st. punct",List("te,st.","","punct")) + ))) + testTokenizer(oldTokenizer,dataset) + } + + def testTokenizer(t: Tokenizer,dataset: DataFrame){ + t.transform(dataset) + .select("tokens","wantedTokens") + .collect().foreach{ + case Row(tokens: Seq[String], wantedTokens: Seq[String]) => + assert(tokens.length == wantedTokens.length) + tokens.zip(wantedTokens).foreach(x => assert(x._1 == x._2)) + case _ => + println() + assert(false) + } + } +}