Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rebase pull #25

Merged
merged 5 commits into from
Jul 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions paddle/fluid/framework/fleet/box_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,32 @@ void BasicAucCalculator::add_unlock_data(double pred, int label) {
++_table[label][pos];
}

void BasicAucCalculator::add_unlock_data(double pred, int label,
float sample_scale) {
PADDLE_ENFORCE_GE(pred, 0.0, platform::errors::PreconditionNotMet(
"pred should be greater than 0"));
PADDLE_ENFORCE_LE(pred, 1.0, platform::errors::PreconditionNotMet(
"pred should be lower than 1"));
PADDLE_ENFORCE_EQ(
label * label, label,
platform::errors::PreconditionNotMet(
"label must be equal to 0 or 1, but its value is: %d", label));
int pos = std::min(static_cast<int>(pred * _table_size), _table_size - 1);
PADDLE_ENFORCE_GE(
pos, 0,
platform::errors::PreconditionNotMet(
"pos must be equal or greater than 0, but its value is: %d", pos));
PADDLE_ENFORCE_LT(
pos, _table_size,
platform::errors::PreconditionNotMet(
"pos must be less than table_size, but its value is: %d", pos));
_local_abserr += fabs(pred - label);
_local_sqrerr += (pred - label) * (pred - label);

_local_pred += pred * sample_scale;
_table[label][pos] += sample_scale;
}

void BasicAucCalculator::add_data(const float* d_pred, const int64_t* d_label,
int batch_size,
const paddle::platform::Place& place) {
Expand All @@ -81,6 +107,26 @@ void BasicAucCalculator::add_data(const float* d_pred, const int64_t* d_label,
}
}
}

void BasicAucCalculator::add_sample_data(
const float* d_pred, const int64_t* d_label,
const std::vector<float>& d_sample_scale, int batch_size,
const paddle::platform::Place& place) {
thread_local std::vector<float> h_pred;
thread_local std::vector<int64_t> h_label;
h_pred.resize(batch_size);
h_label.resize(batch_size);
cudaMemcpy(h_pred.data(), d_pred, sizeof(float) * batch_size,
cudaMemcpyDeviceToHost);
cudaMemcpy(h_label.data(), d_label, sizeof(int64_t) * batch_size,
cudaMemcpyDeviceToHost);

std::lock_guard<std::mutex> lock(_table_mutex);
for (int i = 0; i < batch_size; ++i) {
add_unlock_data(h_pred[i], h_label[i], d_sample_scale[i]);
}
}

// add mask data
void BasicAucCalculator::add_mask_data(const float* d_pred,
const int64_t* d_label,
Expand Down
33 changes: 27 additions & 6 deletions paddle/fluid/framework/fleet/box_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,18 @@ class BasicAucCalculator {
void reset();
// add single data in CPU with LOCK, deprecated
void add_unlock_data(double pred, int label);
void add_unlock_data(double pred, int label, float sample_scale);
// add batch data
void add_data(const float* d_pred, const int64_t* d_label, int batch_size,
const paddle::platform::Place& place);
// add mask data
void add_mask_data(const float* d_pred, const int64_t* d_label,
const int64_t* d_mask, int batch_size,
const paddle::platform::Place& place);
// add sample data
void add_sample_data(const float* d_pred, const int64_t* d_label,
const std::vector<float>& d_sample_scale, int batch_size,
const paddle::platform::Place& place);
void compute();
int table_size() const { return _table_size; }
double bucket_error() const { return _bucket_error; }
Expand Down Expand Up @@ -669,9 +674,11 @@ class BoxWrapper {
MetricMsg() {}
MetricMsg(const std::string& label_varname, const std::string& pred_varname,
int metric_phase, int bucket_size = 1000000,
bool mode_collect_in_gpu = false, int max_batch_size = 0)
bool mode_collect_in_gpu = false, int max_batch_size = 0,
const std::string& sample_scale_varname = "")
: label_varname_(label_varname),
pred_varname_(pred_varname),
sample_scale_varname_(sample_scale_varname),
metric_phase_(metric_phase) {
calculator = new BasicAucCalculator(mode_collect_in_gpu);
calculator->init(bucket_size, max_batch_size);
Expand All @@ -692,7 +699,19 @@ class BoxWrapper {
platform::errors::PreconditionNotMet(
"the predict data length should be consistent with "
"the label data length"));
calculator->add_data(pred_data, label_data, label_len, place);
std::vector<float> sample_scale_data;
if (!sample_scale_varname_.empty()) {
get_data<float>(exe_scope, sample_scale_varname_, &sample_scale_data);
PADDLE_ENFORCE_EQ(
label_len, sample_scale_data.size(),
platform::errors::PreconditionNotMet(
"lable size [%lu] and sample_scale_data[%lu] should be same",
label_len, sample_scale_data.size()));
calculator->add_sample_data(pred_data, label_data, sample_scale_data,
label_len, place);
} else {
calculator->add_data(pred_data, label_data, label_len, place);
}
}
template <class T = float>
static void get_data(const Scope* exe_scope, const std::string& varname,
Expand Down Expand Up @@ -728,6 +747,7 @@ class BoxWrapper {
protected:
std::string label_varname_;
std::string pred_varname_;
std::string sample_scale_varname_;
int metric_phase_;
BasicAucCalculator* calculator;
};
Expand Down Expand Up @@ -1050,12 +1070,13 @@ class BoxWrapper {
const std::string& mask_varname, int metric_phase,
const std::string& cmatch_rank_group, bool ignore_rank,
int bucket_size = 1000000, bool mode_collect_in_gpu = false,
int max_batch_size = 0) {
int max_batch_size = 0,
const std::string& sample_scale_varname = "") {
if (method == "AucCalculator") {
metric_lists_.emplace(
name,
new MetricMsg(label_varname, pred_varname, metric_phase, bucket_size,
mode_collect_in_gpu, max_batch_size));
name, new MetricMsg(label_varname, pred_varname, metric_phase,
bucket_size, mode_collect_in_gpu, max_batch_size,
sample_scale_varname));
} else if (method == "MultiTaskAucCalculator") {
metric_lists_.emplace(
name, new MultiTaskMetricMsg(label_varname, pred_varname,
Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/operators/bilinear_tensor_product_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
Tensor weight_mat =
weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim}));
math::GetBlas<DeviceContext, T>(dev_ctx).GEMM(
CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, x->data<T>(),
weight_mat.data<T>(), 0, left_mul.data<T>());
CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, static_cast<T>(1), x->data<T>(),
weight_mat.data<T>(), static_cast<T>(0), left_mul.data<T>());
output_col_vec.device(place) =
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
}
Expand Down Expand Up @@ -145,8 +145,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_x) *
y_mat;
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, static_cast<T>(1),
y_scale.data<T>(), weight_i.data<T>(), static_cast<T>(1), d_x->data<T>());
}

if (d_y || d_weight) {
Expand All @@ -155,14 +155,14 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
.broadcast(bcast_for_y);
x_scale_mat.device(place) = output_vec_y * x_mat;
if (d_y) {
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, static_cast<T>(1),
x_scale.data<T>(), weight_i.data<T>(), static_cast<T>(1), d_y->data<T>());
}
if (d_weight) {
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
framework::make_ddim({x_dim, y_dim}));
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, static_cast<T>(1),
x_scale.data<T>(), y->data<T>(), static_cast<T>(0), d_weight_i.data<T>());
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/math/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class Blas {
void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
T alpha, const T* A, const T* B, T beta, T* C) const;

template <typename T>
void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
float alpha, const T* A, const T* B, float beta, float* C, int flag) const;

template <typename T>
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
Expand Down
95 changes: 95 additions & 0 deletions paddle/fluid/operators/math/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,50 @@ struct CUBlas<double> {
}
};

template <>
struct CUBlas<int8_t> {
//int8_t call func:
//CUBlas<int8_t>::GEMM_EX(
// &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_8I, ldb, A,
// CUDA_R_8I, lda, &h_beta, C, CUDA_R_32F, N, CUDA_R_32F);

// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
cublasOperation_t transa, cublasOperation_t transb, int m,
int n, int k, const void *alpha, const void *A,
cudaDataType_t Atype, int lda, const void *B,
cudaDataType_t Btype, int ldb, const void *beta, void *C,
cudaDataType_t Ctype, int ldc,
cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
bool use_tensor_op_math = dev_ctx->tensor_core_available();
if (use_tensor_op_math) {
//algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
//VLOG(5) << "2. CUBlas int8_t, algo is CUBLAS_GEMM_DFALT_TENSOR_OP.";
algo = CUBLAS_GEMM_DFALT; // only for int8 gemm
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
VLOG(5) << "3. use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
algo = CUBLAS_GEMM_DFALT;
#endif // CUDA_VERSION >= 9000

dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasGemmEx(
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
beta, C, Ctype, ldc, computeType, algo));
});
#else
PADDLE_THROW(platform::errors::Unimplemented(
"cublasGemmEx is not supported on cuda <= 7.5"));
#endif
}
};

template <>
struct CUBlas<platform::float16> {
using float16 = platform::float16;
Expand Down Expand Up @@ -466,6 +510,57 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 8000
}

//int8_t matmul
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
float alpha, const int8_t *A,
const int8_t *B, float beta,
float *C, int flag) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

PADDLE_ENFORCE_GE(
context_.GetComputeCapability(), 53,
platform::errors::InvalidArgument(
"cublas int8_t gemm requires GPU compute capability >= 53,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = static_cast<float>(alpha);
float h_beta = static_cast<float>(beta);
//int h_alpha = static_cast<int>(alpha);
//int h_beta = static_cast<int>(beta);

#if CUDA_VERSION >= 8000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
VLOG(3) << "1. call int8_t GEMM_EX.";
CUBlas<int8_t>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B, CUDA_R_8I, ldb, A,
CUDA_R_8I, lda, &h_beta, C, CUDA_R_32F, N, CUDA_R_32F);
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm

//context_.CublasCall([&](cublasHandle_t handle) {
// CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K,
// &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
// N);
//});
#endif // CUDA_VERSION >= 8000
}


template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
Expand Down
Loading