Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine io_convert and op_convert #10461

Merged
merged 1 commit into from
May 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddle/fluid/inference/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc)
add_subdirectory(convert)
5 changes: 3 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
nv_test(test_tensorrt_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
nv_test(test_tensorrt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
nv_test(test_trt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/io_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
#include <cuda.h>
#include "paddle/fluid/platform/enforce.h"

Expand Down Expand Up @@ -50,7 +50,7 @@ class DefaultInputConverter : public EngineInputConverter {
}
};

REGISTER_TENSORRT_INPUT_CONVERTER(mul, DefaultInputConverter);
REGISTER_TENSORRT_INPUT_CONVERTER(default, DefaultInputConverter);

} // namespace tensorrt
} // namespace inference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class EngineInputConverter {
static void Run(const std::string& in_op_type, const LoDTensor& in, void* out,
size_t max_size, cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr);
auto* converter = Registry<EngineInputConverter>::Lookup(in_op_type);
auto* converter = Registry<EngineInputConverter>::Lookup(
in_op_type, "default" /* default_type */);
PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream);
(*converter)(in, out, max_size);
Expand Down
38 changes: 14 additions & 24 deletions paddle/fluid/inference/tensorrt/convert/op_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"

namespace paddle {
namespace inference {
Expand All @@ -32,34 +33,23 @@ class OpConverter {
OpConverter() {}
virtual void operator()(const framework::OpDesc& op) {}

void Execute(const framework::OpDesc& op, TensorRTEngine* engine) {
void Run(const framework::OpDesc& op, TensorRTEngine* engine) {
std::string type = op.Type();
auto it = converters_.find(type);
PADDLE_ENFORCE(it != converters_.end(), "no OpConverter for optype [%s]",
type);
it->second->SetEngine(engine);
(*it->second)(op);
}

static OpConverter& Global() {
static auto* x = new OpConverter;
return *x;
}

template <typename T>
void Register(const std::string& key) {
converters_[key] = new T;
auto* it = Registry<OpConverter>::Lookup(type);
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
it->SetEngine(engine);
(*it)(op);
}

// convert fluid op to tensorrt layer
void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) {
OpConverter::Global().Execute(op, engine);
OpConverter::Run(op, engine);
}

// convert fluid block to tensorrt network
void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) {
for (auto op : block.AllOps()) {
OpConverter::Global().Execute(*op, engine);
OpConverter::Run(*op, engine);
}
}

Expand All @@ -78,12 +68,12 @@ class OpConverter {
framework::Scope* scope_{nullptr};
};

#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
struct trt_##op_type__##_converter { \
trt_##op_type__##_converter() { \
OpConverter::Global().Register<Converter__>(#op_type__); \
} \
}; \
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
struct trt_##op_type__##_converter { \
trt_##op_type__##_converter() { \
Registry<OpConverter>::Register<Converter__>(#op_type__); \
} \
}; \
trt_##op_type__##_converter trt_##op_type__##_converter__;

} // namespace tensorrt
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace paddle {
namespace inference {
namespace tensorrt {

void compare(float input, float expect) {
void Compare(float input, float expect) {
framework::Scope scope;
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
Expand Down Expand Up @@ -85,8 +85,8 @@ void compare(float input, float expect) {
}

TEST(OpConverter, ConvertRelu) {
compare(1, 1); // relu(1) = 1
compare(-5, 0); // relu(-5) = 0
Compare(1, 1); // relu(1) = 1
Compare(-5, 0); // relu(-5) = 0
}

} // namespace tensorrt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/tensorrt/io_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"

#include <gtest/gtest.h>

Expand All @@ -34,7 +34,7 @@ TEST_F(EngineInputConverterTester, DefaultCPU) {
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);

cudaStream_t stream;
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(),
EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
&stream);
}

Expand All @@ -44,7 +44,7 @@ TEST_F(EngineInputConverterTester, DefaultGPU) {
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);

cudaStream_t stream;
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(),
EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
&stream);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace paddle {
namespace inference {
namespace tensorrt {

TEST(BlockConverter, ConvertBlock) {
TEST(OpConverter, ConvertBlock) {
framework::ProgramDesc prog;
auto* block = prog.MutableBlock(0);
auto* mul_op = block->AppendOp();
Expand Down
11 changes: 9 additions & 2 deletions paddle/fluid/inference/utils/singleton.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include <string>
#include <unordered_map>
#include "paddle/fluid/platform/enforce.h"

Expand Down Expand Up @@ -49,9 +50,15 @@ struct Registry {
items_[name] = new ItemChild;
}

static ItemParent* Lookup(const std::string& name) {
static ItemParent* Lookup(const std::string& name,
const std::string& default_name = "") {
auto it = items_.find(name);
if (it == items_.end()) return nullptr;
if (it == items_.end()) {
if (default_name == "")
return nullptr;
else
return items_.find(default_name)->second;
}
return it->second;
}

Expand Down