Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix C4459 warning in custom_op_lite.h #751

Merged
merged 5 commits into from
Jun 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 73 additions & 74 deletions include/custom_op/custom_op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ namespace Custom {

class OrtKernelContextStorage : public ITensorStorage {
public:
OrtKernelContextStorage(const OrtW::CustomOpApi& api,
OrtKernelContextStorage(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
if (is_input) {
auto input_count = api.KernelContext_GetInputCount(&ctx);
auto input_count = api_.KernelContext_GetInputCount(&ctx);
if (indice >= input_count) {
ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
}
const_value_ = api.KernelContext_GetInput(&ctx, indice);
auto* info = api.GetTensorTypeAndShape(const_value_);
shape_ = api.GetTensorShape(info);
api.ReleaseTensorTypeAndShapeInfo(info);
const_value_ = api_.KernelContext_GetInput(&ctx, indice);
auto* info = api_.GetTensorTypeAndShape(const_value_);
shape_ = api_.GetTensorShape(info);
api_.ReleaseTensorTypeAndShapeInfo(info);
}
}

Expand Down Expand Up @@ -66,18 +66,18 @@ class OrtKernelContextStorage : public ITensorStorage {
std::optional<std::vector<int64_t>> shape_;
};

static std::string get_mem_type(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input){
static std::string get_mem_type(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) {
std::string output = "Cpu";
if (is_input) {
const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice);
const OrtValue* const_value = custom_op_api.KernelContext_GetInput(&ctx, indice);
const OrtMemoryInfo* mem_info = {};
api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info));
custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info));
if (mem_info) {
const char* mem_type = nullptr;
api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type));
custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type));
if (mem_type) {
output = mem_type;
}
Expand All @@ -88,29 +88,29 @@ static std::string get_mem_type(const OrtW::CustomOpApi& api,

template <typename T>
class OrtTensor : public Tensor<T> {
public:
OrtTensor(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(api, ctx, indice, is_input)),
mem_type_(get_mem_type(api, ctx, indice, is_input)) {
public:
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(custom_op_api, ctx, indice, is_input)),
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {
}

bool IsCpuTensor() const {
return mem_type_ == "Cpu";
}

private:
private:
std::string mem_type_ = "Cpu";
};

class OrtStringTensorStorage : public IStringTensorStorage<std::string> {
public:
using strings = std::vector<std::string>;
OrtStringTensorStorage(const OrtW::CustomOpApi& api,
OrtStringTensorStorage(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
if (is_input) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
if (indice >= input_count) {
Expand Down Expand Up @@ -197,10 +197,10 @@ class OrtStringTensorStorage : public IStringTensorStorage<std::string> {
class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view> {
public:
using strings = std::vector<std::string_view>;
OrtStringViewTensorStorage(const OrtW::CustomOpApi& api,
OrtStringViewTensorStorage(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
if (is_input) {
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
if (indice >= input_count) {
Expand Down Expand Up @@ -275,57 +275,56 @@ class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view>

// to make the metaprogramming magic happy.
template <>
class OrtTensor<std::string> : public Tensor<std::string>{
public:
OrtTensor(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(api, ctx, indice, is_input)),
mem_type_(get_mem_type(api, ctx, indice, is_input)) {}
class OrtTensor<std::string> : public Tensor<std::string> {
public:
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(custom_op_api, ctx, indice, is_input)),
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {}

bool IsCpuTensor() const {
return mem_type_ == "Cpu";
}

private:
private:
std::string mem_type_ = "Cpu";
};

template <>
class OrtTensor<std::string_view> : public Tensor<std::string_view>{
public:
OrtTensor(const OrtW::CustomOpApi& api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(api, ctx, indice, is_input)),
mem_type_(get_mem_type(api, ctx, indice, is_input)) {}
class OrtTensor<std::string_view> : public Tensor<std::string_view> {
public:
OrtTensor(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(custom_op_api, ctx, indice, is_input)),
mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {}

bool IsCpuTensor() const {
return mem_type_ == "Cpu";
}

private:
private:
std::string mem_type_ = "Cpu";
};

using TensorPtr = std::unique_ptr<Custom::Arg>;
using TensorPtrs = std::vector<TensorPtr>;


using TensorBasePtr = std::unique_ptr<Custom::TensorBase>;
using TensorBasePtrs = std::vector<TensorBasePtr>;

// Represent variadic input or output
struct Variadic : public Arg {
Variadic(const OrtW::CustomOpApi& api,
Variadic(const OrtW::CustomOpApi& custom_op_api,
OrtKernelContext& ctx,
size_t indice,
bool is_input) : api_(api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(api, ctx, indice, is_input)) {
bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {
#if ORT_API_VERSION < 14
ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
#endif
if (is_input) {
auto input_count = api.KernelContext_GetInputCount(&ctx_);
auto input_count = api_.KernelContext_GetInputCount(&ctx_);
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
auto* info = api_.GetTensorTypeAndShape(const_value);
Expand All @@ -334,40 +333,40 @@ struct Variadic : public Arg {
TensorBasePtr tensor;
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
tensor = std::make_unique<Custom::OrtTensor<bool>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<bool>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
tensor = std::make_unique<Custom::OrtTensor<float>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<float>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
tensor = std::make_unique<Custom::OrtTensor<double>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<double>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
tensor = std::make_unique<Custom::OrtTensor<uint8_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<uint8_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
tensor = std::make_unique<Custom::OrtTensor<int8_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<int8_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
tensor = std::make_unique<Custom::OrtTensor<uint16_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<uint16_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
tensor = std::make_unique<Custom::OrtTensor<int16_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<int16_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
tensor = std::make_unique<Custom::OrtTensor<uint32_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<uint32_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
tensor = std::make_unique<Custom::OrtTensor<int32_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<int32_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
tensor = std::make_unique<Custom::OrtTensor<uint64_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<uint64_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
tensor = std::make_unique<Custom::OrtTensor<int64_t>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<int64_t>>(api_, ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
tensor = std::make_unique<Custom::OrtTensor<std::string>>(api, ctx, ith_input, true);
tensor = std::make_unique<Custom::OrtTensor<std::string>>(api_, ctx, ith_input, true);
break;
default:
ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
Expand Down Expand Up @@ -395,7 +394,7 @@ struct Variadic : public Arg {
size_t Size() const {
return tensors_.size();
}

const TensorBasePtr& operator[](size_t indice) const {
return tensors_.at(indice);
}
Expand All @@ -412,11 +411,11 @@ struct Variadic : public Arg {

class OrtGraphKernelContext : public KernelContext {
public:
OrtGraphKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
OrtGraphKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) {
OrtMemoryInfo* info;
OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &allocator_));
api.ReleaseMemoryInfo(info);
OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &allocator_));
api_.ReleaseMemoryInfo(info);
}

virtual ~OrtGraphKernelContext() {
Expand Down Expand Up @@ -458,31 +457,31 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext {
public:
static const int cuda_resource_ver = 1;

OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
OrtGraphCudaKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) {
api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
if (!cuda_stream_) {
ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
}
api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_);
api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_);
if (!cublas_) {
ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
}
void* resource = nullptr;
OrtStatusPtr result = api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
OrtStatusPtr result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
if (result) {
ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION);
}
memcpy(&device_id_, &resource, sizeof(int));

OrtMemoryInfo* info;
OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_));
api.ReleaseMemoryInfo(info);
OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_));
api_.ReleaseMemoryInfo(info);

OrtMemoryInfo* cuda_mem_info;
OrtW::ThrowOnError(api, api.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_));
api.ReleaseMemoryInfo(cuda_mem_info);
OrtW::ThrowOnError(api_, api_.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_));
api_.ReleaseMemoryInfo(cuda_mem_info);
}

virtual ~OrtGraphCudaKernelContext() {
Expand Down Expand Up @@ -944,7 +943,7 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {

class OrtAttributeReader {
public:
OrtAttributeReader(const OrtApi& api, const OrtKernelInfo& info) : base_kernel_(api, info) {
OrtAttributeReader(const OrtApi& ort_api, const OrtKernelInfo& info) : base_kernel_(ort_api, info) {
}

template <class T>
Expand Down
Loading