Skip to content

Commit

Permalink
Merge pull request #349 from robertknight/u8-gather
Browse files Browse the repository at this point in the history
Support u8 tensors in `Gather` operator
  • Loading branch information
robertknight authored Sep 6, 2024
2 parents c960da9 + d647dab commit 702efa3
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion rten-examples/src/rmbg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let std_dev = [1.0, 1.0, 1.0];
normalize_image(normalized_image.view_mut(), mean, std_dev);

let [_, orig_height, orig_width] = image.shape().try_into()?;
let [_, orig_height, orig_width] = image.shape();

let mut normalized_image = normalized_image.into_dyn();
normalized_image.insert_axis(0); // Add batch dim
Expand Down
1 change: 1 addition & 0 deletions src/ops/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ impl Operator for Gather {
match input {
Input::Int32Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
Input::FloatTensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
Input::UInt8Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
_ => Err(OpError::UnsupportedType),
}
}
Expand Down

0 comments on commit 702efa3

Please sign in to comment.