Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle-Inference]: fix concat slice #39096

Merged
merged 7 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions paddle/fluid/inference/tensorrt/convert/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@ class ConcatOpConverter : public OpConverter {
itensors.push_back(engine_->GetITensor(input_name));
}
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));

if (axis == -1) {
axis = (engine_->GetITensor(op_desc.Input("X").front())->getDimensions())
.nbDims -
1;
} else {
if (!engine_->with_dynamic_shape()) {
axis = axis - 1; // Remove batch dim
}
}
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(),
itensors.size());
if (!engine_->with_dynamic_shape()) {
axis = axis - 1; // Remove batch dim
}
layer->setAxis(axis);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "concat", {output_name}, test_mode);
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class SliceOpConverter : public OpConverter {

nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
if (engine_->use_oss() && engine_->with_ernie()) {
if (engine_->use_oss() && engine_->with_ernie() &&
input_dims.nbDims == 4) {
std::vector<nvinfer1::ITensor*> plugin_inputs;
if (engine_->with_interleaved()) {
auto* shuffler_slice = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
Expand All @@ -81,7 +82,7 @@ class SliceOpConverter : public OpConverter {
engine_->SetTensorDynamicRange(shuffler_slice->getOutput(0),
out_scale);
shuffler_slice->setName(
("SpecialSlice_interleaved: Shuffle: (Output: " + output_name +
("SpecialSlice_interleaved: transpose: (Output: " + output_name +
")")
.c_str());
plugin_inputs.emplace_back(shuffler_slice->getOutput(0));
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
}
int axis = BOOST_GET_CONST(int, desc.GetAttr("axis"));
if (with_dynamic_shape) {
if (axis < 0) return false;
} else {
if (axis <= 0) return false;
if (!with_dynamic_shape) {
if (axis == 0) return false;
}
auto concat_inputs = desc.Inputs();
if (concat_inputs.find("AxisTensor") != concat_inputs.end()) {
Expand Down
20 changes: 11 additions & 9 deletions paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,32 +113,34 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input,
const int32_t* cu_seqlens, T* output) {
const int hidden = blockDim.x;
const int hidden = blockDim.x * gridDim.y;
const int batch = blockIdx.x;
const int local_idx = blockIdx.y * blockDim.y + threadIdx.x;

output[batch * hidden + threadIdx.x] =
slice_input[cu_seqlens[batch] * hidden + threadIdx.x];
output[batch * hidden + local_idx] =
slice_input[cu_seqlens[batch] * hidden + local_idx];
}

int SpecialSlicePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims; // (sum(S), 768, 1, 1)
auto out_dims = output_desc[0].dims; // (batch, 768, 1, 1)
auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1)
auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1)

assert(input_desc[0].type == nvinfer1::DataType::kHALF);
assert(hidden % 128 == 0);

const int32_t hidden = input_dims.d[1];
const int num_blocks = out_dims.d[0]; // batch size
const int num_threads = hidden;
constexpr int num_threads = 128;
const dim3 blocks(out_dims.d[0], hidden / num_threads);

const half* slice_input = static_cast<const half*>(inputs[0]);
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
half* output = static_cast<half*>(outputs[0]);

SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
slice_input, cu_seqlens, output);
SpecialSliceKernel<<<blocks, num_threads, 0, stream>>>(slice_input,
cu_seqlens, output);

return cudaGetLastError() != cudaSuccess;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def generate_input3(attrs: List[Dict[str, Any]], batch):
def generate_weight1(attrs: List[Dict[str, Any]]):
return np.zeros([1]).astype(np.int32)

for dims in [1, 2, 3, 4]:
for dims in [2, 3, 4]:
for num_input in [0, 1]:
for batch in [1, 2, 4]:
for axis in [-1, 0, 1, 2, 3]:
Expand Down Expand Up @@ -277,12 +277,9 @@ def clear_dynamic_shape():

def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape == True:
if attrs[0]['axis'] >= 0:
return 1, 4
else:
return 0, 5
return 1, 4
else:
if attrs[0]['axis'] > 0:
if attrs[0]['axis'] != 0:
return 1, 4
else:
return 0, 5
Expand Down