Skip to content

Commit

Permalink
[SPARK-12159][ML] Add user guide section for IndexToString transformer
Browse files Browse the repository at this point in the history
Documentation regarding the `IndexToString` label transformer with code snippets in Scala/Java/Python.

Author: BenFradet <[email protected]>

Closes #10166 from BenFradet/SPARK-12159.

(cherry picked from commit 06746b3)
Signed-off-by: Joseph K. Bradley <[email protected]>
  • Loading branch information
BenFradet authored and jkbradley committed Dec 8, 2015
1 parent 7e45feb commit 3e31e7e
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 16 deletions.
104 changes: 88 additions & 16 deletions docs/ml-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -835,10 +835,10 @@ dctDf.select("featuresDCT").show(3);
`StringIndexer` encodes a string column of labels to a column of label indices.
The indices are in `[0, numLabels)`, ordered by label frequencies.
So the most frequent label gets index `0`.
If the input column is numeric, we cast it to string and index the string
values. When downstream pipeline components such as `Estimator` or
`Transformer` make use of this string-indexed label, you must set the input
column of the component to this string-indexed column name. In many cases,
If the input column is numeric, we cast it to string and index the string
values. When downstream pipeline components such as `Estimator` or
`Transformer` make use of this string-indexed label, you must set the input
column of the component to this string-indexed column name. In many cases,
you can set the input column with `setInputCol`.

**Examples**
Expand Down Expand Up @@ -951,9 +951,78 @@ indexed.show()
</div>
</div>


## IndexToString

Symmetrically to `StringIndexer`, `IndexToString` maps a column of label indices
back to a column containing the original labels as strings. The common use case
is to produce indices from labels with `StringIndexer`, train a model with those
indices and retrieve the original labels from the column of predicted indices
with `IndexToString`. However, you are free to supply your own labels.

**Examples**

Building on the `StringIndexer` example, let's assume we have the following
DataFrame with columns `id` and `categoryIndex`:

~~~~
id | categoryIndex
----|---------------
0 | 0.0
1 | 2.0
2 | 1.0
3 | 0.0
4 | 0.0
5 | 1.0
~~~~

Applying `IndexToString` with `categoryIndex` as the input column,
`originalCategory` as the output column, we are able to retrieve our original
labels (they will be inferred from the columns' metadata):

~~~~
id | categoryIndex | originalCategory
----|---------------|-----------------
0 | 0.0 | a
1 | 2.0 | b
2 | 1.0 | c
3 | 0.0 | a
4 | 0.0 | a
5 | 1.0 | c
~~~~

<div class="codetabs">
<div data-lang="scala" markdown="1">

Refer to the [IndexToString Scala docs](api/scala/index.html#org.apache.spark.ml.feature.IndexToString)
for more details on the API.

{% include_example scala/org/apache/spark/examples/ml/IndexToStringExample.scala %}

</div>

<div data-lang="java" markdown="1">

Refer to the [IndexToString Java docs](api/java/org/apache/spark/ml/feature/IndexToString.html)
for more details on the API.

{% include_example java/org/apache/spark/examples/ml/JavaIndexToStringExample.java %}

</div>

<div data-lang="python" markdown="1">

Refer to the [IndexToString Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IndexToString)
for more details on the API.

{% include_example python/ml/index_to_string_example.py %}

</div>
</div>

## OneHotEncoder

[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features

<div class="codetabs">
<div data-lang="scala" markdown="1">
Expand All @@ -979,10 +1048,11 @@ val indexer = new StringIndexer()
.fit(df)
val indexed = indexer.transform(df)

val encoder = new OneHotEncoder().setInputCol("categoryIndex").
setOutputCol("categoryVec")
val encoder = new OneHotEncoder()
.setInputCol("categoryIndex")
.setOutputCol("categoryVec")
val encoded = encoder.transform(indexed)
encoded.select("id", "categoryVec").foreach(println)
encoded.select("id", "categoryVec").show()
{% endhighlight %}
</div>

Expand Down Expand Up @@ -1015,7 +1085,7 @@ JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
RowFactory.create(5, "c")
));
StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("category", DataTypes.StringType, false, Metadata.empty())
});
DataFrame df = sqlContext.createDataFrame(jrdd, schema);
Expand All @@ -1029,6 +1099,7 @@ OneHotEncoder encoder = new OneHotEncoder()
.setInputCol("categoryIndex")
.setOutputCol("categoryVec");
DataFrame encoded = encoder.transform(indexed);
encoded.select("id", "categoryVec").show();
{% endhighlight %}
</div>

Expand All @@ -1054,6 +1125,7 @@ model = stringIndexer.fit(df)
indexed = model.transform(df)
encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec")
encoded = encoder.transform(indexed)
encoded.select("id", "categoryVec").show()
{% endhighlight %}
</div>
</div>
Expand Down Expand Up @@ -1582,7 +1654,7 @@ from pyspark.mllib.linalg import Vectors

data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)]
df = sqlContext.createDataFrame(data, ["vector"])
transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]),
transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]),
inputCol="vector", outputCol="transformedVector")
transformer.transform(df).show()

Expand Down Expand Up @@ -1837,15 +1909,15 @@ for more details on the API.
sub-array of the original features. It is useful for extracting features from a vector column.

`VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column
whose values are selected via those indices. There are two types of indices,
whose values are selected via those indices. There are two types of indices,

1. Integer indices that represents the indices into the vector, `setIndices()`;

2. String indices that represents the names of features into the vector, `setNames()`.
2. String indices that represents the names of features into the vector, `setNames()`.
*This requires the vector column to have an `AttributeGroup` since the implementation matches on
the name field of an `Attribute`.*

Specification by integer and string are both acceptable. Moreover, you can use integer index and
Specification by integer and string are both acceptable. Moreover, you can use integer index and
string name simultaneously. At least one feature must be selected. Duplicate features are not
allowed, so there can be no overlap between selected indices and names. Note that if names of
features are selected, an exception will be threw out when encountering with empty input attributes.
Expand All @@ -1858,9 +1930,9 @@ followed by the selected names (in the order given).
Suppose that we have a DataFrame with the column `userFeatures`:

~~~
userFeatures
userFeatures
------------------
[0.0, 10.0, 0.5]
[0.0, 10.0, 0.5]
~~~

`userFeatures` is a vector column that contains three user features. Assuming that the first column
Expand All @@ -1874,7 +1946,7 @@ column named `features`:
[0.0, 10.0, 0.5] | [10.0, 0.5]
~~~

Suppose also that we have a potential input attributes for the `userFeatures`, i.e.
Suppose also that we have a potential input attributes for the `userFeatures`, i.e.
`["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them.

~~~
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;

// $example on$
import java.util.Arrays;

import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off$

public class JavaIndexToStringExample {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaIndexToStringExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(jsc);

// $example on$
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
RowFactory.create(0, "a"),
RowFactory.create(1, "b"),
RowFactory.create(2, "c"),
RowFactory.create(3, "a"),
RowFactory.create(4, "a"),
RowFactory.create(5, "c")
));
StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("category", DataTypes.StringType, false, Metadata.empty())
});
DataFrame df = sqlContext.createDataFrame(jrdd, schema);

StringIndexerModel indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.fit(df);
DataFrame indexed = indexer.transform(df);

IndexToString converter = new IndexToString()
.setInputCol("categoryIndex")
.setOutputCol("originalCategory");
DataFrame converted = converter.transform(indexed);
converted.select("id", "originalCategory").show();
// $example off$
jsc.stop();
}
}
45 changes: 45 additions & 0 deletions examples/src/main/python/ml/index_to_string_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import print_function

from pyspark import SparkContext
# $example on$
from pyspark.ml.feature import IndexToString, StringIndexer
# $example off$
from pyspark.sql import SQLContext

if __name__ == "__main__":
sc = SparkContext(appName="IndexToStringExample")
sqlContext = SQLContext(sc)

# $example on$
df = sqlContext.createDataFrame(
[(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")],
["id", "category"])

stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
model = stringIndexer.fit(df)
indexed = model.transform(df)

converter = IndexToString(inputCol="categoryIndex", outputCol="originalCategory")
converted = converter.transform(indexed)

converted.select("id", "originalCategory").show()
# $example off$

sc.stop()
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// scalastyle:off println
package org.apache.spark.examples.ml

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
// $example on$
import org.apache.spark.ml.feature.{StringIndexer, IndexToString}
// $example off$

object IndexToStringExample {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("IndexToStringExample")
val sc = new SparkContext(conf)

val sqlContext = SQLContext.getOrCreate(sc)

// $example on$
val df = sqlContext.createDataFrame(Seq(
(0, "a"),
(1, "b"),
(2, "c"),
(3, "a"),
(4, "a"),
(5, "c")
)).toDF("id", "category")

val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
.fit(df)
val indexed = indexer.transform(df)

val converter = new IndexToString()
.setInputCol("categoryIndex")
.setOutputCol("originalCategory")

val converted = converter.transform(indexed)
converted.select("id", "originalCategory").show()
// $example off$
sc.stop()
}
}
// scalastyle:on println

0 comments on commit 3e31e7e

Please sign in to comment.