Skip to content

Commit

Permalink
Merge branch 'main' into unimg
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl authored Oct 22, 2024
2 parents 3e4c03d + aa2c82f commit 0fbe523
Show file tree
Hide file tree
Showing 17 changed files with 893 additions and 104 deletions.
1 change: 0 additions & 1 deletion include/custom_op/tensor_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,6 @@ class Tensor<std::string_view> : public TensorBase {
std::unique_ptr<IStringTensorStorage<std::string_view>> storage_;
};


template<typename ...Args>
class NamedArgumentDict{
public:
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime_extensions/pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def pre_process(self, images):
return image_pre_process(self.processor, images)

@staticmethod
def to_numpy(result):
return tensor_result_get_at(result, 0)
def to_numpy(result, idx):
return tensor_result_get_at(result, idx)

def __del__(self):
if delete_object and self.processor:
Expand Down
2 changes: 1 addition & 1 deletion shared/api/c_api_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,5 @@ extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawIm
*result = nullptr;
}

return {};
return status.Code();
}
23 changes: 0 additions & 23 deletions shared/api/c_api_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,26 +160,3 @@ extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data
return extError_t();
}

extError_t ORTX_API_CALL OrtxGetTensorDataInt64(OrtxTensor* tensor, const int64_t** data, const int64_t** shape,
size_t* num_dims) {
const void* data_ptr{};
auto err = OrtxGetTensorData(tensor, &data_ptr, shape, num_dims);
*data = reinterpret_cast<const int64_t*>(data_ptr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
return err;
}

extError_t ORTX_API_CALL OrtxGetTensorDataFloat(OrtxTensor* tensor, const float** data, const int64_t** shape,
size_t* num_dims) {
const void* data_ptr{};
auto err = OrtxGetTensorData(tensor, &data_ptr, shape, num_dims);
*data = reinterpret_cast<const float*>(data_ptr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
return err;
}

extError_t ORTX_API_CALL OrtxGetTensorDataUint8(OrtxTensor* tensor, const uint8_t** data, const int64_t** shape,
size_t* num_dims) {
const void* data_ptr{};
auto err = OrtxGetTensorData(tensor, &data_ptr, shape, num_dims);
*data = reinterpret_cast<const uint8_t*>(data_ptr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
return err;
}
7 changes: 7 additions & 0 deletions shared/api/c_api_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once
#include <vector>
#include <fstream>
#include <variant>

#include "ortx_utils.h"
#include "file_sys.h"
Expand Down Expand Up @@ -258,4 +259,10 @@ std::tuple<std::unique_ptr<T[]>, size_t> LoadRawData(It begin, It end) {

return std::make_tuple(std::move(raw_data), n);
}

using AttrType =
std::variant<std::string, double, int64_t, std::vector<std::string>, std::vector<double>, std::vector<int64_t>>;
using AttrDict = std::unordered_map<std::string, AttrType>;
} // namespace ort_extensions

namespace ortx = ort_extensions;
20 changes: 15 additions & 5 deletions shared/api/image_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,36 @@
#include "vision/decode_image.hpp"
#include "image_processor.h"
#include "c_api_utils.hpp"

#include "image_transforms.hpp"
#include "image_transforms_phi_3.hpp"
#include "image_transforms_mllama.hpp"

using namespace ort_extensions;
using json = nlohmann::json;

namespace ort_extensions {
std::tuple<std::unique_ptr<ImageRawData[]>, size_t>
ort_extensions::LoadRawImages(const std::initializer_list<const char*>& image_paths) {
LoadRawImages(const std::initializer_list<const char*>& image_paths) {
return ort_extensions::LoadRawData<const char* const*, ImageRawData>(image_paths.begin(), image_paths.end());
}

std::tuple<std::unique_ptr<ImageRawData[]>, size_t>
LoadRawImages(const char* image_paths[], size_t num_images) {
return ort_extensions::LoadRawData<const char* const*, ImageRawData>(image_paths, image_paths + num_images);
}
} // namespace ort_extensions

using namespace ort_extensions;
using json = nlohmann::json;

Operation::KernelRegistry ImageProcessor::kernel_registry_ = {
{"DecodeImage", []() { return CreateKernelInstance(&ort_extensions::DecodeImage::Compute); }},
{"Resize", []() { return CreateKernelInstance(&Resize::Compute); }},
{"Rescale", []() { return CreateKernelInstance(&Rescale::Compute); }},
{"Normalize", []() { return CreateKernelInstance(&Normalize::Compute); }},
{"CenterCrop", []() { return CreateKernelInstance(&CenterCrop::Compute); }},
{"ConvertRGB", []() { return CreateKernelInstance(convert_to_rgb); }},
{"Permute3D", []() { return CreateKernelInstance(&Permute3D::Compute); }},
{"Phi3ImageTransform", []() { return CreateKernelInstance(phi3_hd_transform); }},
{"Llama3ImageTransform", []() { return CreateKernelInstance(&Llama3ImageTransform::Compute); }},
};

OrtxStatus ImageProcessor::Init(std::string_view processor_def) {
Expand Down Expand Up @@ -179,7 +190,6 @@ OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_d
operations_.back()->ResetTensors(allocator_);
if (status.IsOk()) {
r.SetTensors(std::move(img_result));
// r.SetTensorTypes({kOrtxFloat, kOrtxInt64, kOrtxInt64});
}

return status;
Expand Down
5 changes: 4 additions & 1 deletion shared/api/image_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ namespace ort_extensions {
using ImageRawData = std::vector<uint8_t>;

std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(
const std::initializer_list<const char*>& image_paths);
const std::initializer_list<const char*>& image_paths);

std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(
const char* image_paths[], size_t num_images);

class ProcessorResult : public OrtxObjectImpl {
public:
Expand Down
154 changes: 113 additions & 41 deletions shared/api/image_transforms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,31 @@

#pragma once

#include "ocos.h"
#include "ext_status.h"
#include "op_def_struct.h"
#include "image_resample.h"

template <typename T>
void DumpTensorToFile(const ortc::Tensor<T>& tensor, const char* name) {
#if WIN32
auto tic = GetTickCount();
std::string dtype;
if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, std::byte>) {
dtype = "_u_";
} else {
dtype = "_f_";
}
dtype += std::to_string(tensor.Shape()[1]);
// use tic to be filename in a temp file name
auto filename = std::string("\\temp\\") + name + std::to_string(tic) + dtype + ".bin";
std::ofstream file(filename, std::ios::out | std::ios::binary);
if (file.is_open()) {
file.write(reinterpret_cast<const char*>(tensor.DataRaw()), tensor.SizeInBytes());
file.close();
}
#endif
}

inline OrtxStatus convert_to_rgb(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) {
auto& dimensions = input.Shape();
if (dimensions.size() != 3ULL || dimensions[2] != 3) {
Expand All @@ -31,23 +53,13 @@ inline OrtxStatus convert_to_rgb(const ortc::Tensor<uint8_t>& input, ortc::Tenso
}

struct Resize {
template <typename DictT>
OrtxStatus Init(const DictT& attrs) {
for (const auto& [key, value] : attrs) {
if (key == "height") {
height_ = std::get<int64_t>(value);
} else if (key == "width") {
width_ = std::get<int64_t>(value);
} else if (key == "interpolation") {
interpolation_ = std::get<std::string>(value);
if (interpolation_ != "NEAREST" && interpolation_ != "LINEAR" && interpolation_ != "CUBIC") {
return {kOrtxErrorInvalidArgument, "[Resize]: Invalid interpolation method"};
}
} else {
return {kOrtxErrorInvalidArgument, "[Resize]: Invalid argument"};
}
}
return {};
static const std::unordered_map<std::string, int> InterpolationMethods() {
return {
{"NEAREST", IMAGING_TRANSFORM_NEAREST},
{"LINEAR", IMAGING_TRANSFORM_BILINEAR},
{"CUBIC", IMAGING_TRANSFORM_BICUBIC},
{"LANCZOS", IMAGING_TRANSFORM_LANCZOS}
};
}

OrtxStatus Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) {
Expand All @@ -72,48 +84,65 @@ struct Resize {
}
}

int interp = IMAGING_TRANSFORM_NEAREST;
if (interpolation_ == "NEAREST") {
interp = IMAGING_TRANSFORM_NEAREST;
} else if (interpolation_ == "LINEAR") {
interp = IMAGING_TRANSFORM_BILINEAR;
} else if (interpolation_ == "CUBIC") {
interp = IMAGING_TRANSFORM_BICUBIC;
} else if (interpolation_ == "LANCZOS") {
interp = IMAGING_TRANSFORM_LANCZOS;
} else {
return {kOrtxErrorInvalidArgument, "[Resize]: Invalid interpolation method"};
}
int interp = InterpolationMethods().at(interpolation_);
float box[4] = {0.0f, 0.0f, static_cast<float>(w), static_cast<float>(h)};
auto [height, width] = std::make_tuple(height_, width_);

float box[4] = {0.0f, 0.0f, static_cast<float>(width_), static_cast<float>(height_)};
if (keep_aspect_ratio_) {
double scale = (std::max)(static_cast<double>(width) / w, static_cast<double>(height) / h);
width = static_cast<int64_t>(w * scale);
height = static_cast<int64_t>(h * scale);
}

auto output_image = ImagingResample(rgb_image, static_cast<int>(width_), static_cast<int>(height_), interp, box);
// cv::resize(image, output_image, {static_cast<int32_t>(width_), static_cast<int32_t>(height_)}, 0.0, 0.0, interp);
auto output_image = ImagingResample(rgb_image, static_cast<int>(width), static_cast<int>(height), interp, box);
ImagingDelete(rgb_image);

auto* p_output_image = output.Allocate({height_, width_, c});
for (auto i = height_ - height_; i < height_; ++i) {
for (auto j = width_ - width_; j < width_; ++j) {
auto c0_index = i * width_ * c + j * c;
std::memcpy(p_output_image + c0_index, output_image->image[i] + j * 4, c);
auto* p_output_image = output.Allocate({height, width, c});
for (auto i = height - height; i < height; ++i) {
for (auto j = width - width; j < width; ++j) {
auto c0_index = i * width * c + j * c;
std::memcpy(p_output_image + c0_index, output_image->image[i] + j * 4, c);
}
}
// DumpTensor(output);

ImagingDelete(output_image);
return {};
}

template <typename DictT>
OrtxStatus Init(const DictT& attrs) {
for (const auto& [key, value] : attrs) {
if (key == "height") {
height_ = std::get<int64_t>(value);
} else if (key == "width") {
width_ = std::get<int64_t>(value);
} else if (key == "keep_aspect_ratio") {
keep_aspect_ratio_ = std::get<int64_t>(value) != 0;
} else if (key == "interpolation") {
interpolation_ = std::get<std::string>(value);
if (InterpolationMethods().find(interpolation_) == InterpolationMethods().end()) {
return {kOrtxErrorInvalidArgument, "[Resize]: Invalid interpolation method"};
}
} else {
return {kOrtxErrorInvalidArgument, "[Resize]: Invalid argument"};
}
}
return {};
}

private:
int64_t height_{256};
int64_t width_{256};
bool keep_aspect_ratio_{true};
std::string interpolation_{"CUBIC"}; // LINEAR, NEAREST, CUBIC
};

struct Rescale {
template <typename DictT>
OrtxStatus Init(const DictT& attrs) {
for (const auto& [key, value] : attrs) {
if (key == "scale") {
if (key == "rescale_factor") {
scale_ = static_cast<float>(std::get<double>(value));
} else {
return {kOrtxErrorInvalidArgument, "[Rescale]: Invalid argument"};
Expand All @@ -139,7 +168,7 @@ struct Rescale {
for (int64_t k = 0; k < w; ++k) {
auto c0_index = j * w * c + k * c;
for (int64_t l = 0; l < c; ++l) {
p_output_image[c0_index + l] = input_data[c0_index + l] * scale_;
p_output_image[c0_index + l] = static_cast<float>(input_data[c0_index + l]) * scale_;
}
}
}
Expand Down Expand Up @@ -220,7 +249,6 @@ struct CenterCrop {
// s_h = torch.div((img_h - height), 2, rounding_mode='trunc')
// s_w = torch.div((img_w - width), 2, rounding_mode='trunc')
// x = img[:, :, s_h:s_h + height, s_w:s_w + width]

OrtxStatus Compute(const ortc::Tensor<uint8_t>& input, ortc::Tensor<uint8_t>& output) {
auto& dimensions = input.Shape();
if (dimensions.size() != 3ULL) {
Expand Down Expand Up @@ -252,3 +280,47 @@ struct CenterCrop {
int64_t target_h_{224};
int64_t target_w_{224};
};

struct Permute3D {

OrtxStatus Compute(const ortc::Tensor<float>& input, ortc::Tensor<float>& output) {
auto& dimensions = input.Shape();
if (dimensions.size() != 3ULL || dims_.size() != 3ULL) {
return {kOrtxErrorInvalidArgument, "[Permute]: Only 3D tensors are supported"};
}

auto* input_data = input.Data();
std::vector<int64_t> output_shape = {dimensions[dims_[0]], dimensions[dims_[1]], dimensions[dims_[2]]};
auto* p_output_image = output.Allocate(output_shape);

for (int64_t i = 0; i < dimensions[0]; ++i) {
for (int64_t j = 0; j < dimensions[1]; ++j) {
for (int64_t k = 0; k < dimensions[2]; ++k) {
auto c0_index = i * dimensions[1] * dimensions[2] + j * dimensions[2] + k;
auto c1_index = (dims_[0] == 0 ? i : (dims_[0] == 1 ? j : k)) * output_shape[1] * output_shape[2] +
(dims_[1] == 0 ? i : (dims_[1] == 1 ? j : k)) * output_shape[2] +
(dims_[2] == 0 ? i : (dims_[2] == 1 ? j : k));
p_output_image[c1_index] = input_data[c0_index];
}
}
}

return {};
}

template <typename DictT>
OrtxStatus Init(const DictT& attrs) {
for (const auto& [key, value] : attrs) {
if (key == "dims") {
dims_ = std::get<std::vector<int64_t>>(value);
} else {
return {kOrtxErrorInvalidArgument, "[Permute]: Invalid argument"};
}
}

return {};
}

private:
std::vector<int64_t> dims_{1, 2, 0};
};
Loading

0 comments on commit 0fbe523

Please sign in to comment.