Skip to content

Commit

Permalink
Merge branch 'softmax-fix-junzhang' into 'main'
Browse files Browse the repository at this point in the history
Fix softmax tensor declaration in ctor

See merge request dl/hugectr/hugectr!1513
  • Loading branch information
minseokl committed Dec 15, 2023
2 parents 94dbb7e + bec7a6b commit 5679966
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions HugeCTR/src/layers/softmax_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <linalg/binary_op.cuh>
#include <linalg/reduce.cuh>
#include <linalg/unary_op.cuh>
#include <network_buffer_channels.hpp>
#include <utils.hpp>
namespace HugeCTR {

Expand All @@ -36,14 +37,15 @@ SoftmaxLayer<T>::SoftmaxLayer(const core23::Tensor& input_tensor,
dims_ = input_tensor.shape().dims();
hidden_size_ = input_tensor.shape().size(dims_ - 1);
n_rows_ = len_ / hidden_size_;
workspace23_ =
core23::Tensor({(int64_t)n_rows_}, core23::DataType(core23::ToScalarType<T>::value));
identity23_ =
core23::Tensor({(int64_t)hidden_size_}, core23::DataType(core23::ToScalarType<T>::value));
softmax_out23_ =
core23::Tensor(input_tensor.shape(), core23::DataType(core23::ToScalarType<T>::value));
core23::BufferParams buf_p{.channel = GetBlobsBufferChannel()};
auto param = (input_tensor.my_params().buffer_params(buf_p));
workspace23_ = core23::Tensor(
param.shape({(int64_t)n_rows_}).data_type(core23::DataType(core23::ToScalarType<T>::value)));
identity23_ = core23::Tensor(param.shape({(int64_t)hidden_size_})
.data_type(core23::DataType(core23::ToScalarType<T>::value)));
softmax_out23_ = core23::Tensor(param.shape(input_tensor.shape())
.data_type(core23::DataType(core23::ToScalarType<T>::value)));
}

template <typename T>
void SoftmaxLayer<T>::initialize() {
CudaDeviceContext context(get_device_id());
Expand Down

0 comments on commit 5679966

Please sign in to comment.