Skip to content

Commit

Permalink
Merge pull request apache#61 from tqchen/master
Browse files Browse the repository at this point in the history
Add one hot encoder
  • Loading branch information
tqchen committed Oct 21, 2015
2 parents bcc19fc + 4962811 commit 129e060
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 4 deletions.
12 changes: 11 additions & 1 deletion guide/basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,17 @@ int main(void) {
for (index_t i = 0; i < choosed.size(0); ++i) {
printf("%.2f ", choosed[i]);
}
printf("\n ");
printf("\n");

rhs = one_hot_encode(index, 3);

for (index_t i = 0; i < lhs.size(0); ++i) {
for (index_t j = 0; j < lhs.size(1); ++j) {
printf("%.2f ", rhs[i][j]);
}
printf("\n");
}
printf("\n");

// shutdown tensor enigne after usage
ShutdownTensorEngine<cpu>();
Expand Down
1 change: 1 addition & 0 deletions mshadow/extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
#include "./extension/concat.h"
#include "./extension/implicit_gemm.h"
#include "./extension/choose.h"
#include "./extension/one_hot.h"
#endif // MSHADOW_EXTENSION_H_
3 changes: 0 additions & 3 deletions mshadow/extension/choose.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#define MSHADOW_EXTENSION_CHOOSE_H_

#include "../extension.h"
#include "../packet-inl.h"

namespace mshadow {
namespace expr {
Expand Down Expand Up @@ -89,5 +88,3 @@ struct ExpInfo<MatChooseRowElementExp<SrcExp, IndexExp, DType> > {
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_CHOOSE_H_


87 changes: 87 additions & 0 deletions mshadow/extension/one_hot.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*!
* Copyright (c) 2014 by Contributors
* \file one_hot.h
* \brief Create one-hot indicator array based on the index.
* \author Tianqi Chen
*/
#ifndef MSHADOW_EXTENSION_ONE_HOT_H_
#define MSHADOW_EXTENSION_ONE_HOT_H_

#include "../extension.h"


namespace mshadow {
namespace expr {
/*!
* \brief Create a one-hot indicator array.
* \tparam IndexExp type of index expression
* \tparam DType the type of elements
*/
template<typename IndexExp, typename DType>
struct OneHotEncodeExp:
public Exp<OneHotEncodeExp<IndexExp, DType>,
DType, type::kChainer> {
/*! \brief index operand */
const IndexExp &index_;
/*! \brief number of choices we can have. */
index_t num_choices_;
/*! \brief constructor */
OneHotEncodeExp(const IndexExp &index, index_t num_choices)
: index_(index), num_choices_(num_choices) {}
};

template<typename IndexExp,
typename IDType, int e1>
inline OneHotEncodeExp<IndexExp, default_real_t>
one_hot_encode(const Exp<IndexExp, IDType, e1> &index, index_t num_choices) {
TypeCheckPass<ExpInfo<IndexExp>::kDim == 1>
::Error_Expression_Does_Not_Meet_Dimension_Req();
return OneHotEncodeExp<IndexExp, default_real_t>(index.self(), num_choices);
}

//----------------------
// Execution plan
//----------------------
template<typename IndexExp, typename DType>
struct Plan<OneHotEncodeExp<IndexExp, DType>, DType> {
public:
explicit Plan(const OneHotEncodeExp<IndexExp, DType> &e)
: index_(MakePlan(e.index_)) {
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
index_t idx = static_cast<index_t>(index_.Eval(0, y));
return static_cast<DType>(x == idx);
}

private:
expr::Plan<IndexExp, DType> index_;
};

template<typename IndexExp, typename DType>
inline Plan<OneHotEncodeExp<IndexExp, DType>, DType>
MakePlan(const OneHotEncodeExp<IndexExp, DType> &exp) {
return Plan<OneHotEncodeExp<IndexExp, DType>, DType>(exp);
}

template<int dim, typename IndexExp, typename DType>
struct ShapeCheck<dim, OneHotEncodeExp<IndexExp, DType> > {
inline static Shape<dim>
Check(const OneHotEncodeExp<IndexExp, DType> &t) {
CHECK(dim == 2)
<< "OneHotEncodeExp only support 2 dimension output";
Shape<1> shape = ShapeCheck<1, IndexExp>::Check(t.index_);
Shape<dim> ret;
ret[0] = shape[0];
ret[1] = t.num_choices_;
return ret;
}
};

template<typename IndexExp, typename DType>
struct ExpInfo<OneHotEncodeExp<IndexExp, DType> > {
static const int kDim = 2;
static const int kDevMask = ExpInfo<IndexExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_ONE_HOT_H_

0 comments on commit 129e060

Please sign in to comment.