Skip to content

Commit

Permalink
Add coverage for NeuralSearch class (#898)
Browse files Browse the repository at this point in the history
* Add coverage for NeuralSearch class

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Sep 19, 2024
1 parent f58d989 commit 481a347
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 18 deletions.
31 changes: 24 additions & 7 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
# Ignore Gradle project-specific cache directory
.gradle
# intellij files
.idea/
*.iml
*.ipr
*.iws
*.log
build-idea/
out/

# Ignore Gradle build output directory
build
.idea
.DS_Store
.gitattributes
# eclipse files
.classpath
.project
.settings

# gradle stuff
.gradle/
build/
bin/

# vscode stuff
.vscode/

# osx stuff
.DS_Store

# git stuff
.gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
*/
package org.opensearch.neuralsearch.plugin;

import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.junit.Before;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.env.Environment;
import org.opensearch.indices.IndicesService;
import org.opensearch.ingest.IngestService;
Expand All @@ -21,22 +29,72 @@
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.plugins.SearchPlugin.SearchExtSpec;
import org.opensearch.search.pipeline.Processor.Factory;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchPipelineService;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;

public class NeuralSearchTests extends OpenSearchQueryTestCase {

private NeuralSearch plugin;

@Mock
private SearchPipelineService searchPipelineService;
private SearchPipelinePlugin.Parameters searchParameters;
@Mock
private IngestService ingestService;
private Processor.Parameters ingestParameters;
@Mock
private ClusterService clusterService;
@Mock
private ThreadPool threadPool;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);

plugin = new NeuralSearch();

when(searchPipelineService.getClusterService()).thenReturn(clusterService);
searchParameters = new SearchPipelinePlugin.Parameters(null, null, null, null, null, null, searchPipelineService, null, null, null);
ingestParameters = new Processor.Parameters(null, null, null, null, null, null, ingestService, null, null, null);
when(threadPool.executor(anyString())).thenReturn(OpenSearchExecutors.newDirectExecutorService());
}

public void testCreateComponents() {
// clientAccessor can not be null, and this is the only way to access it from this test
plugin.getProcessors(ingestParameters);
Collection<Object> components = plugin.createComponents(
null,
clusterService,
threadPool,
null,
null,
null,
null,
null,
null,
null,
null
);

assertEquals(1, components.size());
}

public void testQuerySpecs() {
NeuralSearch plugin = new NeuralSearch();
List<SearchPlugin.QuerySpec<?>> querySpecs = plugin.getQueries();

assertNotNull(querySpecs);
Expand All @@ -46,7 +104,6 @@ public void testQuerySpecs() {
}

public void testQueryPhaseSearcher() {
NeuralSearch plugin = new NeuralSearch();
Optional<QueryPhaseSearcher> queryPhaseSearcherWithFeatureFlagDisabled = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcherWithFeatureFlagDisabled);
Expand All @@ -62,7 +119,6 @@ public void testQueryPhaseSearcher() {
}

public void testProcessors() {
NeuralSearch plugin = new NeuralSearch();
Settings settings = Settings.builder().build();
Environment environment = mock(Environment.class);
when(environment.settings()).thenReturn(settings);
Expand All @@ -84,10 +140,8 @@ public void testProcessors() {
}

public void testSearchPhaseResultsProcessors() {
NeuralSearch plugin = new NeuralSearch();
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseResultsProcessor>> searchPhaseResultsProcessors = plugin
.getSearchPhaseResultsProcessors(parameters);
.getSearchPhaseResultsProcessors(searchParameters);
assertNotNull(searchPhaseResultsProcessors);
assertEquals(1, searchPhaseResultsProcessors.size());
assertTrue(searchPhaseResultsProcessors.containsKey("normalization-processor"));
Expand All @@ -97,19 +151,34 @@ public void testSearchPhaseResultsProcessors() {
assertTrue(scoringProcessor instanceof NormalizationProcessorFactory);
}

public void testGetSettings() {
List<Setting<?>> settings = plugin.getSettings();

assertEquals(2, settings.size());
}

public void testRequestProcessors() {
NeuralSearch plugin = new NeuralSearch();
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRequestProcessor>> processors = plugin.getRequestProcessors(
parameters
searchParameters
);
assertNotNull(processors);
assertNotNull(processors.get(NeuralQueryEnricherProcessor.TYPE));
assertNotNull(processors.get(NeuralSparseTwoPhaseProcessor.TYPE));
}

public void testResponseProcessors() {
Map<String, Factory<SearchResponseProcessor>> processors = plugin.getResponseProcessors(searchParameters);
assertNotNull(processors);
assertNotNull(processors.get(RerankProcessor.TYPE));
}

public void testSearchExts() {
List<SearchExtSpec<?>> searchExts = plugin.getSearchExts();

assertEquals(1, searchExts.size());
}

public void testExecutionBuilders() {
NeuralSearch plugin = new NeuralSearch();
Settings settings = Settings.builder().build();
Environment environment = mock(Environment.class);
when(environment.settings()).thenReturn(settings);
Expand All @@ -120,5 +189,4 @@ public void testExecutionBuilders() {
assertEquals("Unexpected number of executor builders are registered", 1, executorBuilders.size());
assertTrue(executorBuilders.get(0) instanceof FixedExecutorBuilder);
}

}

0 comments on commit 481a347

Please sign in to comment.