forked from wang-xinyu/tensorrtx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ibnnet.h
45 lines (36 loc) · 1.09 KB
/
ibnnet.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#pragma once
#include "utils.h"
#include "holder.h"
#include "layers.h"
#include "InferenceEngine.h"
#include <memory>
#include <vector>
#include <chrono>
#include <opencv2/opencv.hpp>
extern Logger gLogger;
using namespace trtxapi;
namespace trt {
enum IBN {
A, // resnet50-ibna,
B, // resnet50-ibnb,
NONE // resnet50
};
class IBNNet {
public:
IBNNet(trt::EngineConfig &enginecfg, const IBN ibn);
~IBNNet() {};
bool serializeEngine(); /* create & serializeEngine */
bool deserializeEngine();
bool inference(std::vector<cv::Mat> &input); /* support batch inference */
float* getOutput();
int getDeviceID(); /* cuda deviceid */
private:
ICudaEngine *createEngine(IBuilder *builder, IBuilderConfig *config);
void preprocessing(const cv::Mat& img, float* const data, const std::size_t stride);
private:
trt::EngineConfig _engineCfg;
std::unique_ptr<trt::InferenceEngine> _inferEngine{nullptr};
std::string _ibn;
DataType _dt{DataType::kFLOAT};
};
}