Skip to content

Commit

Permalink
Fix serde for TopNComputer (#2313)
Browse files Browse the repository at this point in the history
* Fix serde for TopNComputer

The top hits aggregation changed the TopNComputer to be serializable,
but capacity needs to be carried over, as it contains logic which is
checked against when pushing elements (capacity == 0 is not allowed).

* use serde from deser

* remove pub, clippy
  • Loading branch information
PSeitz authored Feb 7, 2024
1 parent 88a3275 commit 0e16ed9
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/aggregation/metric/top_hits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl RetrievalFields {
return Ok(vec![field.to_owned()]);
}

let pattern = globbed_string_to_regex(&field)?;
let pattern = globbed_string_to_regex(field)?;
let fields = reader
.iter_columns()?
.map(|(name, _)| {
Expand Down
60 changes: 53 additions & 7 deletions src/collector/top_score_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::marker::PhantomData;
use std::sync::Arc;

use columnar::ColumnValues;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};

use super::Collector;
Expand Down Expand Up @@ -720,17 +721,43 @@ impl SegmentCollector for TopScoreSegmentCollector {
///
/// For TopN == 0, it will be relative expensive.
#[derive(Clone, Serialize, Deserialize)]
pub struct TopNComputer<Score, DocId, const REVERSE_ORDER: bool = true> {
#[serde(from = "TopNComputerDeser<Score, D, REVERSE_ORDER>")]
pub struct TopNComputer<Score, D, const REVERSE_ORDER: bool = true> {
/// The buffer reverses sort order to get top-semantics instead of bottom-semantics
buffer: Vec<ComparableDoc<Score, DocId, REVERSE_ORDER>>,
buffer: Vec<ComparableDoc<Score, D, REVERSE_ORDER>>,
top_n: usize,
pub(crate) threshold: Option<Score>,
}
// Intermediate struct for TopNComputer for deserialization, to fix vec capacity
#[derive(Deserialize)]
struct TopNComputerDeser<Score, D, const REVERSE_ORDER: bool> {
buffer: Vec<ComparableDoc<Score, D, REVERSE_ORDER>>,
top_n: usize,
threshold: Option<Score>,
}

impl<Score, D, const R: bool> From<TopNComputerDeser<Score, D, R>> for TopNComputer<Score, D, R> {
fn from(mut value: TopNComputerDeser<Score, D, R>) -> Self {
let expected_cap = value.top_n.max(1) * 2;
let current_cap = value.buffer.capacity();
if current_cap < expected_cap {
value.buffer.reserve_exact(expected_cap - current_cap);
} else {
value.buffer.shrink_to(expected_cap);
}

TopNComputer {
buffer: value.buffer,
top_n: value.top_n,
threshold: value.threshold,
}
}
}

impl<Score, DocId, const R: bool> TopNComputer<Score, DocId, R>
impl<Score, D, const R: bool> TopNComputer<Score, D, R>
where
Score: PartialOrd + Clone,
DocId: Ord + Clone,
D: Serialize + DeserializeOwned + Ord + Clone,
{
/// Create a new `TopNComputer`.
/// Internally it will allocate a buffer of size `2 * top_n`.
Expand All @@ -746,7 +773,7 @@ where
/// Push a new document to the top n.
/// If the document is below the current threshold, it will be ignored.
#[inline]
pub fn push(&mut self, feature: Score, doc: DocId) {
pub fn push(&mut self, feature: Score, doc: D) {
if let Some(last_median) = self.threshold.clone() {
if feature < last_median {
return;
Expand Down Expand Up @@ -783,7 +810,7 @@ where
}

/// Returns the top n elements in sorted order.
pub fn into_sorted_vec(mut self) -> Vec<ComparableDoc<Score, DocId, R>> {
pub fn into_sorted_vec(mut self) -> Vec<ComparableDoc<Score, D, R>> {
if self.buffer.len() > self.top_n {
self.truncate_top_n();
}
Expand All @@ -794,7 +821,7 @@ where
/// Returns the top n elements in stored order.
/// Useful if you do not need the elements in sorted order,
/// for example when merging the results of multiple segments.
pub fn into_vec(mut self) -> Vec<ComparableDoc<Score, DocId, R>> {
pub fn into_vec(mut self) -> Vec<ComparableDoc<Score, D, R>> {
if self.buffer.len() > self.top_n {
self.truncate_top_n();
}
Expand Down Expand Up @@ -833,6 +860,25 @@ mod tests {
crate::assert_nearly_equals!(result.0, expected.0);
}
}
#[test]
fn test_topn_computer_serde() {
let computer: TopNComputer<u32, u32> = TopNComputer::new(1);

let computer_ser = serde_json::to_string(&computer).unwrap();
let mut computer: TopNComputer<u32, u32> = serde_json::from_str(&computer_ser).unwrap();

computer.push(1u32, 5u32);
computer.push(1u32, 0u32);
computer.push(1u32, 7u32);

assert_eq!(
computer.into_sorted_vec(),
&[ComparableDoc {
feature: 1u32,
doc: 0u32,
},]
);
}

#[test]
fn test_empty_topn_computer() {
Expand Down
5 changes: 4 additions & 1 deletion src/index/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,10 @@ impl Index {
}

/// Custom thread pool by a outer thread pool.
pub fn set_shared_multithread_executor(&mut self, shared_thread_pool: Arc<Executor>) -> crate::Result<()> {
pub fn set_shared_multithread_executor(
&mut self,
shared_thread_pool: Arc<Executor>,
) -> crate::Result<()> {
self.executor = shared_thread_pool.clone();
Ok(())
}
Expand Down

0 comments on commit 0e16ed9

Please sign in to comment.