diff --git a/query-grammar/src/occur.rs b/query-grammar/src/occur.rs index 6621b67bfa..cc39ca458e 100644 --- a/query-grammar/src/occur.rs +++ b/query-grammar/src/occur.rs @@ -6,12 +6,12 @@ use std::fmt::Write; #[derive(Debug, Clone, Hash, Copy, Eq, PartialEq)] pub enum Occur { /// For a given document to be considered for scoring, - /// at least one of the terms with the Should or the Must + /// at least one of the queries with the Should or the Must /// Occur constraint must be within the document. Should, - /// Document without the term are excluded from the search. + /// Document without the queries are excluded from the search. Must, - /// Document that contain the term are excluded from the + /// Document that contain the query are excluded from the /// search. MustNot, } diff --git a/src/query/boolean_query/boolean_query.rs b/src/query/boolean_query/boolean_query.rs index ac225be20f..5fa925bb14 100644 --- a/src/query/boolean_query/boolean_query.rs +++ b/src/query/boolean_query/boolean_query.rs @@ -1,5 +1,5 @@ use super::boolean_weight::BooleanWeight; -use crate::query::{EnableScoring, Occur, Query, SumWithCoordsCombiner, TermQuery, Weight}; +use crate::query::{EnableScoring, Occur, Query, SumCombiner, TermQuery, Weight}; use crate::schema::{IndexRecordOption, Term}; /// The boolean query returns a set of documents @@ -169,7 +169,7 @@ impl Query for BooleanQuery { sub_weights, self.minimum_number_should_match, enable_scoring.is_scoring_enabled(), - Box::new(SumWithCoordsCombiner::default), + Box::new(SumCombiner::default), ))) } diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 3719a47385..b384c275bf 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -12,11 +12,10 @@ mod tests { use super::*; use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE; use crate::collector::TopDocs; - use crate::query::score_combiner::SumWithCoordsCombiner; use crate::query::term_query::TermScorer; use crate::query::{ EnableScoring, Intersection, Occur, Query, QueryParser, RequiredOptionalScorer, Scorer, - TermQuery, + SumCombiner, TermQuery, }; use crate::schema::*; use crate::{assert_nearly_equals, DocAddress, DocId, Index, IndexWriter, Score}; @@ -90,11 +89,8 @@ mod tests { let query = query_parser.parse_query("+a b")?; let weight = query.weight(EnableScoring::enabled_from_searcher(&searcher))?; let scorer = weight.scorer(searcher.segment_reader(0u32), 1.0)?; - assert!(scorer.is::, - Box, - SumWithCoordsCombiner, - >>()); + assert!(scorer + .is::, Box, SumCombiner>>()); } { let query = query_parser.parse_query("+a b")?; diff --git a/src/query/mod.rs b/src/query/mod.rs index 1736a2fe4f..5e99354ffe 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -57,9 +57,7 @@ pub use self::query_parser::{QueryParser, QueryParserError}; pub use self::range_query::*; pub use self::regex_query::RegexQuery; pub use self::reqopt_scorer::RequiredOptionalScorer; -pub use self::score_combiner::{ - DisjunctionMaxCombiner, ScoreCombiner, SumCombiner, SumWithCoordsCombiner, -}; +pub use self::score_combiner::{DisjunctionMaxCombiner, ScoreCombiner, SumCombiner}; pub use self::scorer::Scorer; pub use self::set_query::TermSetQuery; pub use self::term_query::TermQuery; diff --git a/src/query/query_parser/logical_ast.rs b/src/query/query_parser/logical_ast.rs index a9400881b5..b0929f26a9 100644 --- a/src/query/query_parser/logical_ast.rs +++ b/src/query/query_parser/logical_ast.rs @@ -37,6 +37,34 @@ impl LogicalAst { LogicalAst::Boost(Box::new(self), boost) } } + + pub fn simplify(self) -> LogicalAst { + match self { + LogicalAst::Clause(clauses) => { + let mut new_clauses: Vec<(Occur, LogicalAst)> = Vec::new(); + + for (occur, sub_ast) in clauses { + let simplified_sub_ast = sub_ast.simplify(); + + // If clauses below have the same `Occur`, we can pull them up + match simplified_sub_ast { + LogicalAst::Clause(sub_clauses) + if (occur == Occur::Should || occur == Occur::Must) + && sub_clauses.iter().all(|(o, _)| *o == occur) => + { + for sub_clause in sub_clauses { + new_clauses.push(sub_clause); + } + } + _ => new_clauses.push((occur, simplified_sub_ast)), + } + } + + LogicalAst::Clause(new_clauses) + } + LogicalAst::Leaf(_) | LogicalAst::Boost(_, _) => self, + } + } } fn occur_letter(occur: Occur) -> &'static str { diff --git a/src/query/query_parser/query_parser.rs b/src/query/query_parser/query_parser.rs index ac6d297f47..45a68f70e9 100644 --- a/src/query/query_parser/query_parser.rs +++ b/src/query/query_parser/query_parser.rs @@ -377,7 +377,7 @@ impl QueryParser { if !err.is_empty() { return Err(err.swap_remove(0)); } - Ok(ast) + Ok(ast.simplify()) } /// Parse the user query into an AST. @@ -1437,7 +1437,7 @@ mod test { ); test_parse_query_to_logical_ast_helper( "(+title:a +title:b) title:c", - r#"(+(+Term(field=0, type=Str, "a") +Term(field=0, type=Str, "b")) +Term(field=0, type=Str, "c"))"#, + r#"(+Term(field=0, type=Str, "a") +Term(field=0, type=Str, "b") +Term(field=0, type=Str, "c"))"#, true, ); } @@ -1473,7 +1473,7 @@ mod test { pub fn test_parse_query_to_ast_two_terms() { test_parse_query_to_logical_ast_helper( "title:a b", - r#"(Term(field=0, type=Str, "a") (Term(field=0, type=Str, "b") Term(field=1, type=Str, "b")))"#, + r#"(Term(field=0, type=Str, "a") Term(field=0, type=Str, "b") Term(field=1, type=Str, "b"))"#, false, ); test_parse_query_to_logical_ast_helper( @@ -1705,6 +1705,21 @@ mod test { ); } + #[test] + pub fn test_parse_query_negative() { + test_parse_query_to_logical_ast_helper( + "title:b -title:a", + r#"(+Term(field=0, type=Str, "b") -Term(field=0, type=Str, "a"))"#, + true, + ); + + test_parse_query_to_logical_ast_helper( + "title:b -(-title:a -title:c)", + r#"(+Term(field=0, type=Str, "b") -(-Term(field=0, type=Str, "a") -Term(field=0, type=Str, "c")))"#, + true, + ); + } + #[test] pub fn test_query_parser_hyphen() { test_parse_query_to_logical_ast_helper( diff --git a/src/query/score_combiner.rs b/src/query/score_combiner.rs index 449badbadf..a49f8b104b 100644 --- a/src/query/score_combiner.rs +++ b/src/query/score_combiner.rs @@ -54,30 +54,6 @@ impl ScoreCombiner for SumCombiner { } } -/// Sums the score of different scorers and keeps the count -/// of scorers which matched. -#[derive(Default, Clone, Copy)] -pub struct SumWithCoordsCombiner { - num_fields: usize, - score: Score, -} - -impl ScoreCombiner for SumWithCoordsCombiner { - fn update(&mut self, scorer: &mut TScorer) { - self.score += scorer.score(); - self.num_fields += 1; - } - - fn clear(&mut self) { - self.score = 0.0; - self.num_fields = 0; - } - - fn score(&self) -> Score { - self.score - } -} - /// Take max score of different scorers /// and optionally sum it with other matches multiplied by `tie_breaker` #[derive(Default, Clone, Copy)]