Skip to content

Commit

Permalink
Added unit tests in scala
Browse files Browse the repository at this point in the history
Also changed RegexTokenizer so it extends Tokenizer Class instead of UnaryTransformer
It might be interesting to create a Tokenizer trait that could be used by all tokenizers
  • Loading branch information
Augustin Borsu committed Mar 18, 2015
1 parent cd6642e commit e262bac
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 6 deletions.
11 changes: 5 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
override protected def outputDataType: DataType = new ArrayType(StringType, false)
}


/**
* :: AlphaComponent ::
* A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
Expand All @@ -51,10 +50,10 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
* lowercase = false, minTokenLength = 1
*/
@AlphaComponent
class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
class RegexTokenizer extends Tokenizer {

/**
* param for minimum token length
* param for minimum token length, default is one to avoid returning empty strings
* @group param
*/
val minTokenLength = new IntParam(this, "minLength", "minimum token length", Some(1))
Expand All @@ -66,7 +65,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
def getMinTokenLength: Int = get(minTokenLength)

/**
* param sets regex as matching gaps(true) or tokens (false)
* param sets regex as splitting on gaps(true) or matching tokens (false)
* @group param
*/
val gaps = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens", Some(false))
Expand All @@ -78,7 +77,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
def getGaps: Boolean = get(gaps)

/**
* param sets regex used by tokenizer
* param sets regex pattern used by tokenizer
* @group param
*/
val pattern = new Param(this, "pattern",
Expand All @@ -95,7 +94,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
val re = paramMap(pattern)
val tokens = if(paramMap(gaps)) str.split(re).toList else (re.r.findAllIn(str)).toList

tokens.filter(_.length >= paramMap(minTokenLength))
tokens.filter(_.length >= paramMap(minTokenLength)).toSeq
}

override protected def validateInputType(inputType: DataType): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}


case class TextData(rawText : String,wantedTokens: Seq[String])
class TokenizerSuite extends FunSuite with MLlibTestSparkContext {

@transient var sqlContext: SQLContext = _
@transient var dataset: DataFrame = _

override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
}

test("RegexTokenizer"){
var myRegExTokenizer = new RegexTokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")

dataset = sqlContext.createDataFrame(
sc.parallelize(List(
TextData("Test for tokenization.",List("Test","for","tokenization",".")),
TextData("Te,st. punct",List("Te",",","st",".","punct"))
)))
testTokenizer(myRegExTokenizer,dataset)

dataset = sqlContext.createDataFrame(
sc.parallelize(List(
TextData("Test for tokenization.",List("Test","for","tokenization")),
TextData("Te,st. punct",List("punct"))
)))
myRegExTokenizer.asInstanceOf[RegexTokenizer]
.setMinTokenLength(3)
testTokenizer(myRegExTokenizer,dataset)

myRegExTokenizer.asInstanceOf[RegexTokenizer]
.setPattern("\\s")
.setGaps(true)
.setMinTokenLength(0)
dataset = sqlContext.createDataFrame(
sc.parallelize(List(
TextData("Test for tokenization.",List("Test","for","tokenization.")),
TextData("Te,st. punct",List("Te,st.","","punct"))
)))
testTokenizer(myRegExTokenizer,dataset)
}

test("Tokenizer"){
val oldTokenizer = new Tokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")
dataset = sqlContext.createDataFrame(
sc.parallelize(List(
TextData("Test for tokenization.",List("test","for","tokenization.")),
TextData("Te,st. punct",List("te,st.","","punct"))
)))
testTokenizer(oldTokenizer,dataset)
}

def testTokenizer(t: Tokenizer,dataset: DataFrame){
t.transform(dataset)
.select("tokens","wantedTokens")
.collect().foreach{
case Row(tokens: Seq[String], wantedTokens: Seq[String]) =>
assert(tokens.length == wantedTokens.length)
tokens.zip(wantedTokens).foreach(x => assert(x._1 == x._2))
case _ =>
println()
assert(false)
}
}
}

0 comments on commit e262bac

Please sign in to comment.