Skip to content

Commit

Permalink
Enable '.' for nested field in text embedding processor (opensearch-p…
Browse files Browse the repository at this point in the history
…roject#811)

* Added nested structure for text embed processor mapping

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored and vibrantvarun committed Jul 9, 2024
1 parent 5358580 commit dfe3963
Show file tree
Hide file tree
Showing 9 changed files with 467 additions and 61 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.15...2.x)
### Features
### Enhancements
* Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/)
- Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/)
- Enable '.' for nested field in text embedding processor ([#811](https://github.com/opensearch-project/neural-search/pull/811))
### Bug Fixes
### Infrastructure
- Add BWC for batch ingestion ([#769](https://github.com/opensearch-project/neural-search/pull/769))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
Expand All @@ -21,6 +22,8 @@
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -120,7 +123,7 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
try {
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocument);
Map<String, Object> processMap = buildMapWithTargetKeys(ingestDocument);
List<String> inferenceList = createInferenceList(processMap);
if (inferenceList.size() == 0) {
handler.accept(ingestDocument, null);
Expand Down Expand Up @@ -228,7 +231,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
List<String> inferenceList = null;
try {
validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithTargetKeys(ingestDocumentWrapper.getIngestDocument());
inferenceList = createInferenceList(processMap);
} catch (Exception e) {
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
Expand Down Expand Up @@ -276,15 +279,17 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List<String>
}

@VisibleForTesting
Map<String, Object> buildMapWithTargetKeyAndOriginalValue(IngestDocument ingestDocument) {
Map<String, Object> buildMapWithTargetKeys(IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
Map<String, Object> mapWithProcessorKeys = new LinkedHashMap<>();
for (Map.Entry<String, Object> fieldMapEntry : fieldMap.entrySet()) {
String originalKey = fieldMapEntry.getKey();
Object targetKey = fieldMapEntry.getValue();
Pair<String, Object> processedNestedKey = processNestedKey(fieldMapEntry);
String originalKey = processedNestedKey.getKey();
Object targetKey = processedNestedKey.getValue();

if (targetKey instanceof Map) {
Map<String, Object> treeRes = new LinkedHashMap<>();
buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes);
buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes);
mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey));
} else {
mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey));
Expand All @@ -293,20 +298,19 @@ Map<String, Object> buildMapWithTargetKeyAndOriginalValue(IngestDocument ingestD
return mapWithProcessorKeys;
}

private void buildMapWithProcessorKeyAndOriginalValueForMapType(
String parentKey,
Object processorKey,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> treeRes
) {
if (processorKey == null || sourceAndMetadataMap == null) return;
@VisibleForTesting
void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> sourceAndMetadataMap, Map<String, Object> treeRes) {
if (Objects.isNull(processorKey) || Objects.isNull(sourceAndMetadataMap)) {
return;
}
if (processorKey instanceof Map) {
Map<String, Object> next = new LinkedHashMap<>();
if (sourceAndMetadataMap.get(parentKey) instanceof Map) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
buildMapWithProcessorKeyAndOriginalValueForMapType(
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
Pair<String, Object> processedNestedKey = processNestedKey(nestedFieldMapEntry);
buildNestedMap(
processedNestedKey.getKey(),
processedNestedKey.getValue(),
(Map<String, Object>) sourceAndMetadataMap.get(parentKey),
next
);
Expand All @@ -317,21 +321,46 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType(
List<Object> listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList());
Map<String, Object> map = new LinkedHashMap<>();
map.put(nestedFieldMapEntry.getKey(), listOfStrings);
buildMapWithProcessorKeyAndOriginalValueForMapType(
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
map,
next
);
buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next);
}
}
treeRes.put(parentKey, next);
treeRes.merge(parentKey, next, (v1, v2) -> {
if (v1 instanceof Collection && v2 instanceof Collection) {
((Collection) v1).addAll((Collection) v2);
return v1;
} else if (v1 instanceof Map && v2 instanceof Map) {
((Map) v1).putAll((Map) v2);
return v1;
} else {
return v2;
}
});
} else {
String key = String.valueOf(processorKey);
treeRes.put(key, sourceAndMetadataMap.get(parentKey));
}
}

/**
* Process the nested key, such as "a.b.c" to "a", "b.c"
* @param nestedFieldMapEntry
* @return A pair of the original key and the target key
*/
@VisibleForTesting
protected Pair<String, Object> processNestedKey(final Map.Entry<String, Object> nestedFieldMapEntry) {
String originalKey = nestedFieldMapEntry.getKey();
Object targetKey = nestedFieldMapEntry.getValue();
int nestedDotIndex = originalKey.indexOf('.');
if (nestedDotIndex != -1) {
Map<String, Object> newTargetKey = new LinkedHashMap<>();
newTargetKey.put(originalKey.substring(nestedDotIndex + 1), targetKey);
targetKey = newTargetKey;

originalKey = originalKey.substring(0, nestedDotIndex);
}
return new ImmutablePair<>(originalKey, targetKey);
}

private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.processor;

import com.google.common.collect.ImmutableList;
import org.apache.commons.lang.math.RandomUtils;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
Expand Down Expand Up @@ -58,4 +59,17 @@ protected List<List<Float>> createMockVectorResult() {
modelTensorList.add(number7);
return modelTensorList;
}

protected List<List<Float>> createRandomOneDimensionalMockVector(int numOfVectors, int vectorDimension, float min, float max) {
List<List<Float>> result = new ArrayList<>();
for (int i = 0; i < numOfVectors; i++) {
List<Float> numbers = new ArrayList<>();
for (int j = 0; j < vectorDimension; j++) {
Float nextFloat = RandomUtils.nextFloat() * (max - min) + min;
numbers.add(nextFloat);
}
result.add(numbers);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,31 @@
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.apache.hc.core5.http.message.BasicHeader;
import org.apache.lucene.search.join.ScoreMode;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;

import com.google.common.collect.ImmutableList;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT {

private static final String INDEX_NAME = "text_embedding_index";

private static final String PIPELINE_NAME = "pipeline-hybrid";
protected static final String QUERY_TEXT = "hello";
protected static final String LEVEL_1_FIELD = "nested_passages";
protected static final String LEVEL_2_FIELD = "level_2";
protected static final String LEVEL_3_FIELD_TEXT = "level_3_text";
protected static final String LEVEL_3_FIELD_EMBEDDING = "level_3_embedding";
private final String INGEST_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc1.json").toURI()));
private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI()));
private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI()));
private final String BULK_ITEM_TEMPLATE = Files.readString(
Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI())
);
Expand Down Expand Up @@ -77,6 +87,66 @@ public void testTextEmbeddingProcessor_batch() throws Exception {
}
}

public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING);
createTextEmbeddingIndex();
ingestDocument(INGEST_DOC3, "3");

Map<String, Object> sourceMap = (Map<String, Object>) getDocById(INDEX_NAME, "3").get("_source");
assertNotNull(sourceMap);
assertTrue(sourceMap.containsKey(LEVEL_1_FIELD));
Map<String, Object> nestedPassages = (Map<String, Object>) sourceMap.get(LEVEL_1_FIELD);
assertTrue(nestedPassages.containsKey(LEVEL_2_FIELD));
Map<String, Object> level2 = (Map<String, Object>) nestedPassages.get(LEVEL_2_FIELD);
assertEquals(QUERY_TEXT, level2.get(LEVEL_3_FIELD_TEXT));
assertTrue(level2.containsKey(LEVEL_3_FIELD_EMBEDDING));
List<Double> embeddings = (List<Double>) level2.get(LEVEL_3_FIELD_EMBEDDING);
assertEquals(768, embeddings.size());
for (Double embedding : embeddings) {
assertTrue(embedding >= 0.0 && embedding <= 1.0);
}

NeuralQueryBuilder neuralQueryBuilderQuery = new NeuralQueryBuilder(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_EMBEDDING,
QUERY_TEXT,
"",
modelId,
10,
null,
null,
null,
null,
null
);
QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD,
neuralQueryBuilderQuery,
ScoreMode.Total
);
QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total);

Map<String, Object> searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 1);
assertNotNull(searchResponseAsMap);

Map<String, Object> hits = (Map<String, Object>) searchResponseAsMap.get("hits");
assertNotNull(hits);

assertEquals(1.0, hits.get("max_score"));
List<Map<String, Object>> listOfHits = (List<Map<String, Object>>) hits.get("hits");
assertNotNull(listOfHits);
assertEquals(1, listOfHits.size());
Map<String, Object> hitsInner = listOfHits.get(0);
assertEquals("3", hitsInner.get("_id"));
assertEquals(1.0, hitsInner.get("_score"));
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

private String uploadTextEmbeddingModel() throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
return registerModelGroupAndUploadModel(requestBody);
Expand Down
Loading

0 comments on commit dfe3963

Please sign in to comment.