Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][FEATURE] SPARK-5566: RegEx Tokenizer #4504

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
01cd26f
RegExTokenizer
Feb 10, 2015
9547e9d
RegEx Tokenizer
Feb 10, 2015
9e07a78
Merge remote-tracking branch 'upstream/master'
Feb 11, 2015
9f8685a
RegexTokenizer
Feb 11, 2015
11ca50f
Merge remote-tracking branch 'upstream/master'
Feb 12, 2015
196cd7a
Merge remote-tracking branch 'upstream/master'
Feb 16, 2015
2e89719
Merge remote-tracking branch 'upstream/master'
Feb 17, 2015
77ff9ca
Merge remote-tracking branch 'upstream/master'
Feb 18, 2015
7f930bb
Merge remote-tracking branch 'upstream/master'
Mar 1, 2015
f6a5002
Merge remote-tracking branch 'upstream/master'
Mar 2, 2015
19f9e53
Merge remote-tracking branch 'upstream/master'
Mar 2, 2015
9082fc3
Removed stopwords parameters and updated doc
Mar 3, 2015
d3ef6d3
Added doc to RegexTokenizer
Mar 3, 2015
cb9c9a7
Merge remote-tracking branch 'upstream/master'
Mar 13, 2015
201a107
Merge remote-tracking branch 'upstream/master'
Mar 17, 2015
132b00b
Changed matching to gaps and removed case folding
Mar 17, 2015
cd6642e
Changed regex to pattern
Mar 17, 2015
e262bac
Added unit tests in scala
Mar 18, 2015
b66313f
Modified the pattern Param so it is compiled when given to the Tokenizer
Mar 19, 2015
38b95a1
Added Java unit test for RegexTokenizer
Mar 19, 2015
6a85982
Style corrections
Mar 19, 2015
daf685e
Merge remote-tracking branch 'upstream/master'
sagacifyTestUser Mar 20, 2015
12dddb4
Merge remote-tracking branch 'upstream/master'
sagacifyTestUser Mar 23, 2015
148126f
Added return type to public functions
sagacifyTestUser Mar 23, 2015
e88d7b8
change pattern to a StringParameter; update tests
mengxr Mar 23, 2015
2338da5
Merge remote-tracking branch 'upstream/master'
sagacifyTestUser Mar 24, 2015
f96526d
Merge remote-tracking branch 'apache/master' into SPARK-5566
mengxr Mar 24, 2015
9651aec
update test
mengxr Mar 24, 2015
556aa27
Merge branch 'aborsu985-master' into SPARK-5566
mengxr Mar 24, 2015
a164800
remove tabs
mengxr Mar 24, 2015
5f09434
Merge remote-tracking branch 'upstream/master'
sagacifyTestUser Mar 25, 2015
cb07021
Merge branch 'SPARK-5566' of git://github.com/mengxr/spark into mengx…
sagacifyTestUser Mar 25, 2015
716d257
Merge branch 'mengxr-SPARK-5566'
sagacifyTestUser Mar 25, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.{ParamMap, IntParam, BooleanParam, Param}
import org.apache.spark.sql.types.{DataType, StringType, ArrayType}

/**
Expand All @@ -39,3 +39,67 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {

override protected def outputDataType: DataType = new ArrayType(StringType, false)
}

/**
* :: AlphaComponent ::
* A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
* or using it to split the text (set matching to false). Optional parameters also allow to fold
* the text to lowercase prior to it being tokenized and to filer tokens using a minimal length.
* It returns an array of strings that can be empty.
* The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true,
* lowercase = false, minTokenLength = 1
*/
@AlphaComponent
class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {

/**
* param for minimum token length, default is one to avoid returning empty strings
* @group param
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please append : IntParam to minTokenLength. See SPARK-6428. Please also update other method declarations.

val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1))

/** @group setParam */
def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)

/** @group getParam */
def getMinTokenLength: Int = get(minTokenLength)

/**
* param sets regex as splitting on gaps (true) or matching tokens (false)
* @group param
*/
val gaps: BooleanParam = new BooleanParam(
this, "gaps", "Set regex to match gaps or tokens", Some(false))

/** @group setParam */
def setGaps(value: Boolean): this.type = set(gaps, value)

/** @group getParam */
def getGaps: Boolean = get(gaps)

/**
* param sets regex pattern used by tokenizer
* @group param
*/
val pattern: Param[String] = new Param(
this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+"))

/** @group setParam */
def setPattern(value: String): this.type = set(pattern, value)

/** @group getParam */
def getPattern: String = get(pattern)

override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str =>
val re = paramMap(pattern).r
val tokens = if (paramMap(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
val minLength = paramMap(minTokenLength)
tokens.filter(_.length >= minLength)
}

override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}

override protected def outputDataType: DataType = new ArrayType(StringType, false)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature;

import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

public class JavaTokenizerSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;

@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaTokenizerSuite");
jsql = new SQLContext(jsc);
}

@After
public void tearDown() {
jsc.stop();
jsc = null;
}

@Test
public void regexTokenizer() {
RegexTokenizer myRegExTokenizer = new RegexTokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")
.setPattern("\\s")
.setGaps(true)
.setMinTokenLength(3);

JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Lists.newArrayList(
new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
));
DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);

Row[] pairs = myRegExTokenizer.transform(dataset)
.select("tokens", "wantedTokens")
.collect();

for (Row r : pairs) {
Assert.assertEquals(r.get(0), r.get(1));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import scala.beans.BeanInfo

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) {
/** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */
def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra empty line

class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._

@transient var sqlContext: SQLContext = _

override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initialize dataset here to avoid duplicate code. If they are not the same in tests, please remove dataset from TokenizerSuite and define local variables in each test.

}

test("RegexTokenizer") {
val tokenizer = new RegexTokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")

val dataset0 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")),
TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct"))
))
testRegexTokenizer(tokenizer, dataset0)

val dataset1 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")),
TokenizerTestData("Te,st. punct", Seq("punct"))
))

tokenizer.setMinTokenLength(3)
testRegexTokenizer(tokenizer, dataset1)

tokenizer
.setPattern("\\s")
.setGaps(true)
.setMinTokenLength(0)
val dataset2 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")),
TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct"))
))
testRegexTokenizer(tokenizer, dataset2)
}
}

object RegexTokenizerSuite extends FunSuite {

def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
t.transform(dataset)
.select("tokens", "wantedTokens")
.collect()
.foreach {
case Row(tokens, wantedTokens) =>
assert(tokens === wantedTokens)
}
}
}