Skip to content

Commit

Permalink
[Mixed Precision] Fix mixed precsion to use Tensor V2
Browse files Browse the repository at this point in the history
This PR includes fixes to use TensorV2

Resolves:

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <[email protected]>
  • Loading branch information
jijoongmoon committed Aug 28, 2024
1 parent cac2af3 commit a0c3029
Show file tree
Hide file tree
Showing 23 changed files with 104 additions and 95 deletions.
15 changes: 8 additions & 7 deletions api/ccapi/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,14 @@ class Model {
* @details This function accepts vector of properties in the format -
* { std::string property_name, void * property_val, ...}
*/
virtual int train(const std::vector<std::string> &values = {},
std::function<bool(void *)> stop_cb =
[](void *stop_user_data) { return false; },
void *stop_user_data = nullptr,
std::function<void(void *)> epoch_complete_cb =
[](void *epoch_user_data) { return false; },
void *epoch_user_data = nullptr) = 0;
virtual int train(
const std::vector<std::string> &values = {},
std::function<bool(void *)> stop_cb =
[](void *stop_user_data) { return false; },
void *stop_user_data = nullptr,
std::function<void(void *)> epoch_complete_cb =
[](void *epoch_user_data) { return false; },
void *epoch_user_data = nullptr) = 0;

/**
* @brief Run Model train with callback function by user
Expand Down
13 changes: 7 additions & 6 deletions nntrainer/graph/network_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,16 +487,17 @@ bool NetworkGraph::backwarding(
}
}
/** apply the gradient with the above global norm */
std::cout << "======================================= update gradient "
<< std::endl;
// std::cout << "======================================= update gradient "
// << std::endl;
for (auto w : lazy_weights) {
std::cout << w->getName() << " : ";
// std::cout << w->getName() << " : ";
lazy_apply_grad_op(*w, iteration);
}
nan_count++;

std::cout << "====================================== update gradient finished"
<< std::endl;
// std::cout << "====================================== update gradient
// finished"
// << std::endl;
/** @todo : handle as property : growth_interval : default --> 2000 */

if (nan_count > 2000) {
Expand Down Expand Up @@ -1643,7 +1644,7 @@ void NetworkGraph::requestOptimizerVariable(
w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables(
dims, w->getName(), ":opt", TensorLifespan::MAX_LIFESPAN,
w->isGradientClipByGlobalNorm(), w->isMixedPrecision(),
Tensor::Initializer::ZEROS));
Initializer::ZEROS));
}
}
}
Expand Down
18 changes: 9 additions & 9 deletions nntrainer/layers/bn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
1.0f, bias_decay, "beta", true);

wt_idx[BNParams::mu_b] =
context.requestTensor(dim, "moviing_mean_backup", Tensor::Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
context.requestTensor(dim, "moviing_mean_backup", Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);

wt_idx[BNParams::var_b] = context.requestTensor(
dim, "moviing_variance_backup", Tensor::Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);
wt_idx[BNParams::var_b] =
context.requestTensor(dim, "moviing_variance_backup", Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);

/**
* caches the deviation -> input - avg(input)
Expand All @@ -137,8 +137,8 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
}

wt_idx[BNParams::deviation] =
context.requestTensor(in_dim_, "deviation", Tensor::Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
context.requestTensor(in_dim_, "deviation", Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);
/** caches the inverse standard deviation */
wt_idx[BNParams::invstd] =
context.requestTensor(dim, "invstd", Initializer::NONE, false,
Expand All @@ -150,8 +150,8 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
* as the output of this layer need not be stored all the time.
*/
wt_idx[BNParams::t_full] =
context.requestTensor(in_dim_, "tensor_full", Tensor::Initializer::NONE,
false, TensorLifespan::CALC_DERIV_LIFESPAN);
context.requestTensor(in_dim_, "tensor_full", Initializer::NONE, false,
TensorLifespan::CALC_DERIV_LIFESPAN);
/**
* caches variance + epsilon as well.
*/
Expand Down
2 changes: 1 addition & 1 deletion nntrainer/layers/layer_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class InitLayerContext {
* start from 0 and will always be incremental.
*/
unsigned int requestWeight(const TensorDim &dim, const TensorDim &dim_g,
const Tensor::Initializer init,
const Initializer init,
const WeightRegularizer reg, const float reg_const,
const float decay, const std::string &name,
bool trainable = true, unsigned int out_axis = 3) {
Expand Down
24 changes: 12 additions & 12 deletions nntrainer/layers/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,16 +512,16 @@ void LSTMLayer::finalize(InitLayerContext &context) {
const TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit,
activation_tensor_type);

wt_idx[LSTMParams::hidden_state] = context.requestTensor(
hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN);
wt_idx[LSTMParams::hidden_state] =
context.requestTensor(hidden_state_dim, "hidden_state", Initializer::NONE,
true, TensorLifespan::ITERATION_LIFESPAN);
// cell_state_dim : [ batch_size, 1, max_timestep, unit ]
const TensorDim cell_state_dim(batch_size, 1, max_timestep, unit,
activation_tensor_type);

wt_idx[LSTMParams::cell_state] = context.requestTensor(
cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN);
wt_idx[LSTMParams::cell_state] =
context.requestTensor(cell_state_dim, "cell_state", Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN);

// ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
const TensorDim ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit,
Expand Down Expand Up @@ -594,18 +594,18 @@ void LSTMLayer::finalize(InitLayerContext &context) {
// reverse_ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
const TensorDim reverse_ifgo_dim(batch_size, 1, max_timestep,
NUM_GATE * unit, activation_tensor_type);
wt_idx[LSTMParams::reverse_ifgo] = context.requestTensor(
reverse_ifgo_dim, "reverse_ifgo", Tensor::Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN);
wt_idx[LSTMParams::reverse_ifgo] =
context.requestTensor(reverse_ifgo_dim, "reverse_ifgo", Initializer::NONE,
true, TensorLifespan::ITERATION_LIFESPAN);
}

if (dropout_rate > epsilon) {
// dropout_mask_dim = [ batch, 1, time_iteration, unit ]
const TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit,
activation_tensor_type);
wt_idx[LSTMParams::dropout_mask] = context.requestTensor(
dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);
wt_idx[LSTMParams::dropout_mask] =
context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
}

if (context.getActivationDataType() == TensorDim::DataType::FP32) {
Expand Down
8 changes: 4 additions & 4 deletions nntrainer/layers/pooling2d_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,15 @@ void Pooling2DLayer::finalize(InitLayerContext &context) {
auto helper_dim = in_dim;
helper_dim.setDataType(ml::train::TensorDim::DataType::FP32);
pool_helper_idx =
context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
context.requestTensor(helper_dim, "helper_idx", Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);
pool_helper_size.resize(helper_dim.batch() * helper_dim.channel());
} else {
auto helper_dim = out_dim;
helper_dim.setDataType(ml::train::TensorDim::DataType::FP32);
pool_helper_idx =
context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
context.requestTensor(helper_dim, "helper_idx", Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);
}
}

Expand Down
5 changes: 5 additions & 0 deletions nntrainer/tensor/char_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ class CharTensor : public TensorBase {
* @return std::string of tensor data type (QINT8)
*/
std::string getStringDataType() const override { return "QINT8"; }

/**
* @copydoc Tensor::isValid()
*/
bool isValid() const override { return true; }; // NYI
};

} // namespace nntrainer
Expand Down
6 changes: 3 additions & 3 deletions nntrainer/tensor/float_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void FloatTensor::setZero() {
// sscal(size(), 0, getData<float>(), 1);
/// @note we cannot use sscal, when we set zero. if the data is inf or
/// NaN, then the inf or NaN still remain.
memset(getData<float>(), 0, sizeof(float) * size());
memset((float *)getData(), 0, sizeof(float) * size());
} else {
/// @todo implement apply_i
// apply_i<float>([](float val) -> float { return 0; });
Expand Down Expand Up @@ -1210,8 +1210,8 @@ void FloatTensor::apply_broadcast(
return apply_broadcast_util(m, v_func, output, this->computeBroadcastInfo(m));
}

bool Tensor::isValid() const {
return is_valid(dim.getDataLen(), Tdatatype::FP32, getData<float>());
bool FloatTensor::isValid() const {
return is_valid(dim.getDataLen(), Tdatatype::FP32, (float *)getData());
}

} // namespace nntrainer
2 changes: 1 addition & 1 deletion nntrainer/tensor/float_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ class FloatTensor : public TensorBase {
/**
* @copydoc Tensor::isValid()
*/
bool Tensor::isValid() const;
bool isValid() const override;
};

} // namespace nntrainer
Expand Down
6 changes: 3 additions & 3 deletions nntrainer/tensor/half_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void HalfTensor::setZero() {
// sscal(size(), 0, (_FP16 *)getData(), 1);
/// @note we cannot use sscal, when we set zero. if the data is inf or
/// NaN, then the inf or NaN still remain.
memset(getData<_FP16>(), 0, sizeof(_FP16) * size());
memset((_FP16 *)getData(), 0, sizeof(_FP16) * size());
} else {
/// @todo implement apply_i
// apply_i<_FP16>([](_FP16 val) -> _FP16 { return 0; });
Expand Down Expand Up @@ -1176,8 +1176,8 @@ void HalfTensor::apply_broadcast_util(
}
}

bool Tensor::isValid() const {
return is_valid(dim.getDataLen(), Tdatatype::FP16, getData<_FP16>());
bool HalfTensor::isValid() const {
return is_valid(dim.getDataLen(), Tdatatype::FP16, (_FP16 *)getData());
}

} // namespace nntrainer
2 changes: 1 addition & 1 deletion nntrainer/tensor/half_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class HalfTensor : public TensorBase {
/**
* @copydoc Tensor::isValid()
*/
bool Tensor::isValid() const;
bool isValid() const override;
};

} // namespace nntrainer
Expand Down
1 change: 0 additions & 1 deletion nntrainer/tensor/hgemm/hgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,3 @@ void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
}
}
}

12 changes: 6 additions & 6 deletions nntrainer/tensor/manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ std::vector<Weight *> Manager::requestWeights(
*/
grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix,
dim_g, grad_exec_order, grad_ls,
Tensor::Initializer::ZEROS);
Initializer::ZEROS);

if (var->getDataType() != ml::train::TensorDim::DataType::FP32) {
TensorDim var32_dim(dim_v);
Expand All @@ -444,7 +444,7 @@ std::vector<Weight *> Manager::requestWeights(

var32 = weight_pool.requestOrExtend(shared_name + ":var32", var32_dim,
var32_exec_order, var_ls,
Tensor::Initializer::ZEROS);
Initializer::ZEROS);
}
}
} else {
Expand All @@ -461,16 +461,16 @@ std::vector<Weight *> Manager::requestWeights(
// if (Weight::isGradientClipByGlobalNorm(clip_by_global_norm))
// is_wgrad = false;
grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim_g,
grad_exec_order, grad_ls,
Tensor::Initializer::ZEROS, is_wgrad);
grad_exec_order, grad_ls, Initializer::ZEROS,
is_wgrad);
if (var->getDataType() != ml::train::TensorDim::DataType::FP32) {
TensorDim var32_dim(dim_v);
var32_dim.setDataType(ml::train::TensorDim::DataType::FP32);
std::vector<unsigned int> var32_exec_order;
var32_exec_order.push_back(TensorPool::PERSIST_END_ORDER);
var32 =
weight_pool.request(name + ":var32", var32_dim, var32_exec_order,
var_ls, Tensor::Initializer::ZEROS);
var_ls, Initializer::ZEROS);
}
}
}
Expand Down Expand Up @@ -690,7 +690,7 @@ bool Manager::isSecondLastAccess(const std::string &name,
std::vector<Tensor *> Manager::requestWeightOptimizerVariables(
const std::vector<TensorDim> &dims, const std::string &name,
const std::string &suffix, const TensorLifespan &lifespan, bool is_grad_clip,
bool is_mixed_precision, Tensor::Initializer initializer) {
bool is_mixed_precision, Initializer initializer) {

std::vector<Tensor *> ret;
ret.reserve(dims.size());
Expand Down
13 changes: 8 additions & 5 deletions nntrainer/tensor/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,19 @@ class Manager {
/**
* @brief Constructor of Manager
*/
Manager(bool enable_swap, const std::string &swap_path = "",
Manager(bool enable_swap_, const std::string &swap_path = "",
unsigned int lookahead = 0, const std::string tensor_format_ = "NCHW",
const std::string tensor_dtype_ = "FP32-FP32",
ExecutionMode exec_mode_ = ExecutionMode::TRAIN) :
weight_pool(enable_swap, swap_path, "weight_pool"),
tensor_pool(enable_swap && (exec_mode_ == ExecutionMode::TRAIN), swap_path,
weight_pool(enable_swap_, swap_path, "weight_pool"),
tensor_pool(enable_swap_ && (exec_mode_ == ExecutionMode::TRAIN), swap_path,
"tensor_pool"),
enable_optimizations(true),
swap_lookahead(lookahead),
tensor_format(tensor_format_),
tensor_dtype(split(tensor_dtype_, getRegex("\\-"))),
exec_mode(exec_mode_) {}
exec_mode(exec_mode_),
enable_swap(enable_swap_) {}

/**
* @brief Construct a new Manager object (deleted)
Expand Down Expand Up @@ -228,7 +229,7 @@ class Manager {
const std::vector<TensorDim> &dims, const std::string &name,
const std::string &suffix, const TensorLifespan &lifespan,
bool is_grad_clip, bool is_mixed_type,
Tensor::Initializer initializer = Tensor::Initializer::NONE);
Initializer initializer = Initializer::NONE);

/**
* @brief Create tensors with the given spec
Expand Down Expand Up @@ -537,6 +538,8 @@ class Manager {

ExecutionMode exec_mode;

bool enable_swap;

/**
* @brief Finalize the given tensor pool
*
Expand Down
5 changes: 5 additions & 0 deletions nntrainer/tensor/short_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ class ShortTensor : public TensorBase {
* @return std::string of tensor data type (UINT16)
*/
std::string getStringDataType() const override { return "UINT16"; }

/**
* @copydoc Tensor::isValid()
*/
bool isValid() const override { return true; }; // NYI
};

} // namespace nntrainer
Expand Down
7 changes: 4 additions & 3 deletions nntrainer/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,10 +944,11 @@ Tensor Tensor::clone() const {
Tensor Tensor::clone(ml::train::TensorDim::DataType type) const {
if (getDataType() == type)
return clone();

Tensor output(getName(), getFormat(), type);
TensorDim dim = getDim();
dim.setDataType(type);
Tensor output(dim, true);
output.copyData(*this);
output.name = name;
output.setName(getName());
return output;
}

Expand Down
2 changes: 1 addition & 1 deletion nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ class TensorBase {
/**
* @copydoc Tensor::isValid()
*/
bool isValid() const { return true; };
virtual bool isValid() const = 0;

static constexpr float epsilon = 1e-5;

Expand Down
6 changes: 3 additions & 3 deletions nntrainer/tensor/tensor_wrap_specs.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ enum class TensorLifespan {
* regularizer_constant, decay, clip gradient constant, need_gradient property,
* name, output axis of the tensor object and loss Scale Factor, is_mixed.
*/
typedef std::tuple<TensorDim, TensorDim, Tensor::Initializer, WeightRegularizer,
float, float, float, bool, const std::string, unsigned int,
float, bool>
typedef std::tuple<TensorDim, TensorDim, Initializer, WeightRegularizer, float,
float, float, bool, const std::string, unsigned int, float,
bool>
WeightSpec;

/**
Expand Down
Loading

0 comments on commit a0c3029

Please sign in to comment.