diff --git a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java index ce6fcda788c15..abd89872f013f 100644 --- a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java @@ -159,7 +159,7 @@ final class DefaultSearchContext extends SearchContext { DefaultSearchContext(long id, ShardSearchRequest request, SearchShardTarget shardTarget, Engine.Searcher engineSearcher, ClusterService clusterService, IndexService indexService, IndexShard indexShard, BigArrays bigArrays, LongSupplier relativeTimeSupplier, TimeValue timeout, - FetchPhase fetchPhase, Version minNodeVersion) { + FetchPhase fetchPhase, Version minNodeVersion) throws IOException { this.id = id; this.request = request; this.fetchPhase = fetchPhase; diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index e4cf2d764d13d..58ba78e47a925 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -398,7 +398,7 @@ && canRewriteToMatchNone(rewritten.source()) }, listener::onFailure)); } - private void onMatchNoDocs(SearchRewriteContext rewriteContext, ActionListener listener) { + private void onMatchNoDocs(SearchRewriteContext rewriteContext, ActionListener listener) throws IOException { // creates a lightweight search context that we use to inform context listeners // before closing SearchContext searchContext = createSearchContext(rewriteContext, defaultSearchTimeout); @@ -609,7 +609,7 @@ private SearchContext findContext(long id, TransportRequest request) throws Sear } } - final SearchContext createAndPutContext(SearchRewriteContext rewriteContext) { + final SearchContext createAndPutContext(SearchRewriteContext rewriteContext) throws IOException { SearchContext context = createContext(rewriteContext); onNewContext(context); boolean success = false; @@ -644,7 +644,7 @@ private void onNewContext(SearchContext context) { } } - final SearchContext createContext(SearchRewriteContext rewriteContext) { + final SearchContext createContext(SearchRewriteContext rewriteContext) throws IOException { final DefaultSearchContext context = createSearchContext(rewriteContext, defaultSearchTimeout); try { if (rewriteContext.request != null && openScrollContexts.get() >= maxOpenScrollContext) { @@ -695,7 +695,7 @@ public DefaultSearchContext createSearchContext(ShardSearchRequest request, Time return createSearchContext(rewriteContext.wrapSearcher(), timeout); } - private DefaultSearchContext createSearchContext(SearchRewriteContext rewriteContext, TimeValue timeout) { + private DefaultSearchContext createSearchContext(SearchRewriteContext rewriteContext, TimeValue timeout) throws IOException { boolean success = false; try { final ShardSearchRequest request = rewriteContext.request; diff --git a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java index 85a0010fd58f9..d5f8c76a823d0 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java @@ -62,7 +62,9 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; import java.util.List; +import java.util.Objects; import java.util.Set; /** @@ -77,13 +79,26 @@ public class ContextIndexSearcher extends IndexSearcher { private AggregatedDfs aggregatedDfs; private QueryProfiler profiler; - private Runnable checkCancelled; + private MutableQueryTimeout cancellable; - public ContextIndexSearcher(IndexReader reader, Similarity similarity, QueryCache queryCache, QueryCachingPolicy queryCachingPolicy) { - super(reader); + public ContextIndexSearcher(IndexReader reader, Similarity similarity, + QueryCache queryCache, QueryCachingPolicy queryCachingPolicy) throws IOException { + this(reader, similarity, queryCache, queryCachingPolicy, new MutableQueryTimeout()); + } + + // TODO: Make the 2nd constructor private so that the IndexReader is always wrapped. + // Some issues must be fixed: + // - regarding tests deriving from AggregatorTestCase and more specifically the use of searchAndReduce and + // the ShardSearcher sub-searchers. + // - tests that use a MultiReader + public ContextIndexSearcher(IndexReader reader, Similarity similarity, + QueryCache queryCache, QueryCachingPolicy queryCachingPolicy, + MutableQueryTimeout cancellable) throws IOException { + super(cancellable != null ? new ExitableDirectoryReader((DirectoryReader) reader, cancellable) : reader); setSimilarity(similarity); setQueryCache(queryCache); setQueryCachingPolicy(queryCachingPolicy); + this.cancellable = cancellable != null ? cancellable : new MutableQueryTimeout(); } public void setProfiler(QueryProfiler profiler) { @@ -91,11 +106,19 @@ public void setProfiler(QueryProfiler profiler) { } /** - * Set a {@link Runnable} that will be run on a regular basis while - * collecting documents. + * Add a {@link Runnable} that will be run on a regular basis while accessing documents in the + * DirectoryReader but also while collecting them and check for query cancellation or timeout. + */ + public Runnable addQueryCancellation(Runnable action) { + return this.cancellable.add(action); + } + + /** + * Remove a {@link Runnable} that checks for query cancellation or timeout + * which is called while accessing documents in the DirectoryReader but also while collecting them. */ - public void setCheckCancelled(Runnable checkCancelled) { - this.checkCancelled = checkCancelled; + public void removeQueryCancellation(Runnable action) { + this.cancellable.remove(action); } public void setAggregatedDfs(AggregatedDfs aggregatedDfs) { @@ -139,12 +162,6 @@ public Weight createWeight(Query query, ScoreMode scoreMode, float boost) throws } } - private void checkCancelled() { - if (checkCancelled != null) { - checkCancelled.run(); - } - } - public void search(List leaves, Weight weight, CollectorManager manager, QuerySearchResult result, DocValueFormat[] formats, TotalHits totalHits) throws IOException { final List collectors = new ArrayList<>(leaves.size()); @@ -179,7 +196,7 @@ protected void search(List leaves, Weight weight, Collector c * the provided ctx. */ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collector) throws IOException { - checkCancelled(); + cancellable.checkCancelled(); weight = wrapWeight(weight); final LeafCollector leafCollector; try { @@ -207,7 +224,7 @@ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collecto if (scorer != null) { try { intersectScorerAndBitSet(scorer, liveDocsBitSet, leafCollector, - checkCancelled == null ? () -> { } : checkCancelled); + this.cancellable.isEnabled() ? cancellable::checkCancelled: () -> {}); } catch (CollectionTerminatedException e) { // collection was terminated prematurely // continue with the following leaf @@ -217,7 +234,7 @@ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collecto } private Weight wrapWeight(Weight weight) { - if (checkCancelled != null) { + if (cancellable.isEnabled()) { return new Weight(weight.getQuery()) { @Override public void extractTerms(Set terms) { @@ -243,7 +260,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { public BulkScorer bulkScorer(LeafReaderContext context) throws IOException { BulkScorer in = weight.bulkScorer(context); if (in != null) { - return new CancellableBulkScorer(in, checkCancelled); + return new CancellableBulkScorer(in, cancellable::checkCancelled); } else { return null; } @@ -319,4 +336,33 @@ public DirectoryReader getDirectoryReader() { assert reader instanceof DirectoryReader : "expected an instance of DirectoryReader, got " + reader.getClass(); return (DirectoryReader) reader; } + + private static class MutableQueryTimeout implements ExitableDirectoryReader.QueryCancellation { + + private final Set runnables = new HashSet<>(); + + private Runnable add(Runnable action) { + Objects.requireNonNull(action, "cancellation runnable should not be null"); + if (runnables.add(action) == false) { + throw new IllegalArgumentException("Cancellation runnable already added"); + } + return action; + } + + private void remove(Runnable action) { + runnables.remove(action); + } + + @Override + public void checkCancelled() { + for (Runnable timeout : runnables) { + timeout.run(); + } + } + + @Override + public boolean isEnabled() { + return runnables.isEmpty() == false; + } + } } diff --git a/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java new file mode 100644 index 0000000000000..b66532bbd093e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java @@ -0,0 +1,289 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.internal; + +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FilterDirectoryReader; +import org.apache.lucene.index.FilterLeafReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.PointValues; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.suggest.document.CompletionTerms; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.automaton.CompiledAutomaton; + +import java.io.IOException; + +/** + * Wraps an {@link IndexReader} with a {@link QueryCancellation} + * which checks for cancelled or timed-out query. + */ +class ExitableDirectoryReader extends FilterDirectoryReader { + + /** + * Used to check if query cancellation is actually enabled + * and if so use it to check if the query is cancelled or timed-out. + */ + interface QueryCancellation { + + /** + * Used to prevent unnecessary checks for cancellation + * @return true if query cancellation is enabled + */ + boolean isEnabled(); + + /** + * Call to check if the query is cancelled or timed-out. + * If so a {@link RuntimeException} is thrown + */ + void checkCancelled(); + } + + ExitableDirectoryReader(DirectoryReader in, QueryCancellation queryCancellation) throws IOException { + super(in, new SubReaderWrapper() { + @Override + public LeafReader wrap(LeafReader reader) { + return new ExitableLeafReader(reader, queryCancellation); + } + }); + } + + @Override + protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) { + throw new UnsupportedOperationException("doWrapDirectoryReader() should never be invoked"); + } + + @Override + public CacheHelper getReaderCacheHelper() { + return in.getReaderCacheHelper(); + } + /** + * Wraps a {@link FilterLeafReader} with a {@link QueryCancellation}. + */ + static class ExitableLeafReader extends FilterLeafReader { + + private final QueryCancellation queryCancellation; + + private ExitableLeafReader(LeafReader leafReader, QueryCancellation queryCancellation) { + super(leafReader); + this.queryCancellation = queryCancellation; + } + + @Override + public PointValues getPointValues(String field) throws IOException { + final PointValues pointValues = in.getPointValues(field); + if (pointValues == null) { + return null; + } + return queryCancellation.isEnabled() ? new ExitablePointValues(pointValues, queryCancellation) : pointValues; + } + + @Override + public Terms terms(String field) throws IOException { + Terms terms = in.terms(field); + if (terms == null) { + return null; + } + // If we have a suggest CompletionQuery then the CompletionWeight#bulkScorer() will check that + // the terms are instanceof CompletionTerms (not generic FilterTerms) and will throw an exception + // if that's not the case. + return (queryCancellation.isEnabled() && terms instanceof CompletionTerms == false) ? + new ExitableTerms(terms, queryCancellation) : terms; + } + + @Override + public CacheHelper getCoreCacheHelper() { + return in.getCoreCacheHelper(); + } + + @Override + public CacheHelper getReaderCacheHelper() { + return in.getReaderCacheHelper(); + } + } + + /** + * Wrapper class for {@link FilterLeafReader.FilterTerms} that check for query cancellation or timeout. + */ + static class ExitableTerms extends FilterLeafReader.FilterTerms { + + private final QueryCancellation queryCancellation; + + private ExitableTerms(Terms terms, QueryCancellation queryCancellation) { + super(terms); + this.queryCancellation = queryCancellation; + } + + @Override + public TermsEnum intersect(CompiledAutomaton compiled, BytesRef startTerm) throws IOException { + return new ExitableTermsEnum(in.intersect(compiled, startTerm), queryCancellation); + } + + @Override + public TermsEnum iterator() throws IOException { + return new ExitableTermsEnum(in.iterator(), queryCancellation); + } + } + + /** + * Wrapper class for {@link FilterLeafReader.FilterTermsEnum} that is used by {@link ExitableTerms} for + * implementing an exitable enumeration of terms. + */ + private static class ExitableTermsEnum extends FilterLeafReader.FilterTermsEnum { + + private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = (1 << 4) - 1; // 15 + + private int calls; + private final QueryCancellation queryCancellation; + + private ExitableTermsEnum(TermsEnum termsEnum, QueryCancellation queryCancellation) { + super(termsEnum); + this.queryCancellation = queryCancellation; + this.queryCancellation.checkCancelled(); + } + + private void checkAndThrowWithSampling() { + if ((calls++ & MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK) == 0) { + queryCancellation.checkCancelled(); + } + } + + @Override + public BytesRef next() throws IOException { + checkAndThrowWithSampling(); + return in.next(); + } + } + + /** + * Wrapper class for {@link PointValues} that checks for query cancellation or timeout. + */ + static class ExitablePointValues extends PointValues { + + private final PointValues in; + private final QueryCancellation queryCancellation; + + private ExitablePointValues(PointValues in, QueryCancellation queryCancellation) { + this.in = in; + this.queryCancellation = queryCancellation; + this.queryCancellation.checkCancelled(); + } + + @Override + public void intersect(IntersectVisitor visitor) throws IOException { + queryCancellation.checkCancelled(); + in.intersect(new ExitableIntersectVisitor(visitor, queryCancellation)); + } + + @Override + public long estimatePointCount(IntersectVisitor visitor) { + queryCancellation.checkCancelled(); + return in.estimatePointCount(visitor); + } + + @Override + public byte[] getMinPackedValue() throws IOException { + queryCancellation.checkCancelled(); + return in.getMinPackedValue(); + } + + @Override + public byte[] getMaxPackedValue() throws IOException { + queryCancellation.checkCancelled(); + return in.getMaxPackedValue(); + } + + @Override + public int getNumDimensions() throws IOException { + queryCancellation.checkCancelled(); + return in.getNumDimensions(); + } + + @Override + public int getNumIndexDimensions() throws IOException { + queryCancellation.checkCancelled(); + return in.getNumIndexDimensions(); + } + + @Override + public int getBytesPerDimension() throws IOException { + queryCancellation.checkCancelled(); + return in.getBytesPerDimension(); + } + + @Override + public long size() { + queryCancellation.checkCancelled(); + return in.size(); + } + + @Override + public int getDocCount() { + queryCancellation.checkCancelled(); + return in.getDocCount(); + } + } + + private static class ExitableIntersectVisitor implements PointValues.IntersectVisitor { + + private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = (1 << 4) - 1; // 15 + + private final PointValues.IntersectVisitor in; + private final QueryCancellation queryCancellation; + private int calls; + + private ExitableIntersectVisitor(PointValues.IntersectVisitor in, QueryCancellation queryCancellation) { + this.in = in; + this.queryCancellation = queryCancellation; + } + + private void checkAndThrowWithSampling() { + if ((calls++ & MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK) == 0) { + queryCancellation.checkCancelled(); + } + } + + @Override + public void visit(int docID) throws IOException { + checkAndThrowWithSampling(); + in.visit(docID); + } + + @Override + public void visit(int docID, byte[] packedValue) throws IOException { + checkAndThrowWithSampling(); + in.visit(docID, packedValue); + } + + @Override + public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + queryCancellation.checkCancelled(); + return in.compare(minPackedValue, maxPackedValue); + } + + @Override + public void grow(int count) { + queryCancellation.checkCancelled(); + in.grow(count); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java index 7bafc310c18bd..7f91d778105dd 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -45,17 +45,17 @@ import org.apache.lucene.search.TopFieldCollector; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; import org.apache.lucene.util.FutureArrays; import org.elasticsearch.action.search.SearchShardTask; -import org.apache.lucene.search.Weight; import org.elasticsearch.common.Booleans; import org.elasticsearch.common.CheckedConsumer; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.concurrent.QueueResizingEsThreadPoolExecutor; import org.elasticsearch.index.IndexSortConfig; -import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.DateFieldMapper.DateFieldType; +import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchPhase; import org.elasticsearch.search.SearchService; @@ -254,60 +254,54 @@ static boolean executeInternal(SearchContext searchContext) throws QueryPhaseExe final long startTime = searchContext.getRelativeTimeInMillis(); final long timeout = searchContext.timeout().millis(); final long maxTime = startTime + timeout; - timeoutRunnable = () -> { + timeoutRunnable = searcher.addQueryCancellation(() -> { final long time = searchContext.getRelativeTimeInMillis(); if (time > maxTime) { throw new TimeExceededException(); } - }; + }); } else { timeoutRunnable = null; } - final Runnable cancellationRunnable; if (searchContext.lowLevelCancellation()) { SearchShardTask task = searchContext.getTask(); - cancellationRunnable = () -> { if (task.isCancelled()) throw new TaskCancelledException("cancelled"); }; - } else { - cancellationRunnable = null; - } - - final Runnable checkCancelled; - if (timeoutRunnable != null && cancellationRunnable != null) { - checkCancelled = () -> { - timeoutRunnable.run(); - cancellationRunnable.run(); - }; - } else if (timeoutRunnable != null) { - checkCancelled = timeoutRunnable; - } else if (cancellationRunnable != null) { - checkCancelled = cancellationRunnable; - } else { - checkCancelled = null; + searcher.addQueryCancellation(() -> { + if (task.isCancelled()) { + throw new TaskCancelledException("cancelled"); + } + }); } - searcher.setCheckCancelled(checkCancelled); - boolean shouldRescore; - // if we are optimizing sort and there are no other collectors - if (sortAndFormatsForRewrittenNumericSort != null && collectors.size() == 0 && searchContext.getProfilers() == null) { - shouldRescore = searchWithCollectorManager(searchContext, searcher, query, leafSorter, timeoutSet); - } else { - shouldRescore = searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, timeoutSet); - } + try { + boolean shouldRescore; + // if we are optimizing sort and there are no other collectors + if (sortAndFormatsForRewrittenNumericSort!=null && collectors.size()==0 && searchContext.getProfilers()==null) { + shouldRescore = searchWithCollectorManager(searchContext, searcher, query, leafSorter, timeoutSet); + } else { + shouldRescore = searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, timeoutSet); + } - // if we rewrote numeric long or date sort, restore fieldDocs based on the original sort - if (sortAndFormatsForRewrittenNumericSort != null) { - searchContext.sort(sortAndFormatsForRewrittenNumericSort); // restore SortAndFormats - restoreTopFieldDocs(queryResult, sortAndFormatsForRewrittenNumericSort); - } + // if we rewrote numeric long or date sort, restore fieldDocs based on the original sort + if (sortAndFormatsForRewrittenNumericSort!=null) { + searchContext.sort(sortAndFormatsForRewrittenNumericSort); // restore SortAndFormats + restoreTopFieldDocs(queryResult, sortAndFormatsForRewrittenNumericSort); + } - ExecutorService executor = searchContext.indexShard().getThreadPool().executor(ThreadPool.Names.SEARCH); - if (executor instanceof QueueResizingEsThreadPoolExecutor) { - QueueResizingEsThreadPoolExecutor rExecutor = (QueueResizingEsThreadPoolExecutor) executor; - queryResult.nodeQueueSize(rExecutor.getCurrentQueueSize()); - queryResult.serviceTimeEWMA((long) rExecutor.getTaskExecutionEWMA()); + ExecutorService executor = searchContext.indexShard().getThreadPool().executor(ThreadPool.Names.SEARCH); + if (executor instanceof QueueResizingEsThreadPoolExecutor) { + QueueResizingEsThreadPoolExecutor rExecutor = (QueueResizingEsThreadPoolExecutor) executor; + queryResult.nodeQueueSize(rExecutor.getCurrentQueueSize()); + queryResult.serviceTimeEWMA((long) rExecutor.getTaskExecutionEWMA()); + } + return shouldRescore; + } finally { + // Search phase has finished, no longer need to check for timeout + // otherwise aggregation phase might get cancelled. + if (timeoutRunnable!=null) { + searcher.removeQueryCancellation(timeoutRunnable); + } } - return shouldRescore; } catch (Exception e) { throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Failed to execute main query", e); } diff --git a/server/src/test/java/org/elasticsearch/search/SearchCancellationTests.java b/server/src/test/java/org/elasticsearch/search/SearchCancellationTests.java index cdbe140b0f83c..34b678135c3a3 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchCancellationTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchCancellationTests.java @@ -20,16 +20,22 @@ import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.StringField; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.PointValues; import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.store.Directory; -import org.elasticsearch.core.internal.io.IOUtils; import org.apache.lucene.util.TestUtil; +import org.apache.lucene.util.automaton.CompiledAutomaton; +import org.apache.lucene.util.automaton.RegExp; +import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.search.internal.ContextIndexSearcher; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.test.ESTestCase; @@ -43,8 +49,11 @@ public class SearchCancellationTests extends ESTestCase { - static Directory dir; - static IndexReader reader; + private static final String STRING_FIELD_NAME = "foo"; + private static final String POINT_FIELD_NAME = "point"; + + private static Directory dir; + private static IndexReader reader; @BeforeClass public static void setup() throws IOException { @@ -61,9 +70,14 @@ public static void setup() throws IOException { } private static void indexRandomDocuments(RandomIndexWriter w, int numDocs) throws IOException { - for (int i = 0; i < numDocs; ++i) { + for (int i = 1; i <= numDocs; ++i) { Document doc = new Document(); - doc.add(new StringField("foo", "bar", Field.Store.NO)); + StringBuilder sb = new StringBuilder(); + for (int j = 0; j < i; j++) { + sb.append('a'); + } + doc.add(new StringField(STRING_FIELD_NAME, sb.toString(), Field.Store.NO)); + doc.add(new IntPoint(POINT_FIELD_NAME, i, i + 1)); w.addDocument(doc); } } @@ -75,21 +89,97 @@ public static void cleanup() throws IOException { reader = null; } + public void testAddingCancellationActions() throws IOException { + ContextIndexSearcher searcher = new ContextIndexSearcher(reader, + IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy()); + NullPointerException npe = expectThrows(NullPointerException.class, () -> searcher.addQueryCancellation(null)); + assertEquals("cancellation runnable should not be null", npe.getMessage()); + + Runnable r = () -> {}; + searcher.addQueryCancellation(r); + IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> searcher.addQueryCancellation(r)); + assertEquals("Cancellation runnable already added", iae.getMessage()); + } + public void testCancellableCollector() throws IOException { - TotalHitCountCollector collector = new TotalHitCountCollector(); - AtomicBoolean cancelled = new AtomicBoolean(); + TotalHitCountCollector collector1 = new TotalHitCountCollector(); + Runnable cancellation = () -> { throw new TaskCancelledException("cancelled"); }; ContextIndexSearcher searcher = new ContextIndexSearcher(reader, IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy()); - searcher.setCheckCancelled(() -> { + + searcher.search(new MatchAllDocsQuery(), collector1); + assertThat(collector1.getTotalHits(), equalTo(reader.numDocs())); + + searcher.addQueryCancellation(cancellation); + expectThrows(TaskCancelledException.class, + () -> searcher.search(new MatchAllDocsQuery(), collector1)); + + searcher.removeQueryCancellation(cancellation); + TotalHitCountCollector collector2 = new TotalHitCountCollector(); + searcher.search(new MatchAllDocsQuery(), collector2); + assertThat(collector2.getTotalHits(), equalTo(reader.numDocs())); + } + + public void testCancellableDirectoryReader() throws IOException { + AtomicBoolean cancelled = new AtomicBoolean(true); + Runnable cancellation = () -> { if (cancelled.get()) { throw new TaskCancelledException("cancelled"); - } - }); - searcher.search(new MatchAllDocsQuery(), collector); - assertThat(collector.getTotalHits(), equalTo(reader.numDocs())); - cancelled.set(true); + }}; + ContextIndexSearcher searcher = new ContextIndexSearcher(reader, + IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy()); + searcher.addQueryCancellation(cancellation); + CompiledAutomaton automaton = new CompiledAutomaton(new RegExp("a.*").toAutomaton()); + + expectThrows(TaskCancelledException.class, + () -> searcher.getIndexReader().leaves().get(0).reader().terms(STRING_FIELD_NAME).iterator()); + expectThrows(TaskCancelledException.class, + () -> searcher.getIndexReader().leaves().get(0).reader().terms(STRING_FIELD_NAME).intersect(automaton, null)); expectThrows(TaskCancelledException.class, - () -> searcher.search(new MatchAllDocsQuery(), collector)); + () -> searcher.getIndexReader().leaves().get(0).reader().getPointValues(POINT_FIELD_NAME)); + expectThrows(TaskCancelledException.class, + () -> searcher.getIndexReader().leaves().get(0).reader().getPointValues(POINT_FIELD_NAME)); + + cancelled.set(false); // Avoid exception during construction of the wrapper objects + Terms terms = searcher.getIndexReader().leaves().get(0).reader().terms(STRING_FIELD_NAME); + TermsEnum termsIterator = terms.iterator(); + TermsEnum termsIntersect = terms.intersect(automaton, null); + PointValues pointValues1 = searcher.getIndexReader().leaves().get(0).reader().getPointValues(POINT_FIELD_NAME); + cancelled.set(true); + expectThrows(TaskCancelledException.class, termsIterator::next); + expectThrows(TaskCancelledException.class, termsIntersect::next); + expectThrows(TaskCancelledException.class, pointValues1::getDocCount); + expectThrows(TaskCancelledException.class, pointValues1::getNumIndexDimensions); + expectThrows(TaskCancelledException.class, () -> pointValues1.intersect(new PointValuesIntersectVisitor())); + + cancelled.set(false); // Avoid exception during construction of the wrapper objects + // Re-initialize objects so that we reset the `calls` counter used to avoid cancellation check + // on every iteration and assure that cancellation would normally happen if we hadn't removed the + // cancellation runnable. + termsIterator = terms.iterator(); + termsIntersect = terms.intersect(automaton, null); + PointValues pointValues2 = searcher.getIndexReader().leaves().get(0).reader().getPointValues(POINT_FIELD_NAME); + cancelled.set(true); + searcher.removeQueryCancellation(cancellation); + termsIterator.next(); + termsIntersect.next(); + pointValues2.getDocCount(); + pointValues2.getNumIndexDimensions(); + pointValues2.intersect(new PointValuesIntersectVisitor()); } + private static class PointValuesIntersectVisitor implements PointValues.IntersectVisitor { + @Override + public void visit(int docID) { + } + + @Override + public void visit(int docID, byte[] packedValue) { + } + + @Override + public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + return PointValues.Relation.CELL_CROSSES_QUERY; + } + } } diff --git a/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java b/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java index 186436e4bad3f..61a75a0716fd2 100644 --- a/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java +++ b/server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java @@ -22,6 +22,7 @@ import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.StringField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.FilterDirectoryReader; @@ -76,6 +77,9 @@ import java.util.Set; import static org.elasticsearch.search.internal.ContextIndexSearcher.intersectScorerAndBitSet; +import static org.elasticsearch.search.internal.ExitableDirectoryReader.ExitableLeafReader; +import static org.elasticsearch.search.internal.ExitableDirectoryReader.ExitablePointValues; +import static org.elasticsearch.search.internal.ExitableDirectoryReader.ExitableTerms; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -191,6 +195,8 @@ public void doTestContextIndexSearcher(boolean sparse, boolean deletions) throws doc.add(fooField); StringField deleteField = new StringField("delete", "no", Field.Store.NO); doc.add(deleteField); + IntPoint pointField = new IntPoint("point", 1, 2); + doc.add(pointField); w.addDocument(doc); if (deletions) { // add a document that matches foo:bar but will be deleted @@ -235,7 +241,19 @@ public void onRemoval(ShardId shardId, Accountable accountable) { ContextIndexSearcher searcher = new ContextIndexSearcher(filteredReader, IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy()); - searcher.setCheckCancelled(() -> {}); + + // Assert wrapping + assertEquals(ExitableDirectoryReader.class, searcher.getIndexReader().getClass()); + for (LeafReaderContext lrc : searcher.getIndexReader().leaves()) { + assertEquals(ExitableLeafReader.class, lrc.reader().getClass()); + assertNotEquals(ExitableTerms.class, lrc.reader().terms("foo").getClass()); + assertNotEquals(ExitablePointValues.class, lrc.reader().getPointValues("point").getClass()); + } + searcher.addQueryCancellation(() -> {}); + for (LeafReaderContext lrc : searcher.getIndexReader().leaves()) { + assertEquals(ExitableTerms.class, lrc.reader().terms("foo").getClass()); + assertEquals(ExitablePointValues.class, lrc.reader().getPointValues("point").getClass()); + } // Searching a non-existing term will trigger a null scorer assertEquals(0, searcher.count(new TermQuery(new Term("non_existing_field", "non_existing_value")))); diff --git a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java index 95ed82aa855b3..4b70153261d3f 100644 --- a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java @@ -837,12 +837,12 @@ public void testMinScore() throws Exception { } - private static ContextIndexSearcher newContextSearcher(IndexReader reader) { + private static ContextIndexSearcher newContextSearcher(IndexReader reader) throws IOException { return new ContextIndexSearcher(reader, IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy()); } - private static ContextIndexSearcher newEarlyTerminationContextSearcher(IndexReader reader, int size) { + private static ContextIndexSearcher newEarlyTerminationContextSearcher(IndexReader reader, int size) throws IOException { return new ContextIndexSearcher(reader, IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy()) { @@ -855,7 +855,7 @@ public void search(List leaves, Weight weight, Collector coll } // used to check that numeric long or date sort optimization was run - private static ContextIndexSearcher newOptimizedContextSearcher(IndexReader reader, int queryType) { + private static ContextIndexSearcher newOptimizedContextSearcher(IndexReader reader, int queryType) throws IOException { return new ContextIndexSearcher(reader, IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy()) { diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 73a6f900fc279..97992d04205f5 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -180,7 +180,7 @@ protected SearchContext createSearchContext(IndexSearcher indexSearcher, IndexSettings indexSettings, Query query, MultiBucketConsumer bucketConsumer, - MappedFieldType... fieldTypes) { + MappedFieldType... fieldTypes) throws IOException { return createSearchContext(indexSearcher, indexSettings, query, bucketConsumer, new NoneCircuitBreakerService(), fieldTypes); } @@ -189,7 +189,7 @@ protected SearchContext createSearchContext(IndexSearcher indexSearcher, Query query, MultiBucketConsumer bucketConsumer, CircuitBreakerService circuitBreakerService, - MappedFieldType... fieldTypes) { + MappedFieldType... fieldTypes) throws IOException { QueryCache queryCache = new DisabledQueryCache(indexSettings); QueryCachingPolicy queryCachingPolicy = new QueryCachingPolicy() { @Override @@ -203,7 +203,7 @@ public boolean shouldCache(Query query) { } }; ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(indexSearcher.getIndexReader(), - indexSearcher.getSimilarity(), queryCache, queryCachingPolicy); + indexSearcher.getSimilarity(), queryCache, queryCachingPolicy, null); SearchContext searchContext = mock(SearchContext.class); when(searchContext.numberOfShards()).thenReturn(1);