Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add infer_request functions #111

Merged
merged 1 commit into from
May 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions crates/openvino/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use crate::tensor::Tensor;
use crate::{cstr, drop_using_function, try_unsafe, util::Result};
use openvino_sys::{
ov_infer_request_free, ov_infer_request_get_output_tensor_by_index,
ov_infer_request_get_tensor, ov_infer_request_infer,
ov_infer_request_set_input_tensor_by_index, ov_infer_request_set_tensor,
ov_infer_request_cancel, ov_infer_request_free, ov_infer_request_get_input_tensor,
ov_infer_request_get_output_tensor, ov_infer_request_get_output_tensor_by_index,
ov_infer_request_get_tensor, ov_infer_request_infer, ov_infer_request_set_input_tensor,
ov_infer_request_set_input_tensor_by_index, ov_infer_request_set_output_tensor,
ov_infer_request_set_output_tensor_by_index, ov_infer_request_set_tensor,
ov_infer_request_start_async, ov_infer_request_t, ov_infer_request_wait_for,
};

/// See [`InferRequest`](https://docs.openvino.ai/2023.3/api/c_cpp_api/group__ov__infer__request__c__api.html).
/// See [`InferRequest`](https://docs.openvino.ai/2024/api/c_cpp_api/group__ov__infer__request__c__api.html).
pub struct InferRequest {
ptr: *mut ov_infer_request_t,
}
Expand Down Expand Up @@ -43,6 +45,21 @@ impl InferRequest {
Ok(Tensor::from_ptr(tensor))
}

/// Get an input tensor from the model with only one input tensor.
pub fn get_input_tensor(&self) -> Result<Tensor> {
let mut tensor = std::ptr::null_mut();
try_unsafe!(ov_infer_request_get_input_tensor(
self.ptr,
std::ptr::addr_of_mut!(tensor)
))?;
Ok(Tensor::from_ptr(tensor))
}

/// Set an input tensor for infer models with single input.
pub fn set_input_tensor(&mut self, tensor: &Tensor) -> Result<()> {
try_unsafe!(ov_infer_request_set_input_tensor(self.ptr, tensor.as_ptr()))
}

/// Assing an input [`Tensor`] to the model by its index.
pub fn set_input_tensor_by_index(&mut self, index: usize, tensor: &Tensor) -> Result<()> {
try_unsafe!(ov_infer_request_set_input_tensor_by_index(
Expand All @@ -64,11 +81,43 @@ impl InferRequest {
Ok(Tensor::from_ptr(tensor))
}

/// Get an output tensor from the model with only one output tensor.
pub fn get_output_tensor(&self) -> Result<Tensor> {
let mut tensor = std::ptr::null_mut();
try_unsafe!(ov_infer_request_get_output_tensor(
self.ptr,
std::ptr::addr_of_mut!(tensor)
))?;
Ok(Tensor::from_ptr(tensor))
}

/// Set an output tensor to infer models with single output.
pub fn set_output_tensor(&mut self, tensor: &Tensor) -> Result<()> {
try_unsafe!(ov_infer_request_set_output_tensor(
self.ptr,
tensor.as_ptr()
))
}

/// Set an output tensor to infer by the index of output tensor.
pub fn set_output_tensor_by_index(&mut self, index: usize, tensor: &Tensor) -> Result<()> {
try_unsafe!(ov_infer_request_set_output_tensor_by_index(
self.ptr,
index,
tensor.as_ptr()
))
}

/// Execute the inference request.
pub fn infer(&mut self) -> Result<()> {
try_unsafe!(ov_infer_request_infer(self.ptr))
}

/// Cancels inference request.
pub fn cancel(&mut self) -> Result<()> {
try_unsafe!(ov_infer_request_cancel(self.ptr))
}

/// Execute the inference request asyncroneously.
pub fn infer_async(&mut self) -> Result<()> {
try_unsafe!(ov_infer_request_start_async(self.ptr))
Expand Down
Loading