Skip to content

Commit

Permalink
Feature extraction C API for whipser model (#755)
Browse files Browse the repository at this point in the history
* Feature extraction C API for whipser model

* Update the docs

* Update the docs2

* refine the code

* fix some issues

* fix the Linux build

* fix more data consistency issue

* More code refinements
  • Loading branch information
wenbingl authored Jul 11, 2024
1 parent 95d65e4 commit 8153bc1
Show file tree
Hide file tree
Showing 28 changed files with 1,505 additions and 515 deletions.
13 changes: 11 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ if (MSVC)
endif()
message(STATUS "_STATIC_MSVC_RUNTIME_LIBRARY: ${_STATIC_MSVC_RUNTIME_LIBRARY}")

# DLL initialization errors due to old conda msvcp140.dll dll are a result of the new MSVC compiler
# See https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856
# Remove this definition once the conda msvcp140.dll dll is updated.
add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
endif()

if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
Expand Down Expand Up @@ -442,7 +446,9 @@ endif()
if(OCOS_ENABLE_BERT_TOKENIZER)
# Bert
set(_HAS_TOKENIZER ON)
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*")
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*"
"operators/tokenizer/bert_tokenizer.*"
"operators/tokenizer/bert_tokenizer_decoder.*")
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
endif()

Expand Down Expand Up @@ -820,7 +826,9 @@ if(OCOS_ENABLE_AZURE)
endif()

target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS})
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:noexcep_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")

target_link_libraries(ortcustomops PUBLIC ocos_operators)

if(_BUILD_SHARED_LIBRARY)
Expand All @@ -840,7 +848,8 @@ if(_BUILD_SHARED_LIBRARY)
standardize_output_folder(extensions_shared)

if(LINUX OR ANDROID)
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS
" -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
# strip if not a debug build
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-s")
Expand Down
19 changes: 19 additions & 0 deletions docs/c_api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ONNXRuntime Extensions C ABI

ONNXRuntime Extensions provides a C-style ABI for pre-processing. It offers support for tokenization, image processing, speech feature extraction, and more. You can compile the ONNXRuntime Extensions as either a static library or a dynamic library to access these APIs.

The C ABI header files are named `ortx_*.h` and can be found in the include folder. There are three types of data processing APIs available:

- [`ortx_tokenizer.h`](../include/ortx_tokenizer.h): Provides tokenization for LLM models.
- [`ortx_processor.h`](../include/ortx_processor.h): Offers image processing APIs for multimodels.
- [`ortx_extraction.h`](../include/ortx_extractor.h): Provides speech feature extraction for audio data processing to assist the Whisper model.

## ABI QuickStart

Most APIs accept raw data inputs such as audio, image compressed binary formats, or UTF-8 encoded text for tokenization.

**Tokenization:** You can create a tokenizer object using `OrtxCreateTokenizer` and then use the object to tokenize a text or decode the token ID into the text. A C-style code snippet is available [here](../test/pp_api_test/c_only_test.c).

**Image processing:** `OrtxCreateProcessor` can create an image processor object from a pre-defined workflow in JSON format to process image files into a tensor-like data type. An example code snippet can be found [here](../test/pp_api_test/test_processor.cc#L75).

**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extractor.cc#L16).
File renamed without changes.
75 changes: 75 additions & 0 deletions include/ortx_extractor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// C ABI header file for the onnxruntime-extensions tokenization module

#pragma once

#include "ortx_utils.h"

typedef OrtxObject OrtxFeatureExtractor;
typedef OrtxObject OrtxRawAudios;
typedef OrtxObject OrtxTensorResult;

#ifdef __cplusplus
extern "C" {
#endif

/**
* @brief Creates a feature extractor object.
*
* This function creates a feature extractor object based on the provided feature definition.
*
* @param[out] extractor Pointer to a pointer to the created feature extractor object.
* @param[in] fe_def The feature definition used to create the feature extractor.
*
* @return An error code indicating the result of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* fe_def);

/**
* Loads a collection of audio files into memory.
*
* This function loads a collection of audio files specified by the `audio_paths` array
* into memory and returns a pointer to the loaded audio data in the `audios` parameter.
*
* @param audios A pointer to a pointer that will be updated with the loaded audio data.
* The caller is responsible for freeing the memory allocated for the audio data.
* @param audio_paths An array of strings representing the paths to the audio files to be loaded.
* @param num_audios The number of audio files to be loaded.
*
* @return An `extError_t` value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** audios, const char* const* audio_paths, size_t num_audios);

/**
* @brief Creates an array of raw audio objects.
*
* This function creates an array of raw audio objects based on the provided data and sizes.
*
* @param audios Pointer to the variable that will hold the created raw audio objects.
* @param data Array of pointers to the audio data.
* @param sizes Array of pointers to the sizes of the audio data.
* @param num_audios Number of audio objects to create.
*
* @return extError_t Error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void* data[], const int64_t* sizes[], size_t num_audios);

/**
* @brief Calculates the log mel spectrogram for a given audio using the specified feature extractor.
*
* This function takes an instance of the OrtxFeatureExtractor struct, an instance of the OrtxRawAudios struct,
* and a pointer to a OrtxTensorResult pointer. It calculates the log mel spectrogram for the given audio using
* the specified feature extractor and stores the result in the provided log_mel pointer.
*
* @param extractor The feature extractor to use for calculating the log mel spectrogram.
* @param audio The raw audio data to process.
* @param log_mel A pointer to a OrtxTensorResult pointer where the result will be stored.
* @return An extError_t value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio, OrtxTensorResult** log_mel);

#ifdef __cplusplus
}
#endif
36 changes: 16 additions & 20 deletions include/ortx_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
// typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting
typedef OrtxObject OrtxProcessor;
typedef OrtxObject OrtxRawImages;
typedef OrtxObject OrtxImageProcessorResult;

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -40,8 +39,22 @@ extError_t ORTX_API_CALL OrtxCreateProcessor(OrtxProcessor** processor, const ch
extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** image_paths, size_t num_images,
size_t* num_images_loaded);


/**
* @brief Preprocesses the given raw images using the specified processor.
* @brief Creates raw images from the provided data.
*
* This function creates raw images from the provided data. The raw images are stored in the `images` parameter.
*
* @param images Pointer to a pointer to the `OrtxRawImages` structure that will hold the created raw images.
* @param data Array of pointers to the data for each image.
* @param sizes Array of pointers to the sizes of each image.
* @param num_images Number of images to create.
* @return An `extError_t` value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateRawImages(OrtxRawImages** images, const void* data[], const int64_t* sizes[], size_t num_images);

/**
* @brief Pre-processes the given raw images using the specified processor.
*
* This function applies preprocessing operations on the raw images using the provided processor.
* The result of the preprocessing is stored in the `OrtxImageProcessorResult` object.
Expand All @@ -52,24 +65,7 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima
* @return An `extError_t` value indicating the success or failure of the preprocessing operation.
*/
extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images,
OrtxImageProcessorResult** result);

/**
* @brief Retrieves the image processor result at the specified index.
*
* @param result Pointer to the OrtxImageProcessorResult structure to store the result.
* @param index The index of the result to retrieve.
* @return extError_t The error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor);

/** \brief Clear the outputs of the processor
*
* \param processor The processor object
* \param result The result object to clear
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result);
OrtxTensorResult** result);

#ifdef __cplusplus
}
Expand Down
19 changes: 17 additions & 2 deletions include/ortx_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ typedef enum {
kOrtxKindDetokenizerCache = 0x778B,
kOrtxKindProcessor = 0x778C,
kOrtxKindRawImages = 0x778D,
kOrtxKindImageProcessorResult = 0x778E,
kOrtxKindTensorResult = 0x778E,
kOrtxKindProcessorResult = 0x778F,
kOrtxKindTensor = 0x7790,
kOrtxKindFeatureExtractor = 0x7791,
kOrtxKindRawAudios = 0x7792,
kOrtxKindEnd = 0x7999
} extObjectKind_t;

// all object managed by the library should be 'derived' from this struct
// which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C
typedef struct {
int ext_kind_;
extObjectKind_t ext_kind_;
} OrtxObject;

typedef OrtxObject OrtxTensor;
typedef OrtxObject OrtxTensorResult;

// C, instead of C++ doesn't cast automatically,
// so we need to use a macro to cast the object to the correct type
Expand Down Expand Up @@ -77,6 +80,18 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);
*/
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);

/**
* @brief Retrieves the tensor at the specified index from the given tensor result.
*
* This function allows you to access a specific tensor from a tensor result object.
*
* @param result The tensor result object from which to retrieve the tensor.
* @param index The index of the tensor to retrieve.
* @param tensor A pointer to a variable that will hold the retrieved tensor.
* @return An error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor);

/** \brief Get the data from the tensor
*
* \param tensor The tensor object
Expand Down
78 changes: 24 additions & 54 deletions onnxruntime_extensions/_torch_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._ortapi2 import make_onnx_model
from ._cuops import SingleOpGraph
from ._hf_cvt import HFTokenizerConverter
from .util import remove_unused_initializers
from .util import remove_unused_initializers, mel_filterbank


class _WhisperHParams:
Expand All @@ -30,53 +30,15 @@ class _WhisperHParams:
N_FRAMES = N_SAMPLES // HOP_LENGTH


def _mel_filterbank(
n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
"""
Compute a Mel-filterbank. The filters are stored in the rows, the columns,
and it is Slaney normalized mel-scale filterbank.
"""
fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype)

# the centers of the frequency bins for the DFT
freq_bins = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)

mel = np.linspace(min_mel, max_mel, n_mels + 2)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mel

# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region

log_t = mel >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mel[log_t] - min_log_mel))
mel_bins = freqs

mel_spacing = np.diff(mel_bins)

ramps = mel_bins.reshape(-1, 1) - freq_bins.reshape(1, -1)
for i in range(n_mels):
left = -ramps[i] / mel_spacing[i]
right = ramps[i + 2] / mel_spacing[i + 1]

# intersect them with each other and zero
fbank[i] = np.maximum(0, np.minimum(left, right))

energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
fbank *= energy_norm[:, np.newaxis]
return fbank


class CustomOpStftNorm(torch.autograd.Function):
@staticmethod
def symbolic(g, self, n_fft, hop_length, window):
t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
t_n_fft = g.op('Constant', value_t=torch.tensor(
n_fft, dtype=torch.int64))
t_hop_length = g.op('Constant', value_t=torch.tensor(
hop_length, dtype=torch.int64))
t_frame_size = g.op(
'Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)

@staticmethod
Expand All @@ -97,7 +59,7 @@ def __init__(self, sr=_WhisperHParams.SAMPLE_RATE, n_fft=_WhisperHParams.N_FFT,
self.n_fft = n_fft
self.window = torch.hann_window(n_fft)
self.mel_filters = torch.from_numpy(
_mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))

def forward(self, audio_pcm: torch.Tensor):
stft_norm = CustomOpStftNorm.apply(audio_pcm,
Expand All @@ -112,7 +74,8 @@ def forward(self, audio_pcm: torch.Tensor):
spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0],
spec_shape[1],
self.n_samples // self.hop_length - spec_shape[2],
self.n_samples // self.hop_length -
spec_shape[2],
dtype=torch.float)
padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2)
Expand Down Expand Up @@ -165,15 +128,20 @@ def _to_onnx_stft(onnx_model, n_fft):
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_minus_1_output_0',
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
name='slice_1'),
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
make_node('Constant', inputs=[], outputs=[
'const0_output_0'], name='const0', value_int=0),
make_node('Constant', inputs=[], outputs=[
'const1_output_0'], name='const1', value_int=1),
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
name='gather_4', axis=3),
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
name='gather_5', axis=3),
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=[
'mul_output_0'], name='mul0'),
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=[
'mul_1_output_0'], name='mul1'),
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[
stft_norm_node.output[0]], name='add0'),
]
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
new_stft_nodes.extend(replaced_nodes)
Expand Down Expand Up @@ -253,9 +221,11 @@ def post_processing(self, **kwargs):
del g.node[:]
g.node.extend(nodes)

inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
inputs = [onnx.helper.make_tensor_value_info(
"sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
del g.input[:]
g.input.extend(inputs)
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text']))
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(
onnx.TensorProto.STRING, ['N', 'text']))

return make_onnx_model(g, opset_version=self.opset_version)
10 changes: 4 additions & 6 deletions operators/audio/audio.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@

#include "ocos.h"
#ifdef ENABLE_DR_LIBS
#include "audio_decoder.hpp"
#include "audio_decoder.h"
#endif // ENABLE_DR_LIBS

FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& {
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []() -> CustomOpArray& {
static OrtOpLoader op_loader(
[]() { return nullptr; }
#ifdef ENABLE_DR_LIBS
,
CustomCpuStructV2("AudioDecoder", AudioDecoder)
CustomCpuStructV2("AudioDecoder", AudioDecoder),
#endif
);
[]() { return nullptr; });

return op_loader.GetCustomOps();
};
Loading

0 comments on commit 8153bc1

Please sign in to comment.