Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature extraction C API for whipser model #755

Merged
merged 10 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ if (MSVC)
endif()
message(STATUS "_STATIC_MSVC_RUNTIME_LIBRARY: ${_STATIC_MSVC_RUNTIME_LIBRARY}")

# DLL initialization errors due to old conda msvcp140.dll dll are a result of the new MSVC compiler
# See https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856
# Remove this definition once the conda msvcp140.dll dll is updated.
add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
endif()

if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
Expand Down Expand Up @@ -442,7 +446,9 @@ endif()
if(OCOS_ENABLE_BERT_TOKENIZER)
# Bert
set(_HAS_TOKENIZER ON)
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*")
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*"
"operators/tokenizer/bert_tokenizer.*"
"operators/tokenizer/bert_tokenizer_decoder.*")
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
endif()

Expand Down Expand Up @@ -820,7 +826,9 @@ if(OCOS_ENABLE_AZURE)
endif()

target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS})
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:noexcep_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")

target_link_libraries(ortcustomops PUBLIC ocos_operators)

if(_BUILD_SHARED_LIBRARY)
Expand All @@ -840,7 +848,8 @@ if(_BUILD_SHARED_LIBRARY)
standardize_output_folder(extensions_shared)

if(LINUX OR ANDROID)
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS
" -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
# strip if not a debug build
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-s")
Expand Down
19 changes: 19 additions & 0 deletions docs/c_api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ONNXRuntime Extensions C ABI

ONNXRuntime Extensions provides a C-style ABI for pre-processing. It offers support for tokenization, image processing, speech feature extraction, and more. You can compile the ONNXRuntime Extensions as either a static library or a dynamic library to access these APIs.

The C ABI header files are named `ortx_*.h` and can be found in the include folder. There are three types of data processing APIs available:

- [`ortx_tokenizer.h`](../include/ortx_tokenizer.h): Provides tokenization for LLM models.
- [`ortx_processor.h`](../include/ortx_processor.h): Offers image processing APIs for multimodels.
- [`ortx_extraction.h`](../include/ortx_extractor.h): Provides speech feature extraction for audio data processing to assist the Whisper model.

## ABI QuickStart

Most APIs accept raw data inputs such as audio, image compressed binary formats, or UTF-8 encoded text for tokenization.

**Tokenization:** You can create a tokenizer object using `OrtxCreateTokenizer` and then use the object to tokenize a text or decode the token ID into the text. A C-style code snippet is available [here](../test/pp_api_test/c_only_test.c).

**Image processing:** `OrtxCreateProcessor` can create an image processor object from a pre-defined workflow in JSON format to process image files into a tensor-like data type. An example code snippet can be found [here](../test/pp_api_test/test_processor.cc#L75).

**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extractor.cc#L16).
File renamed without changes.
75 changes: 75 additions & 0 deletions include/ortx_extractor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// C ABI header file for the onnxruntime-extensions tokenization module

#pragma once

#include "ortx_utils.h"

typedef OrtxObject OrtxFeatureExtractor;
typedef OrtxObject OrtxRawAudios;
typedef OrtxObject OrtxTensorResult;

#ifdef __cplusplus
extern "C" {
#endif

/**
* @brief Creates a feature extractor object.
*
* This function creates a feature extractor object based on the provided feature definition.
*
* @param[out] extractor Pointer to a pointer to the created feature extractor object.
* @param[in] fe_def The feature definition used to create the feature extractor.
*
* @return An error code indicating the result of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* fe_def);

/**
* Loads a collection of audio files into memory.
*
* This function loads a collection of audio files specified by the `audio_paths` array
* into memory and returns a pointer to the loaded audio data in the `audios` parameter.
*
* @param audios A pointer to a pointer that will be updated with the loaded audio data.
* The caller is responsible for freeing the memory allocated for the audio data.
* @param audio_paths An array of strings representing the paths to the audio files to be loaded.
* @param num_audios The number of audio files to be loaded.
*
* @return An `extError_t` value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** audios, const char* const* audio_paths, size_t num_audios);

/**
* @brief Creates an array of raw audio objects.
*
* This function creates an array of raw audio objects based on the provided data and sizes.
*
* @param audios Pointer to the variable that will hold the created raw audio objects.
* @param data Array of pointers to the audio data.
* @param sizes Array of pointers to the sizes of the audio data.
* @param num_audios Number of audio objects to create.
*
* @return extError_t Error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void* data[], const int64_t* sizes[], size_t num_audios);

/**
* @brief Calculates the log mel spectrogram for a given audio using the specified feature extractor.
*
* This function takes an instance of the OrtxFeatureExtractor struct, an instance of the OrtxRawAudios struct,
* and a pointer to a OrtxTensorResult pointer. It calculates the log mel spectrogram for the given audio using
* the specified feature extractor and stores the result in the provided log_mel pointer.
*
* @param extractor The feature extractor to use for calculating the log mel spectrogram.
* @param audio The raw audio data to process.
* @param log_mel A pointer to a OrtxTensorResult pointer where the result will be stored.
* @return An extError_t value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio, OrtxTensorResult** log_mel);

#ifdef __cplusplus
}
#endif
36 changes: 16 additions & 20 deletions include/ortx_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
// typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting
typedef OrtxObject OrtxProcessor;
typedef OrtxObject OrtxRawImages;
typedef OrtxObject OrtxImageProcessorResult;

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -40,8 +39,22 @@ extError_t ORTX_API_CALL OrtxCreateProcessor(OrtxProcessor** processor, const ch
extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** image_paths, size_t num_images,
size_t* num_images_loaded);


/**
* @brief Preprocesses the given raw images using the specified processor.
* @brief Creates raw images from the provided data.
*
* This function creates raw images from the provided data. The raw images are stored in the `images` parameter.
*
* @param images Pointer to a pointer to the `OrtxRawImages` structure that will hold the created raw images.
* @param data Array of pointers to the data for each image.
* @param sizes Array of pointers to the sizes of each image.
* @param num_images Number of images to create.
* @return An `extError_t` value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateRawImages(OrtxRawImages** images, const void* data[], const int64_t* sizes[], size_t num_images);

/**
* @brief Pre-processes the given raw images using the specified processor.
*
* This function applies preprocessing operations on the raw images using the provided processor.
* The result of the preprocessing is stored in the `OrtxImageProcessorResult` object.
Expand All @@ -52,24 +65,7 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima
* @return An `extError_t` value indicating the success or failure of the preprocessing operation.
*/
extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images,
OrtxImageProcessorResult** result);

/**
* @brief Retrieves the image processor result at the specified index.
*
* @param result Pointer to the OrtxImageProcessorResult structure to store the result.
* @param index The index of the result to retrieve.
* @return extError_t The error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor);

/** \brief Clear the outputs of the processor
*
* \param processor The processor object
* \param result The result object to clear
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result);
OrtxTensorResult** result);

#ifdef __cplusplus
}
Expand Down
19 changes: 17 additions & 2 deletions include/ortx_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ typedef enum {
kOrtxKindDetokenizerCache = 0x778B,
kOrtxKindProcessor = 0x778C,
kOrtxKindRawImages = 0x778D,
kOrtxKindImageProcessorResult = 0x778E,
kOrtxKindTensorResult = 0x778E,
kOrtxKindProcessorResult = 0x778F,
kOrtxKindTensor = 0x7790,
kOrtxKindFeatureExtractor = 0x7791,
kOrtxKindRawAudios = 0x7792,
kOrtxKindEnd = 0x7999
} extObjectKind_t;

// all object managed by the library should be 'derived' from this struct
// which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C
typedef struct {
int ext_kind_;
extObjectKind_t ext_kind_;
} OrtxObject;

typedef OrtxObject OrtxTensor;
typedef OrtxObject OrtxTensorResult;

// C, instead of C++ doesn't cast automatically,
// so we need to use a macro to cast the object to the correct type
Expand Down Expand Up @@ -77,6 +80,18 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);
*/
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);

/**
* @brief Retrieves the tensor at the specified index from the given tensor result.
*
* This function allows you to access a specific tensor from a tensor result object.
*
* @param result The tensor result object from which to retrieve the tensor.
* @param index The index of the tensor to retrieve.
* @param tensor A pointer to a variable that will hold the retrieved tensor.
* @return An error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor);

/** \brief Get the data from the tensor
*
* \param tensor The tensor object
Expand Down
78 changes: 24 additions & 54 deletions onnxruntime_extensions/_torch_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._ortapi2 import make_onnx_model
from ._cuops import SingleOpGraph
from ._hf_cvt import HFTokenizerConverter
from .util import remove_unused_initializers
from .util import remove_unused_initializers, mel_filterbank


class _WhisperHParams:
Expand All @@ -30,53 +30,15 @@ class _WhisperHParams:
N_FRAMES = N_SAMPLES // HOP_LENGTH


def _mel_filterbank(
n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
"""
Compute a Mel-filterbank. The filters are stored in the rows, the columns,
and it is Slaney normalized mel-scale filterbank.
"""
fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype)

# the centers of the frequency bins for the DFT
freq_bins = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)

mel = np.linspace(min_mel, max_mel, n_mels + 2)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mel

# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region

log_t = mel >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mel[log_t] - min_log_mel))
mel_bins = freqs

mel_spacing = np.diff(mel_bins)

ramps = mel_bins.reshape(-1, 1) - freq_bins.reshape(1, -1)
for i in range(n_mels):
left = -ramps[i] / mel_spacing[i]
right = ramps[i + 2] / mel_spacing[i + 1]

# intersect them with each other and zero
fbank[i] = np.maximum(0, np.minimum(left, right))

energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
fbank *= energy_norm[:, np.newaxis]
return fbank


class CustomOpStftNorm(torch.autograd.Function):
@staticmethod
def symbolic(g, self, n_fft, hop_length, window):
t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
t_n_fft = g.op('Constant', value_t=torch.tensor(
n_fft, dtype=torch.int64))
t_hop_length = g.op('Constant', value_t=torch.tensor(
hop_length, dtype=torch.int64))
t_frame_size = g.op(
'Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)

@staticmethod
Expand All @@ -97,7 +59,7 @@ def __init__(self, sr=_WhisperHParams.SAMPLE_RATE, n_fft=_WhisperHParams.N_FFT,
self.n_fft = n_fft
self.window = torch.hann_window(n_fft)
self.mel_filters = torch.from_numpy(
_mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))

def forward(self, audio_pcm: torch.Tensor):
stft_norm = CustomOpStftNorm.apply(audio_pcm,
Expand All @@ -112,7 +74,8 @@ def forward(self, audio_pcm: torch.Tensor):
spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0],
spec_shape[1],
self.n_samples // self.hop_length - spec_shape[2],
self.n_samples // self.hop_length -
spec_shape[2],
dtype=torch.float)
padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2)
Expand Down Expand Up @@ -165,15 +128,20 @@ def _to_onnx_stft(onnx_model, n_fft):
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_minus_1_output_0',
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
name='slice_1'),
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
make_node('Constant', inputs=[], outputs=[
'const0_output_0'], name='const0', value_int=0),
make_node('Constant', inputs=[], outputs=[
'const1_output_0'], name='const1', value_int=1),
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
name='gather_4', axis=3),
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
name='gather_5', axis=3),
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=[
'mul_output_0'], name='mul0'),
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=[
'mul_1_output_0'], name='mul1'),
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[
stft_norm_node.output[0]], name='add0'),
]
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
new_stft_nodes.extend(replaced_nodes)
Expand Down Expand Up @@ -253,9 +221,11 @@ def post_processing(self, **kwargs):
del g.node[:]
g.node.extend(nodes)

inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
inputs = [onnx.helper.make_tensor_value_info(
"sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
del g.input[:]
g.input.extend(inputs)
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text']))
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(
onnx.TensorProto.STRING, ['N', 'text']))

return make_onnx_model(g, opset_version=self.opset_version)
10 changes: 4 additions & 6 deletions operators/audio/audio.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@

#include "ocos.h"
#ifdef ENABLE_DR_LIBS
#include "audio_decoder.hpp"
#include "audio_decoder.h"
#endif // ENABLE_DR_LIBS

FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& {
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []() -> CustomOpArray& {
static OrtOpLoader op_loader(
[]() { return nullptr; }
#ifdef ENABLE_DR_LIBS
,
CustomCpuStructV2("AudioDecoder", AudioDecoder)
CustomCpuStructV2("AudioDecoder", AudioDecoder),
#endif
);
[]() { return nullptr; });

return op_loader.GetCustomOps();
};
Loading
Loading