Skip to content

Commit

Permalink
Move to centralized cuDNN handle
Browse files Browse the repository at this point in the history
  • Loading branch information
slayton58 committed Jul 23, 2015
1 parent 438add5 commit 38d6baf
Show file tree
Hide file tree
Showing 21 changed files with 68 additions and 78 deletions.
6 changes: 6 additions & 0 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ class Caffe {
inline static curandGenerator_t curand_generator() {
return Get().curand_generator_;
}
#ifdef USE_CUDNN
inline static cudnnHandle_t cudnn_handle() { return Get().cudnn_handle_; }
#endif
#endif

// Returns the mode: running on CPU or GPU.
Expand Down Expand Up @@ -164,6 +167,9 @@ class Caffe {
#ifndef CPU_ONLY
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
#ifdef USE_CUDNN
cudnnHandle_t cudnn_handle_;
#endif
#endif
shared_ptr<RNG> random_generator_;

Expand Down
1 change: 0 additions & 1 deletion include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ class CuDNNSoftmaxLayer : public SoftmaxLayer<Dtype> {
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
Expand Down
3 changes: 0 additions & 3 deletions include/caffe/neuron_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,6 @@ class CuDNNReLULayer : public ReLULayer<Dtype> {
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
Expand Down Expand Up @@ -514,7 +513,6 @@ class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
Expand Down Expand Up @@ -599,7 +597,6 @@ class CuDNNTanHLayer : public TanHLayer<Dtype> {
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
Expand Down
5 changes: 0 additions & 5 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,6 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

bool handles_setup_;
cudnnHandle_t* handle_;
cudaStream_t* stream_;

// algorithms for forward and backwards convolutions
cudnnConvolutionFwdAlgo_t *fwd_algo_;
Expand Down Expand Up @@ -404,7 +402,6 @@ class CuDNNLRNLayer : public LRNLayer<Dtype> {
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

bool handles_setup_;
cudnnHandle_t handle_;
cudnnLRNDescriptor_t norm_desc_;
cudnnTensorDescriptor_t bottom_desc_, top_desc_;

Expand All @@ -431,7 +428,6 @@ class CuDNNLCNLayer : public LRNLayer<Dtype> {
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

bool handles_setup_;
cudnnHandle_t handle_;
cudnnLRNDescriptor_t norm_desc_;
cudnnTensorDescriptor_t bottom_desc_, top_desc_;

Expand Down Expand Up @@ -516,7 +512,6 @@ class CuDNNPoolingLayer : public PoolingLayer<Dtype> {
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_, top_desc_;
cudnnPoolingDescriptor_t pooling_desc_;
cudnnPoolingMode_t mode_;
Expand Down
20 changes: 18 additions & 2 deletions src/caffe/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ void* Caffe::RNG::generator() {
#else // Normal GPU + CPU Caffe.

Caffe::Caffe()
: cublas_handle_(NULL), curand_generator_(NULL), random_generator_(),
: cublas_handle_(NULL), curand_generator_(NULL),
#ifdef USE_CUDNN
cudnn_handle_(NULL),
#endif
random_generator_(),
mode_(Caffe::CPU), solver_count_(1), root_solver_(true) {
// Try to create a cublas handler, and report an error if failed (but we will
// keep the program running as one might just want to run CPU code).
Expand All @@ -112,13 +116,21 @@ Caffe::Caffe()
!= CURAND_STATUS_SUCCESS) {
LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
}
#ifdef USE_CUDNN
if (cudnnCreate(&cudnn_handle_) != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "Cannot create cuDNN handle. cuDNN won't be available.";
}
#endif
}

Caffe::~Caffe() {
if (cublas_handle_) CUBLAS_CHECK(cublasDestroy(cublas_handle_));
if (curand_generator_) {
CURAND_CHECK(curandDestroyGenerator(curand_generator_));
}
#ifdef USE_CUDNN
if (cudnn_handle_) CUDNN_CHECK(cudnnDestroy(cudnn_handle_));
#endif
}

void Caffe::set_random_seed(const unsigned int seed) {
Expand Down Expand Up @@ -157,6 +169,10 @@ void Caffe::SetDevice(const int device_id) {
CURAND_RNG_PSEUDO_DEFAULT));
CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(Get().curand_generator_,
cluster_seedgen()));
#ifdef USE_CUDNN
if (Get().cublas_handle_) CUDNN_CHECK(cudnnDestroy(Get().cudnn_handle_));
CUDNN_CHECK(cudnnCreate(&Get().cudnn_handle_));
#endif
}

void Caffe::DeviceQuery() {
Expand Down Expand Up @@ -300,7 +316,7 @@ void MemoryHandler::Init() {
size_t free_mem, used_mem;
CUDA_CHECK(cudaMemGetInfo(&free_mem, &used_mem));

devs[i].size = size_t(0.8*free_mem);
devs[i].size = size_t(0.95*free_mem);
devs[i].numStreams = 0;
devs[i].streams = NULL;
}
Expand Down
32 changes: 11 additions & 21 deletions src/caffe/layers/cudnn_conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace caffe {
// Set to three for the benefit of the backward pass, which
// can use separate streams for calculating the gradient w.r.t.
// bias, filter weights, and bottom data for each group independently
#define CUDNN_STREAMS_PER_GROUP 3
#define CUDNN_STREAMS_PER_GROUP 1

cudnnConvolutionFwdAlgo_t
GetCuDNNFwdAlgo(ConvolutionParameter_CuDNNFwdAlgorithm algo) {
Expand Down Expand Up @@ -68,9 +68,6 @@ template <typename Dtype>
void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
ConvolutionLayer<Dtype>::LayerSetUp(bottom, top);
// Initialize CUDA streams and cuDNN.
stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
// Initialize algorithm arrays
fwd_algo_ = new cudnnConvolutionFwdAlgo_t[bottom.size()];
bwd_filter_algo_= new cudnnConvolutionBwdFilterAlgo_t[bottom.size()];
Expand All @@ -94,11 +91,7 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
}

for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
CUDA_CHECK(cudaStreamCreate(&stream_[g]));
CUDNN_CHECK(cudnnCreate(&handle_[g]));
CUDNN_CHECK(cudnnSetStream(handle_[g], stream_[g]));
workspace[g] = NULL;
MemoryHandler::registerStream(stream_[g]);
}

// Set the indexing parameters.
Expand Down Expand Up @@ -172,7 +165,7 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(

// choose forward and backward algorithms + workspace(s)
if (!this->layer_param_.convolution_param().has_cudnnfwdalgo()) {
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[0],
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(Caffe::cudnn_handle(),
bottom_descs_[i],
filter_desc_,
conv_descs_[i],
Expand All @@ -185,7 +178,7 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
this->layer_param_.convolution_param().cudnnfwdalgo());
}

CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[0],
CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(Caffe::cudnn_handle(),
bottom_descs_[i],
filter_desc_,
conv_descs_[i],
Expand All @@ -203,7 +196,8 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
//
// choose backward algorithm for filter
if (!this->layer_param_.convolution_param().has_cudnnbwdfilteralgo()) {
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(handle_[0],
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
Caffe::cudnn_handle(),
bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_limit_bytes, &bwd_filter_algo_[i]) );
Expand All @@ -212,13 +206,15 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
this->layer_param_.convolution_param().cudnnbwdfilteralgo());
}
// get workspace for backwards filter algorithm
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_[0],
CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
Caffe::cudnn_handle(),
bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_,
bwd_filter_algo_[i], &workspace_bwd_filter_sizes_[i]));

// choose backward algo for data
if (!this->layer_param_.convolution_param().has_cudnnbwddataalgo()) {
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(handle_[0],
CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
Caffe::cudnn_handle(),
filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i],
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_limit_bytes, &bwd_data_algo_[i]));
Expand All @@ -228,7 +224,8 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
}

// get workspace size
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(handle_[0],
CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
Caffe::cudnn_handle(),
filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i],
bwd_data_algo_[i], &workspace_bwd_data_sizes_[i]) );
}
Expand Down Expand Up @@ -294,14 +291,7 @@ CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
}
cudnnDestroyFilterDescriptor(filter_desc_);

for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
cudaStreamDestroy(stream_[g]);
cudnnDestroy(handle_[g]);
}

cudaFree(workspaceData);
delete [] stream_;
delete [] handle_;
delete [] fwd_algo_;
delete [] bwd_filter_algo_;
delete [] bwd_data_algo_;
Expand Down
29 changes: 13 additions & 16 deletions src/caffe/layers/cudnn_conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
// Forward through cuDNN in parallel over groups.
for (int g = 0; g < this->group_; g++) {
#ifdef USE_CNMEM
MemoryHandler::mallocGPU(&workspace[0], workspace_fwd_sizes_[i],
stream_[0]);
MemoryHandler::mallocGPU(&workspace[0], workspace_fwd_sizes_[i]);
#endif
// Filters.
// CUDNN_CHECK(cudnnConvolutionForward(handle_[g],
CUDNN_CHECK(cudnnConvolutionForward(handle_[0],
CUDNN_CHECK(cudnnConvolutionForward(Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
bottom_descs_[i], bottom_data + bottom_offset_ * g,
filter_desc_, weight + weight_offset_ * g,
Expand All @@ -37,13 +36,13 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
top_descs_[i], top_data + top_offset_ * g));

#ifdef USE_CNMEM
MemoryHandler::freeGPU(workspace[0], stream_[0]);
MemoryHandler::freeGPU(workspace[0]);
workspace[0] = NULL;
#endif
// Bias.
if (this->bias_term_) {
const Dtype* bias_data = this->blobs_[1]->gpu_data();
CUDNN_CHECK(cudnnAddTensor(handle_[0], CUDNN_ADD_SAME_C,
CUDNN_CHECK(cudnnAddTensor(Caffe::cudnn_handle(), CUDNN_ADD_SAME_C,
cudnn::dataType<Dtype>::one,
bias_desc_, bias_data + bias_offset_ * g,
cudnn::dataType<Dtype>::one,
Expand All @@ -54,7 +53,7 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
// Synchronize the work across groups, each of which went into its own
// stream, by launching an empty kernel into the default (null) stream.
// NOLINT_NEXT_LINE(whitespace/operators)
sync_conv_groups<<<1, 1>>>();
CUDA_CHECK(cudaStreamSynchronize(cudaStreamLegacy));
}
}

Expand All @@ -79,7 +78,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
for (int g = 0; g < this->group_; g++) {
// Gradient w.r.t. bias.
if (this->bias_term_ && this->param_propagate_down_[1]) {
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0],
CUDNN_CHECK(cudnnConvolutionBackwardBias(Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
top_descs_[i], top_diff + top_offset_ * g,
cudnn::dataType<Dtype>::one,
Expand All @@ -89,12 +88,11 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
// Gradient w.r.t. weights.
if (this->param_propagate_down_[0]) {
#ifdef USE_CNMEM
MemoryHandler::mallocGPU(&workspace[0], workspace_bwd_filter_sizes_[i],
stream_[0]);
MemoryHandler::mallocGPU(&workspace[0], workspace_bwd_filter_sizes_[i]);
#endif
const Dtype* bottom_data = bottom[i]->gpu_data();
CUDNN_CHECK(cudnnConvolutionBackwardFilter_v3(
handle_[0],
Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
bottom_descs_[i], bottom_data + bottom_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
Expand All @@ -103,7 +101,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
cudnn::dataType<Dtype>::one,
filter_desc_, weight_diff + weight_offset_ * g));
#ifdef USE_CNMEM
MemoryHandler::freeGPU(workspace[0], stream_[0]);
MemoryHandler::freeGPU(workspace[0]);
workspace[0] = NULL;
#endif
}
Expand All @@ -115,11 +113,10 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
}
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
#ifdef USE_CNMEM
MemoryHandler::mallocGPU(&workspace[0], workspace_bwd_data_sizes_[i],
stream_[0]);
MemoryHandler::mallocGPU(&workspace[0], workspace_bwd_data_sizes_[i]);
#endif
CUDNN_CHECK(cudnnConvolutionBackwardData_v3(
handle_[0],
Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
filter_desc_, weight + weight_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
Expand All @@ -128,7 +125,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
cudnn::dataType<Dtype>::zero,
bottom_descs_[i], bottom_diff + bottom_offset_ * g));
#ifdef USE_CNMEM
MemoryHandler::freeGPU(workspace[0], stream_[0]);
MemoryHandler::freeGPU(workspace[0]);
workspace[0] = NULL;
#endif
}
Expand All @@ -137,7 +134,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
// Synchronize the work across groups, each of which went into its own
// stream, by launching an empty kernel into the default (null) stream.
// NOLINT_NEXT_LINE(whitespace/operators)
sync_conv_groups<<<1, 1>>>();
CUDA_CHECK(cudaStreamSynchronize(cudaStreamLegacy));
}
}

Expand Down
7 changes: 3 additions & 4 deletions src/caffe/layers/cudnn_lcn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ void CuDNNLCNLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
LRNLayer<Dtype>::LayerSetUp(bottom, top);

CUDNN_CHECK(cudnnCreate(&handle_));
CUDNN_CHECK(cudnnCreateLRNDescriptor(&norm_desc_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
Expand Down Expand Up @@ -66,11 +65,11 @@ CuDNNLCNLayer<Dtype>::~CuDNNLCNLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }

cudnnDestroyTensorDescriptor(bottom_desc_);
cudnnDestroyTensorDescriptor(top_desc_);
CUDNN_CHECK(cudnnDestroyTensorDescriptor(bottom_desc_));
CUDNN_CHECK(cudnnDestroyTensorDescriptor(top_desc_));

// destroy LRN handle
cudnnDestroy(handle_);
CUDNN_CHECK(cudnnDestroyLRNDescriptor(norm_desc_));

// free temp buffers
if (tempData1 != NULL) cudaFree(tempData1);
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/layers/cudnn_lcn_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ void CuDNNLCNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
#endif

CUDNN_CHECK(cudnnDivisiveNormalizationForward(
handle_, norm_desc_, CUDNN_DIVNORM_PRECOMPUTED_MEANS,
Caffe::cudnn_handle(), norm_desc_, CUDNN_DIVNORM_PRECOMPUTED_MEANS,
cudnn::dataType<Dtype>::one,
bottom_desc_, bottom_data,
NULL, // srcMeansData
Expand Down Expand Up @@ -51,7 +51,7 @@ void CuDNNLCNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
#endif

CUDNN_CHECK(cudnnDivisiveNormalizationBackward(
handle_, norm_desc_, CUDNN_DIVNORM_PRECOMPUTED_MEANS,
Caffe::cudnn_handle(), norm_desc_, CUDNN_DIVNORM_PRECOMPUTED_MEANS,
cudnn::dataType<Dtype>::one,
bottom_desc_, bottom_data,
NULL, top_diff, // NULL - srcMeansData
Expand Down
Loading

0 comments on commit 38d6baf

Please sign in to comment.