Skip to content

Commit

Permalink
Merge branch 'one_hot'
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanfeydy committed Sep 27, 2019
2 parents fac7ae5 + 8e12ba9 commit 26b4829
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 0 deletions.
1 change: 1 addition & 0 deletions doc/api/math-operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ Constants and padding/concatenation operations:
``Extract(f, M, D)`` extract sub-vector from vector ``f`` (M is starting index, D is dimension of sub-vector)
``ExtractT(f, M, D)`` insert vector ``f`` in a larger vector of zeros (M is starting index, D is dimension of output)
``Concat(f, g)`` concatenation of vectors ``f`` and ``g``
``OneHot(f, D)`` encodes a (rounded) scalar value as a one-hot vector of dimension D
====================== =========================================================================================================

Elementary dot products:
Expand Down
43 changes: 43 additions & 0 deletions keops/core/formulas/maths/OneHot.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#pragma once

#include <sstream>
#include <assert.h>

#include "core/formulas/constants/Zero.h"
#include "core/autodiff/UnaryOp.h"
#include "core/pre_headers.h"

namespace keops {

//////////////////////////////////////////////////////////////
//// ONE-HOT REPRESENTATION : OneHot<F,DIM> ////
//////////////////////////////////////////////////////////////

template< class F, int DIM_ >
struct OneHot : UnaryOp< OneHot, F, DIM_ > {
static const int DIM = DIM_;

static_assert(F::DIM == 1, "One-hot representation is only supported for scalar formulas.");
static_assert(DIM_ >= 1, "A one-hot vector should have length >= 1.");

static void PrintIdString(::std::stringstream &str) {
str << "OneHot";
}

// N.B.: This may not be the most efficient implementation,
// with unnecessary casts, etc.
static HOST_DEVICE INLINE
void Operation(__TYPE__ *out, __TYPE__ *outF) {
#pragma unroll
for (int k = 0; k < DIM; k++)
out[k] = (round(outF[0]) == k) ? 1 : 0 ;
}

// There is no gradient to accumulate on V, whatever V.
template < class V, class GRADIN >
using DiffT = Zero<V::DIM>;
};

#define OneHot(f,n) KeopsNS<OneHot<decltype(InvKeopsNS(f)),n>>()

}
1 change: 1 addition & 0 deletions keops/core/formulas/maths/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Concatenation and matrix-vector products:
* Elem<F,N> : extract Nth element of F
* Extract<F,START,LENGTH> : extract a number LENGTH starting at index START
* MatVecMult<FA,FB> : matrix-vector product (FA::DIM must be a muliple of FB::DIM)
* OneHot<F,D> : represents a scalar formula (rounded to an integer) as a one-hot vector of dimension D
* VecMatMult<FA,FB> : vector-matrix product (FB::DIM must be a muliple of FA::DIM)
* TensorProd<FA,FB> : tensor product (output is of dimension FA::DIM*FB::DIM)
* TensorDot<FA,FB,DA,DB,CA,CB> : tensor dot as in numpy. FA and FB are formulas and DA, DB, CA and CB are
Expand Down
1 change: 1 addition & 0 deletions keops/keops_includes.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "core/formulas/maths/TensorDot.h"
#include "core/formulas/maths/TensorProd.h"
#include "core/formulas/maths/VecMatMult.h"
#include "core/formulas/maths/OneHot.h"


// import all operations on vector implementations
Expand Down
16 changes: 16 additions & 0 deletions pykeops/common/lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,22 @@ def __getitem__(self, key):
else:
raise ValueError("LazyTensors only support indexing with integers and vanilla python slices.")


def one_hot(self, D):
r"""
Encodes a (rounded) scalar value as a one-hot vector of dimension D.
``x.one_hot(D)`` returns a :class:`LazyTensor` that encodes, symbolically,
a vector of length D whose round(x)-th coordinate is equal to 1, and the other ones to zero.
"""
if type(D) is not int:
raise ValueError("One-hot encoding expects an integer dimension of the output vector.")
if self.ndim != 1:
raise ValueError("One-hot encoding is only supported for scalar formulas.")

return self.unary("OneHot", dimres=D, opt_arg=D)


def concat(self, other):
r"""
Concatenation of two :class:`LazyTensor` - a binary operation.
Expand Down

0 comments on commit 26b4829

Please sign in to comment.