Skip to content

Commit

Permalink
[ Tensor ] UInt Tensor : UInt8 / UInt16 / Uint32
Browse files Browse the repository at this point in the history
- This commit resolves #2733
- This commit implements template of UIntTensor based on the ShortTensor class.
- Based on the template, this commit supports UInt8 / UInt16 / UInt32
- Implement UIntTensor Template
- Implement UInt8Tensor / UInt16Tensor / UInt32Tensor (ShortTensor is replaced)
- Unit tests for UInt8 and UInt32 are added

Self evaluation:
Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Eunju Yang <[email protected]>
  • Loading branch information
EunjuYang committed Sep 13, 2024
1 parent 9eb4c85 commit 5220bb1
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 125 deletions.
13 changes: 9 additions & 4 deletions api/ccapi/include/tensor_dim.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ class TensorDim {

/**
* @brief Tensor Data Type.
* Currently support QINT4, QINT8, UINT16, FP16 & FP32
* Currently support QINT4, QINT8, UINT8, UINT16, UINT32, FP16 & FP32
*/
enum class DataType {
QINT4, /** quantized int 4*/
QINT8, /** quantized int 8*/
UINT8, /** unsigned int 8 bit */
UINT16, /** unsigned int 16 bit */
UINT32, /** unsigned int 16 bit */
FP16, /** half precision */
FP32 /** single precision */
};
Expand Down Expand Up @@ -112,7 +114,8 @@ class TensorDim {
* @brief Creator of TensorDim with Format & DataType
*
* @param fm format NCHW | HNWC
* @param d_type DataType QINT4 | QINT8 | UINT16 | FP16 | FP32
* @param d_type DataType QINT4 | QINT8 | UINT8 | UINT16 | UINT32 | FP16 |
* FP32
* @param eff_dim_flag_ effective dimension flag (1 means it's effective)
* @param dyn_dim_flag_ dynamic dimension flag (1 means it's unspecified)
*/
Expand Down Expand Up @@ -215,7 +218,8 @@ class TensorDim {
* @param h height
* @param w width
* @param fm format NCHW | HNWC
* @param d_type Data Type QINT4 | QINT8 | UINT16 | FP16 | FP32
* @param d_type DataType QINT4 | QINT8 | UINT8 | UINT16 | UINT32 | FP16 |
* FP32
* @param eff_dim_flag_ dimension bit flag to calculate the dynamic
* dimension, rightmost is width
*/
Expand Down Expand Up @@ -244,7 +248,8 @@ class TensorDim {
*
* @param shape shape of format
* @param fm format NCHW | HNWC
* @param d_type data type QINT4 | QINT8 | UINT16 | FP16 | FP32
* @param d_type DataType QINT4 | QINT8 | UINT8 | UINT16 | UINT32 | FP16 |
* FP32
* @param order data storage order ROW_MAJOR | COL_MAJOR
*/
TensorDim(const std::string &shape, TensorDim::Format fm,
Expand Down
2 changes: 1 addition & 1 deletion debian/nntrainer-dev.install
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
/usr/include/nntrainer/tensor.h
/usr/include/nntrainer/tensor_base.h
/usr/include/nntrainer/char_tensor.h
/usr/include/nntrainer/short_tensor.h
/usr/include/nntrainer/uint_tensor.h
/usr/include/nntrainer/float_tensor.h
/usr/include/nntrainer/tensor_wrap_specs.h
/usr/include/nntrainer/blas_interface.h
Expand Down
4 changes: 2 additions & 2 deletions nntrainer/tensor/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ tensor_sources = [
'tensor_base.cpp',
'float_tensor.cpp',
'char_tensor.cpp',
'short_tensor.cpp',
'uint_tensor.cpp',
'tensor_dim.cpp',
'var_grad.cpp',
'weight.cpp',
Expand All @@ -29,7 +29,7 @@ tensor_headers = [
'tensor_base.h',
'float_tensor.h',
'char_tensor.h',
'short_tensor.h',
'uint_tensor.h',
'weight.h',
'var_grad.h',
'tensor_wrap_specs.h',
Expand Down
64 changes: 51 additions & 13 deletions nntrainer/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
#include <char_tensor.h>
#include <float_tensor.h>
#include <lazy_tensor.h>
#include <short_tensor.h>
#include <tensor.h>
#include <uint_tensor.h>

#ifdef ENABLE_FP16
#include <half_tensor.h>
Expand All @@ -33,9 +33,15 @@ Tensor::Tensor(std::string name_, Tformat fm, Tdatatype d_type) {
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
} else if (d_type == Tdatatype::UINT8) {
itensor = std::shared_ptr<UInt8Tensor>(new UInt8Tensor(name_, fm),
std::default_delete<UInt8Tensor>());
} else if (d_type == Tdatatype::UINT16) {
itensor = std::shared_ptr<ShortTensor>(new ShortTensor(name_, fm),
std::default_delete<ShortTensor>());
itensor = std::shared_ptr<UInt16Tensor>(
new UInt16Tensor(name_, fm), std::default_delete<UInt16Tensor>());
} else if (d_type == Tdatatype::UINT32) {
itensor = std::shared_ptr<UInt32Tensor>(
new UInt32Tensor(name_, fm), std::default_delete<UInt32Tensor>());
} else if (d_type == Tdatatype::QINT8) {
itensor = std::shared_ptr<CharTensor>(new CharTensor(name_, fm),
std::default_delete<CharTensor>());
Expand Down Expand Up @@ -63,10 +69,18 @@ Tensor::Tensor(const TensorDim &d, bool alloc_now, Initializer init,
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
} else if (d.getDataType() == Tdatatype::UINT8) {
itensor =
std::shared_ptr<UInt8Tensor>(new UInt8Tensor(d, alloc_now, init, name),
std::default_delete<UInt8Tensor>());
} else if (d.getDataType() == Tdatatype::UINT16) {
itensor =
std::shared_ptr<ShortTensor>(new ShortTensor(d, alloc_now, init, name),
std::default_delete<ShortTensor>());
std::shared_ptr<UInt16Tensor>(new UInt16Tensor(d, alloc_now, init, name),
std::default_delete<UInt16Tensor>());
} else if (d.getDataType() == Tdatatype::UINT32) {
itensor =
std::shared_ptr<UInt32Tensor>(new UInt32Tensor(d, alloc_now, init, name),
std::default_delete<UInt32Tensor>());
} else if (d.getDataType() == Tdatatype::QINT8) {
itensor =
std::shared_ptr<CharTensor>(new CharTensor(d, alloc_now, init, name),
Expand All @@ -92,9 +106,15 @@ Tensor::Tensor(const TensorDim &d, const void *buf) {
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
} else if (d.getDataType() == Tdatatype::UINT8) {
itensor = std::shared_ptr<UInt8Tensor>(new UInt8Tensor(d, buf),
std::default_delete<UInt8Tensor>());
} else if (d.getDataType() == Tdatatype::UINT16) {
itensor = std::shared_ptr<ShortTensor>(new ShortTensor(d, buf),
std::default_delete<ShortTensor>());
itensor = std::shared_ptr<UInt16Tensor>(
new UInt16Tensor(d, buf), std::default_delete<UInt16Tensor>());
} else if (d.getDataType() == Tdatatype::UINT32) {
itensor = std::shared_ptr<UInt32Tensor>(
new UInt32Tensor(d, buf), std::default_delete<UInt32Tensor>());
} else if (d.getDataType() == Tdatatype::QINT8) {
itensor = std::shared_ptr<CharTensor>(new CharTensor(d, buf),
std::default_delete<CharTensor>());
Expand All @@ -117,9 +137,15 @@ Tensor::Tensor(const Tensor &rhs) {
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
} else if (rhs.getDataType() == Tdatatype::UINT8) {
itensor = std::shared_ptr<UInt8Tensor>(new UInt8Tensor(*rhs.itensor),
std::default_delete<UInt8Tensor>());
} else if (rhs.getDataType() == Tdatatype::UINT16) {
itensor = std::shared_ptr<ShortTensor>(new ShortTensor(*rhs.itensor),
std::default_delete<ShortTensor>());
itensor = std::shared_ptr<UInt16Tensor>(
new UInt16Tensor(*rhs.itensor), std::default_delete<UInt16Tensor>());
} else if (rhs.getDataType() == Tdatatype::UINT32) {
itensor = std::shared_ptr<UInt32Tensor>(
new UInt32Tensor(*rhs.itensor), std::default_delete<UInt32Tensor>());
} else if (rhs.getDataType() == Tdatatype::QINT8) {
itensor = std::shared_ptr<CharTensor>(new CharTensor(*rhs.itensor),
std::default_delete<CharTensor>());
Expand All @@ -137,9 +163,15 @@ Tensor &Tensor::operator=(const Tensor &rhs) {
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
} else if (rhs.getDataType() == Tdatatype::UINT8) {
itensor = std::shared_ptr<UInt8Tensor>(new UInt8Tensor(*rhs.itensor),
std::default_delete<UInt8Tensor>());
} else if (rhs.getDataType() == Tdatatype::UINT16) {
itensor = std::shared_ptr<ShortTensor>(new ShortTensor(*rhs.itensor),
std::default_delete<ShortTensor>());
itensor = std::shared_ptr<UInt16Tensor>(
new UInt16Tensor(*rhs.itensor), std::default_delete<UInt16Tensor>());
} else if (rhs.getDataType() == Tdatatype::UINT32) {
itensor = std::shared_ptr<UInt32Tensor>(
new UInt32Tensor(*rhs.itensor), std::default_delete<UInt32Tensor>());
} else if (rhs.getDataType() == Tdatatype::QINT8) {
itensor = std::shared_ptr<CharTensor>(new CharTensor(*rhs.itensor),
std::default_delete<CharTensor>());
Expand All @@ -163,9 +195,15 @@ bool Tensor::operator==(const Tensor &rhs) const {
"Error: HalfTensor cannot be created or used when FP16 is not enabled. "
"Please check if the tensor data type is set properly.");
#endif
} else if (getDataType() == Tdatatype::UINT8) {
return *std::dynamic_pointer_cast<UInt8Tensor>(itensor) ==
*std::dynamic_pointer_cast<UInt8Tensor>(rhs.itensor);
} else if (getDataType() == Tdatatype::UINT16) {
return *std::dynamic_pointer_cast<ShortTensor>(itensor) ==
*std::dynamic_pointer_cast<ShortTensor>(rhs.itensor);
return *std::dynamic_pointer_cast<UInt16Tensor>(itensor) ==
*std::dynamic_pointer_cast<UInt16Tensor>(rhs.itensor);
} else if (getDataType() == Tdatatype::UINT32) {
return *std::dynamic_pointer_cast<UInt32Tensor>(itensor) ==
*std::dynamic_pointer_cast<UInt32Tensor>(rhs.itensor);
} else if (getDataType() == Tdatatype::QINT8) {
return *std::dynamic_pointer_cast<CharTensor>(itensor) ==
*std::dynamic_pointer_cast<CharTensor>(rhs.itensor);
Expand Down
68 changes: 65 additions & 3 deletions nntrainer/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
#include <char_tensor.h>
#include <float_tensor.h>
#include <nntrainer_log.h>
#include <short_tensor.h>
#include <tensor_base.h>
#include <uint_tensor.h>

#ifdef ENABLE_FP16
#include <half_tensor.h>
Expand Down Expand Up @@ -231,15 +231,46 @@ class Tensor {
Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
#endif

/**
* @brief Constructor of Tensor
* @param[in] d data for the Tensor. It needs to set format properly.
* @param[in] t_type Tensor Type
*/
Tensor(std::vector<std::vector<std::vector<std::vector<uint8_t>>>> const &d,
ml::train::TensorDim::TensorType t_type) {
itensor = std::shared_ptr<UInt8Tensor>(new UInt8Tensor(d, t_type.format),
std::default_delete<UInt8Tensor>());
}

/**
* @brief Constructor of Tensor
* @note This constructor copies vector again. needs refactoring
* @param[in] d data for the Tensor. It needs to set format properly.
* @param[in] t_type Tensor Type
*/
Tensor(std::vector<std::vector<std::vector<uint8_t>>> const &d,
ml::train::TensorDim::TensorType t_type) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};

/**
* @brief Constructor of Tensor
* @note This constructor copies vector again. needs refactoring
* @param[in] d data for the Tensor with batch size one
* @param[in] t_type Tensor Type
*/
Tensor(std::vector<std::vector<uint8_t>> const &d,
ml::train::TensorDim::TensorType t_type) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};

/**
* @brief Constructor of Tensor
* @param[in] d data for the Tensor. It needs to set format properly.
* @param[in] t_type Tensor Type
*/
Tensor(std::vector<std::vector<std::vector<std::vector<uint16_t>>>> const &d,
ml::train::TensorDim::TensorType t_type) {
itensor = std::shared_ptr<ShortTensor>(new ShortTensor(d, t_type.format),
std::default_delete<ShortTensor>());
itensor = std::shared_ptr<UInt16Tensor>(
new UInt16Tensor(d, t_type.format), std::default_delete<UInt16Tensor>());
}

/**
Expand All @@ -262,6 +293,37 @@ class Tensor {
ml::train::TensorDim::TensorType t_type) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};

/**
* @brief Constructor of Tensor
* @param[in] d data for the Tensor. It needs to set format properly.
* @param[in] t_type Tensor Type
*/
Tensor(std::vector<std::vector<std::vector<std::vector<uint32_t>>>> const &d,
ml::train::TensorDim::TensorType t_type) {
itensor = std::shared_ptr<UInt32Tensor>(
new UInt32Tensor(d, t_type.format), std::default_delete<UInt32Tensor>());
}

/**
* @brief Constructor of Tensor
* @note This constructor copies vector again. needs refactoring
* @param[in] d data for the Tensor. It needs to set format properly.
* @param[in] t_type Tensor Type
*/
Tensor(std::vector<std::vector<std::vector<uint32_t>>> const &d,
ml::train::TensorDim::TensorType t_type) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};

/**
* @brief Constructor of Tensor
* @note This constructor copies vector again. needs refactoring
* @param[in] d data for the Tensor with batch size one
* @param[in] t_type Tensor Type
*/
Tensor(std::vector<std::vector<uint32_t>> const &d,
ml::train::TensorDim::TensorType t_type) :
Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};

/**
* @brief Constructor of Tensor
* @param[in] d data for the Tensor. It needs to set format properly.
Expand Down
Loading

0 comments on commit 5220bb1

Please sign in to comment.