Skip to content

Commit

Permalink
Add partial support of accumarray to unblock
Browse files Browse the repository at this point in the history
  • Loading branch information
ntjohnson1 committed Dec 31, 2023
1 parent d061d17 commit 79bfe4d
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion src/utils/ndarray_helpers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ndarray::{ArrayView, Ix1};
use ndarray::{s, Array, ArrayView, Axis, Ix1, Ix2, IxDyn};
use ndarray_linalg::Norm;

/// Gets largest value (in aboslute terms) but return original value
Expand Down Expand Up @@ -41,6 +41,22 @@ pub fn p_norm(array: ArrayView<f64, Ix1>, p: i64) -> f64 {
}
}

/// Matches numpy_groupies Form 4
/// only supports sum for now
pub fn aggregate(
group_idx: ArrayView<usize, Ix2>,
weights: ArrayView<f64, Ix1>,
) -> Array<f64, IxDyn> {
let result_shape = group_idx
.map_axis(Axis(0), |view| *view.iter().max().unwrap() + 1)
.to_vec();
let mut result = Array::<f64, IxDyn>::zeros(IxDyn(&result_shape));
for (i, group) in group_idx.axis_iter(Axis(0)).enumerate() {
result[&group.to_vec()[..]] += weights[i];
}
result
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -63,4 +79,14 @@ mod tests {
assert!(sign(3.0) == 1.0);
assert!(sign(0.0) == 0.0);
}

#[test]
fn test_aggregate() {
let group: Array<usize, Ix2> = array![[0, 1], [0, 1], [2, 3]];
let weights: Array<f64, Ix1> = array![1., 1., 5.];
let expected_result: Array<f64, IxDyn> =
array![[0., 2., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 5.],].into_dyn();
let result = aggregate((&group).into(), (&weights).into());
assert!(result.abs_diff_eq(&expected_result, 1e-8));
}
}

0 comments on commit 79bfe4d

Please sign in to comment.