Skip to content

Commit

Permalink
Prevent leaking pool_mr handle
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Jul 9, 2020
1 parent 6abd4c0 commit 2bdbc23
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
14 changes: 11 additions & 3 deletions tests/cpp/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,16 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(
return gbm;
}

#ifndef XGBOOST_USE_CUDA
void SetUpRMMResource() {}
#endif // XGBOOST_USE_CUDA
#if !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1
class RMMAllocator {};

void DeleteRMMResource(RMMAllocator* r) {
delete r;
}

RMMAllocatorPtr SetUpRMMResource() {
return RMMAllocatorPtr(nullptr, DeleteRMMResource);
}
#endif // !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1

} // namespace xgboost
20 changes: 15 additions & 5 deletions tests/cpp/helpers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,23 @@ std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDeviceDMatrix(bool with_la
return m;
}

void SetUpRMMResource() {
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
using cuda_mr_t = rmm::mr::cuda_memory_resource;
using pool_mr_t = rmm::mr::pool_memory_resource<cuda_mr_t>;
using cuda_mr_t = rmm::mr::cuda_memory_resource;
using pool_mr_t = rmm::mr::pool_memory_resource<cuda_mr_t>;
class RMMAllocator {
public:
std::unique_ptr<pool_mr_t> handle;
};

void DeleteRMMResource(RMMAllocator* r) {
delete r;
}

RMMAllocatorPtr SetUpRMMResource() {
auto cuda_mr = std::make_unique<cuda_mr_t>();
auto pool_mr = std::make_unique<pool_mr_t>(cuda_mr.release());
rmm::mr::set_default_resource(pool_mr.release());
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
rmm::mr::set_default_resource(pool_mr.get());
return RMMAllocatorPtr(new RMMAllocator{std::move(pool_mr)}, DeleteRMMResource);
}
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
} // namespace xgboost
5 changes: 4 additions & 1 deletion tests/cpp/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <fstream>
#include <cstdio>
#include <string>
#include <memory>
#include <vector>
#include <sys/stat.h>
#include <sys/types.h>
Expand Down Expand Up @@ -352,7 +353,9 @@ inline int Next(DataIterHandle self) {
return static_cast<CudaArrayIterForTest*>(self)->Next();
}

void SetUpRMMResource();
class RMMAllocator;
using RMMAllocatorPtr = std::unique_ptr<RMMAllocator, void(*)(RMMAllocator*)>;
RMMAllocatorPtr SetUpRMMResource();

} // namespace xgboost
#endif
3 changes: 2 additions & 1 deletion tests/cpp/test_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include <string>
#include <memory>
#include <vector>

#include "helpers.h"

int main(int argc, char ** argv) {
xgboost::SetUpRMMResource();
auto rmm_alloc = xgboost::SetUpRMMResource();
xgboost::Args args {{"verbosity", "2"}};
xgboost::ConsoleLogger::Configure(args);

Expand Down

0 comments on commit 2bdbc23

Please sign in to comment.