diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index 135afb3f..8fc3105b 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -8,12 +8,12 @@ #include "cuda/fast_gelu.h" #include "cuda/mul_sigmoid.h" #include "cuda/negxplus1.h" +#include "cuda/replace_zero.h" #include "cuda/scatter_nd_of_shape.h" #include "cuda/transpose_cast.h" #endif FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { - using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput; using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput; @@ -24,7 +24,6 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { using Transpose2DCastFloat16ToFloat32Type = typename contrib::Transpose2DCast; #endif - static OrtOpLoader op_loader( []() { return nullptr; } #ifdef USE_CUDA @@ -36,6 +35,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid), CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero), CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape), #if ORT_API_VERSION >= 16 @@ -47,6 +47,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid), CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero), CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape), CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type), CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type) diff --git a/operators/cuda/replace_zero.h b/operators/cuda/replace_zero.h new file mode 100644 index 00000000..e7974739 --- /dev/null +++ b/operators/cuda/replace_zero.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "replace_zero_impl.cuh" +#include "ortx_common.h" + +namespace contrib { + +/** +* Y = ReplaceZero(X, by=c) is equivalent to: +* +* Y = X.copy() +* X[X == 0] = c +* +* This operation usually appears when a tensor is updated with an operator Equal and Where. +* This kernel avoids the creation of one null tensor. +*/ +template +struct ReplaceZero { + template + OrtxStatus OnModelAttach(const TDict& dict) { + float default_value=0; + by_ = dict.TryToGetAttributeWithDefault("by", default_value); + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& input, + ortc::Tensor& output) const { + const T* input_data = input.Data(); + auto input_shape = input.Shape(); + T* output_data = output.Allocate(input_shape); + auto input_length = input.NumberOfElement(); + if (0 == input_length) { + return {}; + } + + LaunchReplaceZeroKernel(reinterpret_cast(ctx->GetCudaStream()), + input_length, + input_data, + output_data, + by_); + return {}; + } + + private: + float by_; +}; + +} // namespace contrib \ No newline at end of file diff --git a/operators/cuda/replace_zero_impl.cu b/operators/cuda/replace_zero_impl.cu new file mode 100644 index 00000000..43952c30 --- /dev/null +++ b/operators/cuda/replace_zero_impl.cu @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "device_prop.cuh" +#include "utils.cuh" +#include "replace_zero_impl.cuh" +#include "cuda_type.h" + +#ifndef CUDA_LONG +#define CUDA_LONG int32_t +#endif + +using namespace Ort::Custom; + +template +__device__ __inline__ T _replace_zero(const T x, const T by) { + return x == (T)0 ? by : x; +} + +template <> +__device__ __inline__ half _replace_zero(const half x, const half by) { +#if __CUDA_ARCH__ < 700 + return __half2float(x) == 0 ? by : x; +#else + return x == (half)0 ? by : x; +#endif +} + +template +__global__ void ReplaceZeroKernel(T* output_data, const T* input_data, CUDA_LONG N, const T by) { + CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= N) + return; + output_data[id] = _replace_zero(input_data[id], by); +} + +template +T _cast(float value) { return (T)value; } + +template <> +half _cast(float value) { return __float2half(value); } + +template +cudaError_t _LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by) { + if (input_length == 0) + return cudaGetLastError(); + using TT = typename contrib::CudaT::MappedType; + + CUDA_LONG N = static_cast(input_length); + + const int num_threads_per_block = 256; + const int num_elements_per_thread = (N + num_threads_per_block - 1) / num_threads_per_block; + + TT cby = _cast(by); + ReplaceZeroKernel<<>>( + reinterpret_cast(output_data), reinterpret_cast(input_data), N, cby); + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const float* input_data, float* output_data, float by) { + return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by); +} + +template <> +cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const ortc::MFloat16* input_data, ortc::MFloat16* output_data, float by) { + return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by); +} diff --git a/operators/cuda/replace_zero_impl.cuh b/operators/cuda/replace_zero_impl.cuh new file mode 100644 index 00000000..7d975d4d --- /dev/null +++ b/operators/cuda/replace_zero_impl.cuh @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +template +cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by); diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index db40f612..43233a26 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -652,6 +652,66 @@ def test_transpose_cast_cuda(self): self._transpose_cast_cuda(TensorProto.FLOAT) self._transpose_cast_cuda(TensorProto.FLOAT16) + def _replace_zero_cuda(self, itype): + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node("Equal", ["X", "zero"], ["cond"]), + helper.make_node("Where", ["cond", "cst", "X"], ["Y"]), + ], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None, None])], + [helper.make_tensor_value_info("Y", itype, [None, None, None])], + [ + numpy_helper.from_array(np.array([0], dtype=dtype), name="zero"), + numpy_helper.from_array(np.array([1.67], dtype=dtype), name="cst"), + ], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + "ReplaceZero", + ["X"], + ["Y"], + by=1.67, + domain="ai.onnx.contrib", + ) + ], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None, None])], + [helper.make_tensor_value_info("Y", itype, [None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + x = (np.arange(18) - 4).reshape((3, 2, 3)).astype(dtype) + + feeds1 = dict(X=x) + ref = ReferenceEvaluator(model1) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds1)[0] + assert_allclose(expected, got, atol=1e-5) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_replace_zero_cuda(self): + self._replace_zero_cuda(TensorProto.FLOAT) + self._replace_zero_cuda(TensorProto.FLOAT16) + if __name__ == "__main__": unittest.main(verbosity=2)