diff --git a/.buildkite/scripts/gradle-cache-validation.sh b/.buildkite/scripts/gradle-cache-validation.sh new file mode 100755 index 0000000000000..fbb957bc3b26b --- /dev/null +++ b/.buildkite/scripts/gradle-cache-validation.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +set -euo pipefail + +VALIDATION_SCRIPTS_VERSION=2.5.1 +GRADLE_ENTERPRISE_ACCESS_KEY=$(vault kv get -field=value secret/ci/elastic-elasticsearch/gradle-enterprise-api-key) +export GRADLE_ENTERPRISE_ACCESS_KEY + +curl -s -L -O https://github.com/gradle/gradle-enterprise-build-validation-scripts/releases/download/v$VALIDATION_SCRIPTS_VERSION/gradle-enterprise-gradle-build-validation-$VALIDATION_SCRIPTS_VERSION.zip && unzip -q -o gradle-enterprise-gradle-build-validation-$VALIDATION_SCRIPTS_VERSION.zip + +# Create a temporary file +tmpOutputFile=$(mktemp) +trap "rm $tmpOutputFile" EXIT + +gradle-enterprise-gradle-build-validation/03-validate-local-build-caching-different-locations.sh -r https://github.com/elastic/elasticsearch.git -b $BUILDKITE_BRANCH --gradle-enterprise-server https://gradle-enterprise.elastic.co -t precommit --fail-if-not-fully-cacheable | tee $tmpOutputFile + +# Capture the return value +retval=$? + +# Now read the content from the temporary file into a variable +perfOutput=$(cat $tmpOutputFile | sed -n '/Performance Characteristics/,/See https:\/\/gradle.com\/bvs\/main\/Gradle.md#performance-characteristics for details./p' | sed '$d' | sed 's/\x1b\[[0-9;]*m//g') +investigationOutput=$(cat $tmpOutputFile | sed -n '/Investigation Quick Links/,$p' | sed 's/\x1b\[[0-9;]*m//g') + +# Initialize HTML output variable +summaryHtml="

Performance Characteristics

" +summaryHtml+="" + +# generate html for links +summaryHtml+="

Investigation Links

" +summaryHtml+="" + +cat << EOF | buildkite-agent annotate --context "ctx-validation-summary" --style "info" +$summaryHtml +EOF + +# Check if the command was successful +if [ $retval -eq 0 ]; then + echo "Experiment completed successfully" +elif [ $retval -eq 1 ]; then + echo "An invalid input was provided while attempting to run the experiment" +elif [ $retval -eq 2 ]; then + echo "One of the builds that is part of the experiment failed" +elif [ $retval -eq 3 ]; then + echo "The build was not fully cacheable for the given task graph" +elif [ $retval -eq 3 ]; then + echo "An unclassified, fatal error happened while running the experiment" +fi + +exit $retval + diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/precommit/DependencyLicensesTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/precommit/DependencyLicensesTask.java index 0099a4616f829..07817fdaed1fe 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/precommit/DependencyLicensesTask.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/precommit/DependencyLicensesTask.java @@ -23,11 +23,14 @@ import org.gradle.api.provider.Property; import org.gradle.api.provider.Provider; import org.gradle.api.specs.Spec; +import org.gradle.api.tasks.CacheableTask; import org.gradle.api.tasks.Input; import org.gradle.api.tasks.InputDirectory; import org.gradle.api.tasks.InputFiles; import org.gradle.api.tasks.Optional; import org.gradle.api.tasks.OutputDirectory; +import org.gradle.api.tasks.PathSensitive; +import org.gradle.api.tasks.PathSensitivity; import org.gradle.api.tasks.TaskAction; import java.io.File; @@ -89,6 +92,7 @@ * for the dependency. This artifact will be redistributed by us with the release to * comply with the license terms. */ +@CacheableTask public abstract class DependencyLicensesTask extends DefaultTask { private final Pattern regex = Pattern.compile("-v?\\d+.*"); @@ -149,6 +153,7 @@ public DependencyLicensesTask(ObjectFactory objects, ProjectLayout projectLayout } @InputFiles + @PathSensitive(PathSensitivity.NAME_ONLY) public FileCollection getDependencies() { return dependencies; } @@ -159,6 +164,7 @@ public void setDependencies(FileCollection dependencies) { @Optional @InputDirectory + @PathSensitive(PathSensitivity.RELATIVE) public File getLicensesDir() { File asFile = licensesDir.get().getAsFile(); if (asFile.exists()) { diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/precommit/SplitPackagesAuditTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/precommit/SplitPackagesAuditTask.java index ec279589a6bed..f75adbe640297 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/precommit/SplitPackagesAuditTask.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/precommit/SplitPackagesAuditTask.java @@ -20,6 +20,7 @@ import org.gradle.api.provider.MapProperty; import org.gradle.api.provider.Property; import org.gradle.api.provider.SetProperty; +import org.gradle.api.tasks.CacheableTask; import org.gradle.api.tasks.CompileClasspath; import org.gradle.api.tasks.Input; import org.gradle.api.tasks.InputFiles; @@ -56,6 +57,7 @@ /** * Checks for split packages with dependencies. These are not allowed in a future modularized world. */ +@CacheableTask public class SplitPackagesAuditTask extends DefaultTask { private static final Logger LOGGER = Logging.getLogger(SplitPackagesAuditTask.class); diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/AbstractVersionsTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/AbstractVersionsTask.java index 0ab3a9b917d65..ad39faad1bc85 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/AbstractVersionsTask.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/AbstractVersionsTask.java @@ -8,19 +8,119 @@ package org.elasticsearch.gradle.internal.release; +import com.github.javaparser.GeneratedJavaParserConstants; +import com.github.javaparser.ast.CompilationUnit; +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration; +import com.github.javaparser.ast.body.FieldDeclaration; +import com.github.javaparser.ast.expr.IntegerLiteralExpr; +import com.github.javaparser.ast.observer.ObservableProperty; +import com.github.javaparser.printer.ConcreteSyntaxModel; +import com.github.javaparser.printer.concretesyntaxmodel.CsmElement; +import com.github.javaparser.printer.lexicalpreservation.LexicalPreservingPrinter; + import org.gradle.api.DefaultTask; +import org.gradle.api.logging.Logger; +import org.gradle.api.logging.Logging; import org.gradle.initialization.layout.BuildLayout; +import java.io.IOException; +import java.lang.reflect.Field; +import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.List; +import java.util.Map; +import java.util.OptionalInt; +import java.util.stream.Collectors; + +import static com.github.javaparser.ast.observer.ObservableProperty.TYPE_PARAMETERS; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmConditional.Condition.FLAG; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.block; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.child; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.comma; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.comment; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.conditional; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.list; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.newline; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.none; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.sequence; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.space; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.string; +import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.token; public abstract class AbstractVersionsTask extends DefaultTask { + static { + replaceDefaultJavaParserClassCsm(); + } + + /* + * The default JavaParser CSM which it uses to format any new declarations added to a class + * inserts two newlines after each declaration. Our version classes only have one newline. + * In order to get javaparser lexical printer to use our format, we have to completely replace + * the statically declared CSM pattern using hacky reflection + * to access the static map where these are stored, and insert a replacement that is identical + * apart from only one newline at the end of each member declaration, rather than two. + */ + private static void replaceDefaultJavaParserClassCsm() { + try { + Field classCsms = ConcreteSyntaxModel.class.getDeclaredField("concreteSyntaxModelByClass"); + classCsms.setAccessible(true); + @SuppressWarnings({ "unchecked", "rawtypes" }) + Map csms = (Map) classCsms.get(null); + + // copied from the static initializer in ConcreteSyntaxModel + csms.put( + ClassOrInterfaceDeclaration.class, + sequence( + comment(), + list(ObservableProperty.ANNOTATIONS, newline(), none(), newline()), + list(ObservableProperty.MODIFIERS, space(), none(), space()), + conditional( + ObservableProperty.INTERFACE, + FLAG, + token(GeneratedJavaParserConstants.INTERFACE), + token(GeneratedJavaParserConstants.CLASS) + ), + space(), + child(ObservableProperty.NAME), + list( + TYPE_PARAMETERS, + sequence(comma(), space()), + string(GeneratedJavaParserConstants.LT), + string(GeneratedJavaParserConstants.GT) + ), + list( + ObservableProperty.EXTENDED_TYPES, + sequence(string(GeneratedJavaParserConstants.COMMA), space()), + sequence(space(), token(GeneratedJavaParserConstants.EXTENDS), space()), + none() + ), + list( + ObservableProperty.IMPLEMENTED_TYPES, + sequence(string(GeneratedJavaParserConstants.COMMA), space()), + sequence(space(), token(GeneratedJavaParserConstants.IMPLEMENTS), space()), + none() + ), + space(), + block(sequence(newline(), list(ObservableProperty.MEMBERS, sequence(newline()/*, newline()*/), newline(), newline()))) + ) + ); + } catch (ReflectiveOperationException e) { + throw new AssertionError(e); + } + } + + private static final Logger LOGGER = Logging.getLogger(AbstractVersionsTask.class); + static final String TRANSPORT_VERSION_TYPE = "TransportVersion"; static final String INDEX_VERSION_TYPE = "IndexVersion"; static final String SERVER_MODULE_PATH = "server/src/main/java/"; - static final String TRANSPORT_VERSION_FILE_PATH = SERVER_MODULE_PATH + "org/elasticsearch/TransportVersions.java"; - static final String INDEX_VERSION_FILE_PATH = SERVER_MODULE_PATH + "org/elasticsearch/index/IndexVersions.java"; + + static final String VERSION_FILE_PATH = SERVER_MODULE_PATH + "org/elasticsearch/Version.java"; + static final String TRANSPORT_VERSIONS_FILE_PATH = SERVER_MODULE_PATH + "org/elasticsearch/TransportVersions.java"; + static final String INDEX_VERSIONS_FILE_PATH = SERVER_MODULE_PATH + "org/elasticsearch/index/IndexVersions.java"; static final String SERVER_RESOURCES_PATH = "server/src/main/resources/"; static final String TRANSPORT_VERSIONS_RECORD = SERVER_RESOURCES_PATH + "org/elasticsearch/TransportVersions.csv"; @@ -32,4 +132,34 @@ protected AbstractVersionsTask(BuildLayout layout) { rootDir = layout.getRootDirectory().toPath(); } + static Map splitVersionIds(List version) { + return version.stream().map(l -> { + var split = l.split(":"); + if (split.length != 2) throw new IllegalArgumentException("Invalid tag format [" + l + "]"); + return split; + }).collect(Collectors.toMap(l -> l[0], l -> Integer.parseInt(l[1]))); + } + + static OptionalInt findSingleIntegerExpr(FieldDeclaration field) { + var ints = field.findAll(IntegerLiteralExpr.class); + switch (ints.size()) { + case 0 -> { + return OptionalInt.empty(); + } + case 1 -> { + return OptionalInt.of(ints.get(0).asNumber().intValue()); + } + default -> { + LOGGER.warn("Multiple integers found in version field declaration [{}]", field); // and ignore it + return OptionalInt.empty(); + } + } + } + + static void writeOutNewContents(Path file, CompilationUnit unit) throws IOException { + if (unit.containsData(LexicalPreservingPrinter.NODE_TEXT_DATA) == false) { + throw new IllegalArgumentException("CompilationUnit has no lexical information for output"); + } + Files.writeString(file, LexicalPreservingPrinter.print(unit), StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING); + } } diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ExtractCurrentVersionsTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ExtractCurrentVersionsTask.java index 3530d7ef9e807..53dd55041f6bd 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ExtractCurrentVersionsTask.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ExtractCurrentVersionsTask.java @@ -11,7 +11,6 @@ import com.github.javaparser.StaticJavaParser; import com.github.javaparser.ast.CompilationUnit; import com.github.javaparser.ast.body.FieldDeclaration; -import com.github.javaparser.ast.expr.IntegerLiteralExpr; import org.gradle.api.logging.Logger; import org.gradle.api.logging.Logging; @@ -53,11 +52,11 @@ public void executeTask() throws IOException { LOGGER.lifecycle("Extracting latest version information"); List output = new ArrayList<>(); - int transportVersion = readLatestVersion(rootDir.resolve(TRANSPORT_VERSION_FILE_PATH)); + int transportVersion = readLatestVersion(rootDir.resolve(TRANSPORT_VERSIONS_FILE_PATH)); LOGGER.lifecycle("Transport version: {}", transportVersion); output.add(TRANSPORT_VERSION_TYPE + ":" + transportVersion); - int indexVersion = readLatestVersion(rootDir.resolve(INDEX_VERSION_FILE_PATH)); + int indexVersion = readLatestVersion(rootDir.resolve(INDEX_VERSIONS_FILE_PATH)); LOGGER.lifecycle("Index version: {}", indexVersion); output.add(INDEX_VERSION_TYPE + ":" + indexVersion); @@ -74,21 +73,13 @@ Integer highestVersionId() { @Override public void accept(FieldDeclaration fieldDeclaration) { - var ints = fieldDeclaration.findAll(IntegerLiteralExpr.class); - switch (ints.size()) { - case 0 -> { - // No ints in the field declaration, ignore + findSingleIntegerExpr(fieldDeclaration).ifPresent(id -> { + if (highestVersionId != null && highestVersionId > id) { + LOGGER.warn("Version ids [{}, {}] out of order", highestVersionId, id); + } else { + highestVersionId = id; } - case 1 -> { - int id = ints.get(0).asNumber().intValue(); - if (highestVersionId != null && highestVersionId > id) { - LOGGER.warn("Version ids [{}, {}] out of order", highestVersionId, id); - } else { - highestVersionId = id; - } - } - default -> LOGGER.warn("Multiple integers found in version field declaration [{}]", fieldDeclaration); // and ignore it - } + }); } } diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ReleaseToolsPlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ReleaseToolsPlugin.java index 8001b82797557..08abb02ea831e 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ReleaseToolsPlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/ReleaseToolsPlugin.java @@ -52,6 +52,7 @@ public void apply(Project project) { project.getTasks().register("extractCurrentVersions", ExtractCurrentVersionsTask.class); project.getTasks().register("tagVersions", TagVersionsTask.class); + project.getTasks().register("setCompatibleVersions", SetCompatibleVersionsTask.class); final FileTree yamlFiles = projectDirectory.dir("docs/changelog") .getAsFileTree() diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/SetCompatibleVersionsTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/SetCompatibleVersionsTask.java new file mode 100644 index 0000000000000..15e0a0cc345d5 --- /dev/null +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/SetCompatibleVersionsTask.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.gradle.internal.release; + +import com.github.javaparser.StaticJavaParser; +import com.github.javaparser.ast.CompilationUnit; +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration; +import com.github.javaparser.ast.expr.NameExpr; +import com.github.javaparser.printer.lexicalpreservation.LexicalPreservingPrinter; + +import org.gradle.api.tasks.TaskAction; +import org.gradle.api.tasks.options.Option; +import org.gradle.initialization.layout.BuildLayout; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import javax.inject.Inject; + +public class SetCompatibleVersionsTask extends AbstractVersionsTask { + + private Map versionIds = Map.of(); + + @Inject + public SetCompatibleVersionsTask(BuildLayout layout) { + super(layout); + } + + @Option(option = "version-id", description = "Version id used for the release. Of the form :.") + public void versionIds(List version) { + this.versionIds = splitVersionIds(version); + } + + @TaskAction + public void executeTask() throws IOException { + if (versionIds.isEmpty()) { + throw new IllegalArgumentException("No version ids specified"); + } + Integer transportVersion = versionIds.get(TRANSPORT_VERSION_TYPE); + if (transportVersion == null) { + throw new IllegalArgumentException("TransportVersion id not specified"); + } + + Path versionJava = rootDir.resolve(TRANSPORT_VERSIONS_FILE_PATH); + CompilationUnit file = LexicalPreservingPrinter.setup(StaticJavaParser.parse(versionJava)); + + Optional modifiedFile; + + modifiedFile = setMinimumCcsTransportVersion(file, transportVersion); + + if (modifiedFile.isPresent()) { + writeOutNewContents(versionJava, modifiedFile.get()); + } + } + + static Optional setMinimumCcsTransportVersion(CompilationUnit unit, int transportVersion) { + ClassOrInterfaceDeclaration transportVersions = unit.getClassByName("TransportVersions").get(); + + String tvConstantName = transportVersions.getFields().stream().filter(f -> { + var i = findSingleIntegerExpr(f); + return i.isPresent() && i.getAsInt() == transportVersion; + }) + .map(f -> f.getVariable(0).getNameAsString()) + .findFirst() + .orElseThrow(() -> new IllegalStateException("Could not find constant for id " + transportVersion)); + + transportVersions.getFieldByName("MINIMUM_CCS_VERSION") + .orElseThrow(() -> new IllegalStateException("Could not find MINIMUM_CCS_VERSION constant")) + .getVariable(0) + .setInitializer(new NameExpr(tvConstantName)); + + return Optional.of(unit); + } +} diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/TagVersionsTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/TagVersionsTask.java index fa11746543e82..a7f67f87b602e 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/TagVersionsTask.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/TagVersionsTask.java @@ -47,11 +47,7 @@ public void release(String version) { @Option(option = "tag-version", description = "Version id to tag. Of the form :.") public void tagVersions(List version) { - this.tagVersions = version.stream().map(l -> { - var split = l.split(":"); - if (split.length != 2) throw new IllegalArgumentException("Invalid tag format [" + l + "]"); - return split; - }).collect(Collectors.toMap(l -> l[0], l -> Integer.parseInt(l[1]))); + this.tagVersions = splitVersionIds(version); } @TaskAction diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/UpdateVersionsTask.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/UpdateVersionsTask.java index 9996ffe613545..b19e5c0beacf8 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/UpdateVersionsTask.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/release/UpdateVersionsTask.java @@ -8,7 +8,6 @@ package org.elasticsearch.gradle.internal.release; -import com.github.javaparser.GeneratedJavaParserConstants; import com.github.javaparser.StaticJavaParser; import com.github.javaparser.ast.CompilationUnit; import com.github.javaparser.ast.NodeList; @@ -16,14 +15,10 @@ import com.github.javaparser.ast.body.FieldDeclaration; import com.github.javaparser.ast.body.VariableDeclarator; import com.github.javaparser.ast.expr.NameExpr; -import com.github.javaparser.ast.observer.ObservableProperty; -import com.github.javaparser.printer.ConcreteSyntaxModel; -import com.github.javaparser.printer.concretesyntaxmodel.CsmElement; import com.github.javaparser.printer.lexicalpreservation.LexicalPreservingPrinter; import com.google.common.annotations.VisibleForTesting; import org.elasticsearch.gradle.Version; -import org.gradle.api.DefaultTask; import org.gradle.api.logging.Logger; import org.gradle.api.logging.Logging; import org.gradle.api.tasks.TaskAction; @@ -31,10 +26,7 @@ import org.gradle.initialization.layout.BuildLayout; import java.io.IOException; -import java.lang.reflect.Field; -import java.nio.file.Files; import java.nio.file.Path; -import java.nio.file.StandardOpenOption; import java.util.Map; import java.util.NavigableMap; import java.util.Objects; @@ -47,93 +39,12 @@ import javax.annotation.Nullable; import javax.inject.Inject; -import static com.github.javaparser.ast.observer.ObservableProperty.TYPE_PARAMETERS; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmConditional.Condition.FLAG; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.block; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.child; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.comma; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.comment; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.conditional; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.list; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.newline; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.none; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.sequence; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.space; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.string; -import static com.github.javaparser.printer.concretesyntaxmodel.CsmElement.token; - -public class UpdateVersionsTask extends DefaultTask { - - static { - replaceDefaultJavaParserClassCsm(); - } - - /* - * The default JavaParser CSM which it uses to format any new declarations added to a class - * inserts two newlines after each declaration. Our version classes only have one newline. - * In order to get javaparser lexical printer to use our format, we have to completely replace - * the statically declared CSM pattern using hacky reflection - * to access the static map where these are stored, and insert a replacement that is identical - * apart from only one newline at the end of each member declaration, rather than two. - */ - private static void replaceDefaultJavaParserClassCsm() { - try { - Field classCsms = ConcreteSyntaxModel.class.getDeclaredField("concreteSyntaxModelByClass"); - classCsms.setAccessible(true); - @SuppressWarnings({ "unchecked", "rawtypes" }) - Map csms = (Map) classCsms.get(null); - - // copied from the static initializer in ConcreteSyntaxModel - csms.put( - ClassOrInterfaceDeclaration.class, - sequence( - comment(), - list(ObservableProperty.ANNOTATIONS, newline(), none(), newline()), - list(ObservableProperty.MODIFIERS, space(), none(), space()), - conditional( - ObservableProperty.INTERFACE, - FLAG, - token(GeneratedJavaParserConstants.INTERFACE), - token(GeneratedJavaParserConstants.CLASS) - ), - space(), - child(ObservableProperty.NAME), - list( - TYPE_PARAMETERS, - sequence(comma(), space()), - string(GeneratedJavaParserConstants.LT), - string(GeneratedJavaParserConstants.GT) - ), - list( - ObservableProperty.EXTENDED_TYPES, - sequence(string(GeneratedJavaParserConstants.COMMA), space()), - sequence(space(), token(GeneratedJavaParserConstants.EXTENDS), space()), - none() - ), - list( - ObservableProperty.IMPLEMENTED_TYPES, - sequence(string(GeneratedJavaParserConstants.COMMA), space()), - sequence(space(), token(GeneratedJavaParserConstants.IMPLEMENTS), space()), - none() - ), - space(), - block(sequence(newline(), list(ObservableProperty.MEMBERS, sequence(newline()/*, newline()*/), newline(), newline()))) - ) - ); - } catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } - } +public class UpdateVersionsTask extends AbstractVersionsTask { private static final Logger LOGGER = Logging.getLogger(UpdateVersionsTask.class); - static final String SERVER_MODULE_PATH = "server/src/main/java/"; - static final String VERSION_FILE_PATH = SERVER_MODULE_PATH + "org/elasticsearch/Version.java"; - static final Pattern VERSION_FIELD = Pattern.compile("V_(\\d+)_(\\d+)_(\\d+)(?:_(\\w+))?"); - final Path rootDir; - @Nullable private Version addVersion; private boolean setCurrent; @@ -142,7 +53,7 @@ private static void replaceDefaultJavaParserClassCsm() { @Inject public UpdateVersionsTask(BuildLayout layout) { - rootDir = layout.getRootDirectory().toPath(); + super(layout); } @Option(option = "add-version", description = "Specifies the version to add") @@ -287,11 +198,4 @@ static Optional removeVersionConstant(CompilationUnit versionJa return Optional.of(versionJava); } - - static void writeOutNewContents(Path file, CompilationUnit unit) throws IOException { - if (unit.containsData(LexicalPreservingPrinter.NODE_TEXT_DATA) == false) { - throw new IllegalArgumentException("CompilationUnit has no lexical information for output"); - } - Files.writeString(file, LexicalPreservingPrinter.print(unit), StandardOpenOption.WRITE, StandardOpenOption.TRUNCATE_EXISTING); - } } diff --git a/build-tools-internal/src/test/java/org/elasticsearch/gradle/internal/release/SetCompatibleVersionsTaskTests.java b/build-tools-internal/src/test/java/org/elasticsearch/gradle/internal/release/SetCompatibleVersionsTaskTests.java new file mode 100644 index 0000000000000..eecb953a44eb6 --- /dev/null +++ b/build-tools-internal/src/test/java/org/elasticsearch/gradle/internal/release/SetCompatibleVersionsTaskTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.gradle.internal.release; + +import com.github.javaparser.StaticJavaParser; +import com.github.javaparser.ast.CompilationUnit; + +import org.junit.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasToString; + +public class SetCompatibleVersionsTaskTests { + + @Test + public void updateMinCcsVersion() { + final String transportVersionsJava = """ + public class TransportVersions { + public static final TransportVersion V1 = def(100); + public static final TransportVersion V2 = def(200); + public static final TransportVersion V3 = def(300); + + public static final TransportVersion MINIMUM_CCS_VERSION = V2; + }"""; + final String updatedJava = """ + public class TransportVersions { + + public static final TransportVersion V1 = def(100); + + public static final TransportVersion V2 = def(200); + + public static final TransportVersion V3 = def(300); + + public static final TransportVersion MINIMUM_CCS_VERSION = V3; + } + """; + + CompilationUnit unit = StaticJavaParser.parse(transportVersionsJava); + + SetCompatibleVersionsTask.setMinimumCcsTransportVersion(unit, 300); + + assertThat(unit, hasToString(updatedJava)); + } +} diff --git a/docs/changelog/106252.yaml b/docs/changelog/106252.yaml new file mode 100644 index 0000000000000..5e3f084632b9d --- /dev/null +++ b/docs/changelog/106252.yaml @@ -0,0 +1,6 @@ +pr: 106252 +summary: Add min/max range of the `event.ingested` field to cluster state for searchable + snapshots +area: Search +type: enhancement +issues: [] diff --git a/docs/changelog/107415.yaml b/docs/changelog/107415.yaml new file mode 100644 index 0000000000000..8877d0426c60d --- /dev/null +++ b/docs/changelog/107415.yaml @@ -0,0 +1,6 @@ +pr: 107415 +summary: Fix `DecayFunctions'` `toString` +area: Search +type: bug +issues: + - 100870 diff --git a/docs/changelog/108606.yaml b/docs/changelog/108606.yaml new file mode 100644 index 0000000000000..04780bff58800 --- /dev/null +++ b/docs/changelog/108606.yaml @@ -0,0 +1,14 @@ +pr: 108606 +summary: "Extend ISO8601 datetime parser to specify forbidden fields, allowing it to be used\ + \ on more formats" +area: Infra/Core +type: enhancement +issues: [] +highlight: + title: New custom parser for more ISO-8601 date formats + body: |- + Following on from #106486, this extends the custom ISO-8601 datetime parser to cover the `strict_year`, + `strict_year_month`, `strict_date_time`, `strict_date_time_no_millis`, `strict_date_hour_minute_second`, + `strict_date_hour_minute_second_millis`, and `strict_date_hour_minute_second_fraction` date formats. + As before, the parser will use the existing java.time parser if there are parsing issues, and the + `es.datetime.java_time_parsers=true` JVM property will force the use of the old parsers regardless. diff --git a/docs/changelog/109807.yaml b/docs/changelog/109807.yaml new file mode 100644 index 0000000000000..5cf8a2c896c4e --- /dev/null +++ b/docs/changelog/109807.yaml @@ -0,0 +1,6 @@ +pr: 109807 +summary: "ESQL: Fix LOOKUP attribute shadowing" +area: ES|QL +type: bug +issues: + - 109392 diff --git a/docs/changelog/109893.yaml b/docs/changelog/109893.yaml new file mode 100644 index 0000000000000..df6d6e51236c8 --- /dev/null +++ b/docs/changelog/109893.yaml @@ -0,0 +1,5 @@ +pr: 109893 +summary: Add Anthropic messages integration to Inference API +area: Machine Learning +type: enhancement +issues: [ ] diff --git a/docs/changelog/110016.yaml b/docs/changelog/110016.yaml new file mode 100644 index 0000000000000..28ad55aa796c8 --- /dev/null +++ b/docs/changelog/110016.yaml @@ -0,0 +1,5 @@ +pr: 110016 +summary: Opt in keyword field into fallback synthetic source if needed +area: Mapping +type: enhancement +issues: [] diff --git a/docs/changelog/110059.yaml b/docs/changelog/110059.yaml new file mode 100644 index 0000000000000..ba160c091cdc2 --- /dev/null +++ b/docs/changelog/110059.yaml @@ -0,0 +1,32 @@ +pr: 110059 +summary: Adds new `bit` `element_type` for `dense_vectors` +area: Vector Search +type: feature +issues: [] +highlight: + title: Adds new `bit` `element_type` for `dense_vectors` + body: |- + This adds `bit` vector support by adding `element_type: bit` for + vectors. This new element type works for indexed and non-indexed + vectors. Additionally, it works with `hnsw` and `flat` index types. No + quantization based codec works with this element type, this is + consistent with `byte` vectors. + + `bit` vectors accept up to `32768` dimensions in size and expect vectors + that are being indexed to be encoded either as a hexidecimal string or a + `byte[]` array where each element of the `byte` array represents `8` + bits of the vector. + + `bit` vectors support script usage and regular query usage. When + indexed, all comparisons done are `xor` and `popcount` summations (aka, + hamming distance), and the scores are transformed and normalized given + the vector dimensions. + + For scripts, `l1norm` is the same as `hamming` distance and `l2norm` is + `sqrt(l1norm)`. `dotProduct` and `cosineSimilarity` are not supported. + + Note, the dimensions expected by this element_type are always to be + divisible by `8`, and the `byte[]` vectors provided for index must be + have size `dim/8` size, where each byte element represents `8` bits of + the vectors. + notable: true diff --git a/docs/changelog/110066.yaml b/docs/changelog/110066.yaml new file mode 100644 index 0000000000000..920c6304b63ae --- /dev/null +++ b/docs/changelog/110066.yaml @@ -0,0 +1,6 @@ +pr: 110066 +summary: Support flattened fields and multi-fields as dimensions in downsampling +area: Downsampling +type: bug +issues: + - 99297 diff --git a/docs/changelog/110102.yaml b/docs/changelog/110102.yaml new file mode 100644 index 0000000000000..d1b9b53e2dfc5 --- /dev/null +++ b/docs/changelog/110102.yaml @@ -0,0 +1,6 @@ +pr: 110102 +summary: Optimize ST_DISTANCE filtering with Lucene circle intersection query +area: ES|QL +type: enhancement +issues: + - 109972 diff --git a/docs/changelog/110103.yaml b/docs/changelog/110103.yaml new file mode 100644 index 0000000000000..9f613ec2b446e --- /dev/null +++ b/docs/changelog/110103.yaml @@ -0,0 +1,5 @@ +pr: 110103 +summary: Fix automatic tracking of collapse with `docvalue_fields` +area: Search +type: bug +issues: [] diff --git a/docs/changelog/110112.yaml b/docs/changelog/110112.yaml new file mode 100644 index 0000000000000..eca5fd9af15ce --- /dev/null +++ b/docs/changelog/110112.yaml @@ -0,0 +1,5 @@ +pr: 110112 +summary: Increase response size limit for batched requests +area: Machine Learning +type: bug +issues: [] diff --git a/docs/changelog/110146.yaml b/docs/changelog/110146.yaml new file mode 100644 index 0000000000000..61ba35cec319b --- /dev/null +++ b/docs/changelog/110146.yaml @@ -0,0 +1,5 @@ +pr: 110146 +summary: Fix trailing slash in `ml.get_categories` specification +area: Machine Learning +type: bug +issues: [] diff --git a/docs/changelog/110160.yaml b/docs/changelog/110160.yaml new file mode 100644 index 0000000000000..0c38c23c69067 --- /dev/null +++ b/docs/changelog/110160.yaml @@ -0,0 +1,5 @@ +pr: 110160 +summary: Opt in number fields into fallback synthetic source when doc values a… +area: Mapping +type: enhancement +issues: [] diff --git a/docs/changelog/110176.yaml b/docs/changelog/110176.yaml new file mode 100644 index 0000000000000..ae1d7d10d6dc4 --- /dev/null +++ b/docs/changelog/110176.yaml @@ -0,0 +1,5 @@ +pr: 110176 +summary: Fix trailing slash in two rollup specifications +area: Rollup +type: bug +issues: [] diff --git a/docs/changelog/110177.yaml b/docs/changelog/110177.yaml new file mode 100644 index 0000000000000..0ac5328d88df4 --- /dev/null +++ b/docs/changelog/110177.yaml @@ -0,0 +1,5 @@ +pr: 110177 +summary: Fix trailing slash in `security.put_privileges` specification +area: Authorization +type: bug +issues: [] diff --git a/docs/changelog/110186.yaml b/docs/changelog/110186.yaml new file mode 100644 index 0000000000000..23eaab118e2ab --- /dev/null +++ b/docs/changelog/110186.yaml @@ -0,0 +1,6 @@ +pr: 110186 +summary: Don't sample calls to `ReduceContext#consumeBucketsAndMaybeBreak` ins `InternalDateHistogram` + and `InternalHistogram` during reduction +area: Aggregations +type: bug +issues: [] diff --git a/docs/reference/connector/apis/set-connector-sync-job-stats-api.asciidoc b/docs/reference/connector/apis/set-connector-sync-job-stats-api.asciidoc index 4dd9cc6e67ab2..1427269d22b86 100644 --- a/docs/reference/connector/apis/set-connector-sync-job-stats-api.asciidoc +++ b/docs/reference/connector/apis/set-connector-sync-job-stats-api.asciidoc @@ -53,6 +53,9 @@ This API is mainly used by the connector service for updating sync job informati `last_seen`:: (Optional, instant) The timestamp to set the connector sync job's `last_seen` property. +`metadata`:: +(Optional, object) The connector-specific metadata. + [[set-connector-sync-job-stats-api-response-codes]] ==== {api-response-codes-title} diff --git a/docs/reference/data-streams/downsampling.asciidoc b/docs/reference/data-streams/downsampling.asciidoc index b005e83e8c95d..0b08b0972f9a1 100644 --- a/docs/reference/data-streams/downsampling.asciidoc +++ b/docs/reference/data-streams/downsampling.asciidoc @@ -18,9 +18,9 @@ Metrics solutions collect large amounts of time series data that grow over time. As that data ages, it becomes less relevant to the current state of the system. The downsampling process rolls up documents within a fixed time interval into a single summary document. Each summary document includes statistical -representations of the original data: the `min`, `max`, `sum`, `value_count`, -and `average` for each metric. Data stream <> are stored unchanged. +representations of the original data: the `min`, `max`, `sum` and `value_count` +for each metric. Data stream <> +are stored unchanged. Downsampling, in effect, lets you to trade data resolution and precision for storage size. You can include it in an <>. -[[dimension-limits]] -.Dimension limits -**** -In a TSDS, {es} uses dimensions to -generate the document `_id` and <> values. The resulting `_id` is -always a short encoded hash. To prevent the `_tsid` value from being overly -large, {es} limits the number of dimensions for an index using the -<> -index setting. While you can increase this limit, the resulting document `_tsid` -value can't exceed 32KB. Additionally the field name of a dimension cannot be -longer than 512 bytes and the each dimension value can't exceed 1kb. -**** - [discrete] [[time-series-metric]] ==== Metrics @@ -290,11 +277,6 @@ created the initial backing index has: Only data that falls inside that range can be indexed. -In our <>, -`index.look_ahead_time` is set to three hours, so only documents with a -`@timestamp` value that is within three hours previous or subsequent to the -present time are accepted for indexing. - You can use the <> to check the accepted time range for writing to any TSDS. diff --git a/docs/reference/mapping/types/dense-vector.asciidoc b/docs/reference/mapping/types/dense-vector.asciidoc index 2f09e743faa7b..f2f0b3ae8bb23 100644 --- a/docs/reference/mapping/types/dense-vector.asciidoc +++ b/docs/reference/mapping/types/dense-vector.asciidoc @@ -183,11 +183,23 @@ The following mapping parameters are accepted: `element_type`:: (Optional, string) The data type used to encode vectors. The supported data types are -`float` (default) and `byte`. `float` indexes a 4-byte floating-point -value per dimension. `byte` indexes a 1-byte integer value per dimension. -Using `byte` can result in a substantially smaller index size with the -trade off of lower precision. Vectors using `byte` require dimensions with -integer values between -128 to 127, inclusive for both indexing and searching. +`float` (default), `byte`, and bit. + +.Valid values for `element_type` +[%collapsible%open] +==== +`float`::: +indexes a 4-byte floating-point +value per dimension. This is the default value. + +`byte`::: +indexes a 1-byte integer value per dimension. + +`bit`::: +indexes a single bit per dimension. Useful for very high-dimensional vectors or models that specifically support bit vectors. +NOTE: when using `bit`, the number of dimensions must be a multiple of 8 and must represent the number of bits. + +==== `dims`:: (Optional, integer) @@ -205,7 +217,11 @@ API>>. Defaults to `true`. The vector similarity metric to use in kNN search. Documents are ranked by their vector field's similarity to the query vector. The `_score` of each document will be derived from the similarity, in a way that ensures scores are -positive and that a larger score corresponds to a higher ranking. Defaults to `cosine`. +positive and that a larger score corresponds to a higher ranking. +Defaults to `l2_norm` when `element_type: bit` otherwise defaults to `cosine`. + +NOTE: `bit` vectors only support `l2_norm` as their similarity metric. + + ^*^ This parameter can only be specified when `index` is `true`. + @@ -217,6 +233,9 @@ Computes similarity based on the L^2^ distance (also known as Euclidean distance) between the vectors. The document `_score` is computed as `1 / (1 + l2_norm(query, vector)^2)`. +For `bit` vectors, instead of using `l2_norm`, the `hamming` distance between the vectors is used. The `_score` +transformation is `(numBits - hamming(a, b)) / numBits` + `dot_product`::: Computes the dot product of two unit vectors. This option provides an optimized way to perform cosine similarity. The constraints and computed score are defined @@ -320,3 +339,112 @@ any issues, but features in technical preview are not subject to the support SLA of official GA features. `dense_vector` fields support <> . + +[[dense-vector-index-bit]] +==== Indexing & Searching bit vectors + +When using `element_type: bit`, this will treat all vectors as bit vectors. Bit vectors utilize only a single +bit per dimension and are internally encoded as bytes. This can be useful for very high-dimensional vectors or models. + +When using `bit`, the number of dimensions must be a multiple of 8 and must represent the number of bits. Additionally, +with `bit` vectors, the typical vector similarity values are effectively all scored the same, e.g. with `hamming` distance. + +Let's compare two `byte[]` arrays, each representing 40 individual bits. + +`[-127, 0, 1, 42, 127]` in bits `1000000100000000000000010010101001111111` +`[127, -127, 0, 1, 42]` in bits `0111111110000001000000000000000100101010` + +When comparing these two bit, vectors, we first take the {wikipedia}/Hamming_distance[`hamming` distance]. + +`xor` result: +``` +1000000100000000000000010010101001111111 +^ +0111111110000001000000000000000100101010 += +1111111010000001000000010010101101010101 +``` + +Then, we gather the count of `1` bits in the `xor` result: `18`. To scale for scoring, we subtract from the total number +of bits and divide by the total number of bits: `(40 - 18) / 40 = 0.55`. This would be the `_score` betwee these two +vectors. + +Here is an example of indexing and searching bit vectors: + +[source,console] +-------------------------------------------------- +PUT my-bit-vectors +{ + "mappings": { + "properties": { + "my_vector": { + "type": "dense_vector", + "dims": 40, <1> + "element_type": "bit" + } + } + } +} +-------------------------------------------------- +<1> The number of dimensions that represents the number of bits + +[source,console] +-------------------------------------------------- +POST /my-bit-vectors/_bulk?refresh +{"index": {"_id" : "1"}} +{"my_vector": [127, -127, 0, 1, 42]} <1> +{"index": {"_id" : "2"}} +{"my_vector": "8100012a7f"} <2> +-------------------------------------------------- +// TEST[continued] +<1> 5 bytes representing the 40 bit dimensioned vector +<2> A hexidecimal string representing the 40 bit dimensioned vector + +Then, when searching, you can use the `knn` query to search for similar bit vectors: + +[source,console] +-------------------------------------------------- +POST /my-bit-vectors/_search?filter_path=hits.hits +{ + "query": { + "knn": { + "query_vector": [127, -127, 0, 1, 42], + "field": "my_vector" + } + } +} +-------------------------------------------------- +// TEST[continued] + +[source,console-result] +---- +{ + "hits": { + "hits": [ + { + "_index": "my-bit-vectors", + "_id": "1", + "_score": 1.0, + "_source": { + "my_vector": [ + 127, + -127, + 0, + 1, + 42 + ] + } + }, + { + "_index": "my-bit-vectors", + "_id": "2", + "_score": 0.55, + "_source": { + "my_vector": "8100012a7f" + } + } + ] + } +} +---- + diff --git a/docs/reference/query-rules/apis/delete-query-rule.asciidoc b/docs/reference/query-rules/apis/delete-query-rule.asciidoc new file mode 100644 index 0000000000000..01b73033aa361 --- /dev/null +++ b/docs/reference/query-rules/apis/delete-query-rule.asciidoc @@ -0,0 +1,74 @@ +[role="xpack"] +[[delete-query-rule]] +=== Delete query rule + +++++ +Delete query rule +++++ + +Removes an individual query rule within an existing query ruleset. +This is a destructive action that is only recoverable by re-adding the same rule via the <> API. + +[[delete-query-rule-request]] +==== {api-request-title} + +`DELETE _query_rules//_rule/` + +[[delete-query-rule-prereq]] +==== {api-prereq-title} + +Requires the `manage_search_query_rules` privilege. + +[[delete-query_rule-path-params]] +==== {api-path-parms-title} + +``:: +(Required, string) + +``:: +(Required, string) + +[[delete-query-rule-response-codes]] +==== {api-response-codes-title} + +`400`:: +Missing `ruleset_id`, `rule_id`, or both. + +`404` (Missing resources):: +No query ruleset matching `ruleset_id` could be found, or else no rule matching `rule_id` was found in that ruleset. + +[[delete-query-rule-example]] +==== {api-examples-title} + +The following example deletes the query rule with ID `my-rule1` from the query ruleset named `my-ruleset`: + +//// +[source,console] +---- +PUT _query_rules/my-ruleset +{ + "rules": [ + { + "rule_id": "my-rule1", + "type": "pinned", + "criteria": [ + { + "type": "exact", + "metadata": "query_string", + "values": [ "marvel" ] + } + ], + "actions": { + "ids": ["id1"] + } + } + ] +} +---- +// TESTSETUP +//// + +[source,console] +---- +DELETE _query_rules/my-ruleset/_rule/my-rule1 +---- diff --git a/docs/reference/query-rules/apis/get-query-rule.asciidoc b/docs/reference/query-rules/apis/get-query-rule.asciidoc new file mode 100644 index 0000000000000..56713965d7bdc --- /dev/null +++ b/docs/reference/query-rules/apis/get-query-rule.asciidoc @@ -0,0 +1,130 @@ +[role="xpack"] +[[get-query-rule]] +=== Get query rule + +++++ +Get query rule +++++ + +Retrieves information about an individual query rule within a query ruleset. + +[[get-query-rule-request]] +==== {api-request-title} + +`GET _query_rules//_rule/` + +[[get-query-rule-prereq]] +==== {api-prereq-title} + +Requires the `manage_search_query_rules` privilege. + +[[get-query-rule-path-params]] +==== {api-path-parms-title} + +``:: +(Required, string) + +``:: +(Required, string) + +[[get-query-rule-response-codes]] +==== {api-response-codes-title} + +`400`:: +Missing `ruleset_id` or `rule_id`, or both. + +`404` (Missing resources):: +Either no query ruleset matching `ruleset_id` could be found, or no rule matching `rule_id` could be found within that ruleset. + +[[get-query-rule-example]] +==== {api-examples-title} + +The following example gets the query rule with ID `my-rule1` from the ruleset named `my-ruleset`: + +//// + +[source,console] +-------------------------------------------------- +PUT _query_rules/my-ruleset +{ + "rules": [ + { + "rule_id": "my-rule1", + "type": "pinned", + "criteria": [ + { + "type": "contains", + "metadata": "query_string", + "values": [ "pugs", "puggles" ] + } + ], + "actions": { + "ids": [ + "id1", + "id2" + ] + } + }, + { + "rule_id": "my-rule2", + "type": "pinned", + "criteria": [ + { + "type": "fuzzy", + "metadata": "query_string", + "values": [ "rescue dogs" ] + } + ], + "actions": { + "docs": [ + { + "_index": "index1", + "_id": "id3" + }, + { + "_index": "index2", + "_id": "id4" + } + ] + } + } + ] +} +-------------------------------------------------- +// TESTSETUP + +[source,console] +-------------------------------------------------- +DELETE _query_rules/my-ruleset +-------------------------------------------------- +// TEARDOWN + +//// + +[source,console] +---- +GET _query_rules/my-ruleset/_rule/my-rule1 +---- + +A sample response: + +[source,console-result] +---- +{ + "rule_id": "my-rule1", + "type": "pinned", + "criteria": [ + { + "type": "contains", + "metadata": "query_string", + "values": [ "pugs", "puggles" ] + } + ], + "actions": { + "ids": [ + "id1", + "id2" + ] + } +} +---- diff --git a/docs/reference/query-rules/apis/index.asciidoc b/docs/reference/query-rules/apis/index.asciidoc index e72d56d2f4834..f7303647f8515 100644 --- a/docs/reference/query-rules/apis/index.asciidoc +++ b/docs/reference/query-rules/apis/index.asciidoc @@ -1,6 +1,8 @@ [[query-rules-apis]] == Query rules APIs +preview::[] + ++++ Query rules APIs ++++ @@ -20,8 +22,15 @@ Use the following APIs to manage query rulesets: * <> * <> * <> +* <> +* <> +* <> include::put-query-ruleset.asciidoc[] include::get-query-ruleset.asciidoc[] include::list-query-rulesets.asciidoc[] include::delete-query-ruleset.asciidoc[] +include::put-query-rule.asciidoc[] +include::get-query-rule.asciidoc[] +include::delete-query-rule.asciidoc[] + diff --git a/docs/reference/query-rules/apis/put-query-rule.asciidoc b/docs/reference/query-rules/apis/put-query-rule.asciidoc new file mode 100644 index 0000000000000..2b9a6ba892b84 --- /dev/null +++ b/docs/reference/query-rules/apis/put-query-rule.asciidoc @@ -0,0 +1,144 @@ +[role="xpack"] +[[put-query-rule]] +=== Create or update query rule + +++++ +Create or update query rule +++++ + +Creates or updates an individual query rule within a query ruleset. + +[[put-query-rule-request]] +==== {api-request-title} + +`PUT _query_rules//_rule/` + +[[put-query-rule-prereqs]] +==== {api-prereq-title} + +Requires the `manage_search_query_rules` privilege. + +[role="child_attributes"] +[[put-query-rule-request-body]] +(Required, object) Contains parameters for a query rule: + +==== {api-request-body-title} + +`type`:: +(Required, string) The type of rule. +At this time only `pinned` query rule types are allowed. + +`criteria`:: +(Required, array of objects) The criteria that must be met for the rule to be applied. +If multiple criteria are specified for a rule, all criteria must be met for the rule to be applied. + +Criteria must have the following information: + +- `type` (Required, string) The type of criteria. +The following criteria types are supported: ++ +-- +- `exact` +Only exact matches meet the criteria defined by the rule. +Applicable for string or numerical values. +- `fuzzy` +Exact matches or matches within the allowed {wikipedia}/Levenshtein_distance[Levenshtein Edit Distance] meet the criteria defined by the rule. +Only applicable for string values. +- `prefix` +Matches that start with this value meet the criteria defined by the rule. +Only applicable for string values. +- `suffix` +Matches that end with this value meet the criteria defined by the rule. +Only applicable for string values. +- `contains` +Matches that contain this value anywhere in the field meet the criteria defined by the rule. +Only applicable for string values. +- `lt` +Matches with a value less than this value meet the criteria defined by the rule. +Only applicable for numerical values. +- `lte` +Matches with a value less than or equal to this value meet the criteria defined by the rule. +Only applicable for numerical values. +- `gt` +Matches with a value greater than this value meet the criteria defined by the rule. +Only applicable for numerical values. +- `gte` +Matches with a value greater than or equal to this value meet the criteria defined by the rule. +Only applicable for numerical values. +- `always` +Matches all queries, regardless of input. +-- +- `metadata` (Optional, string) The metadata field to match against. +This metadata will be used to match against `match_criteria` sent in the <>. +Required for all criteria types except `global`. +- `values` (Optional, array of strings) The values to match against the metadata field. +Only one value must match for the criteria to be met. +Required for all criteria types except `global`. + +`actions`:: +(Required, object) The actions to take when the rule is matched. +The format of this action depends on the rule type. + +Actions depend on the rule type. +For `pinned` rules, actions follow the format specified by the <>. +The following actions are allowed: + +- `ids` (Optional, array of strings) The unique <> of the documents to pin. +Only one of `ids` or `docs` may be specified, and at least one must be specified. +- `docs` (Optional, array of objects) The documents to pin. +Only one of `ids` or `docs` may be specified, and at least one must be specified. +You can specify the following attributes for each document: ++ +-- +- `_index` (Required, string) The index of the document to pin. +- `_id` (Required, string) The unique <>. +-- + +IMPORTANT: Due to limitations within <>, you can only pin documents using `ids` or `docs`, but cannot use both in single rule. +It is advised to use one or the other in query rulesets, to avoid errors. +Additionally, pinned queries have a maximum limit of 100 pinned hits. +If multiple matching rules pin more than 100 documents, only the first 100 documents are pinned in the order they are specified in the ruleset. + +[[put-query-rule-example]] +==== {api-examples-title} + +The following example creates a new query rule with the ID `my-rule1` in a query ruleset called `my-ruleset`. + +`my-rule1` will pin documents with IDs `id1` and `id2` when `user_query` contains `pugs` _or_ `puggles` **and** `user_country` exactly matches `us`. + +[source,console] +---- +PUT _query_rules/my-ruleset/_rule/my-rule1 +{ + "type": "pinned", + "criteria": [ + { + "type": "contains", + "metadata": "user_query", + "values": [ "pugs", "puggles" ] + }, + { + "type": "exact", + "metadata": "user_country", + "values": [ "us" ] + } + ], + "actions": { + "ids": [ + "id1", + "id2" + ] + } +} +---- +// TESTSETUP + +////////////////////////// + +[source,console] +-------------------------------------------------- +DELETE _query_rules/my-ruleset +-------------------------------------------------- +// TEARDOWN + +////////////////////////// diff --git a/docs/reference/search/search-your-data/cohere-es.asciidoc b/docs/reference/search/search-your-data/cohere-es.asciidoc index f12f23ad2c5dc..3029cfd9f098c 100644 --- a/docs/reference/search/search-your-data/cohere-es.asciidoc +++ b/docs/reference/search/search-your-data/cohere-es.asciidoc @@ -25,14 +25,15 @@ set. Refer to https://docs.cohere.com/docs/elasticsearch-and-cohere[Cohere's tutorial] for an example using a different data set. +You can also review the https://colab.research.google.com/github/elastic/elasticsearch-labs/blob/main/notebooks/integrations/cohere/cohere-elasticsearch.ipynb[Colab notebook version of this tutorial]. + [discrete] [[cohere-es-req]] ==== Requirements -* A https://cohere.com/[Cohere account], -* an https://www.elastic.co/guide/en/cloud/current/ec-getting-started.html[Elastic Cloud] -account, +* A paid https://cohere.com/[Cohere account] is required to use the {infer-cap} API with the Cohere service as the Cohere free trial API usage is limited, +* an https://www.elastic.co/guide/en/cloud/current/ec-getting-started.html[Elastic Cloud] account, * Python 3.7 or higher. @@ -329,17 +330,12 @@ they were sent to the {infer} endpoint. [[cohere-es-rag]] ==== Retrieval Augmented Generation (RAG) with Cohere and {es} -RAG is a method for generating text using additional information fetched from an -external data source. With the ranked results, you can build a RAG system on the -top of what you previously created by using -https://docs.cohere.com/docs/chat-api[Cohere's Chat API]. +https://docs.cohere.com/docs/retrieval-augmented-generation-rag[RAG] is a method for generating text using additional information fetched from an external data source. +With the ranked results, you can build a RAG system on the top of what you previously created by using https://docs.cohere.com/docs/chat-api[Cohere's Chat API]. -Pass in the retrieved documents and the query to receive a grounded response -using Cohere's newest generative model -https://docs.cohere.com/docs/command-r-plus[Command R+]. +Pass in the retrieved documents and the query to receive a grounded response using Cohere's newest generative model https://docs.cohere.com/docs/command-r-plus[Command R+]. -Then pass in the query and the documents to the Chat API, and print out the -response. +Then pass in the query and the documents to the Chat API, and print out the response. [source,py] -------------------------------------------------- diff --git a/docs/reference/search/search-your-data/knn-search.asciidoc b/docs/reference/search/search-your-data/knn-search.asciidoc index 0e61b44eda413..70cf9eec121d7 100644 --- a/docs/reference/search/search-your-data/knn-search.asciidoc +++ b/docs/reference/search/search-your-data/knn-search.asciidoc @@ -410,6 +410,24 @@ post-filtering approach, where the filter is applied **after** the approximate kNN search completes. Post-filtering has the downside that it sometimes returns fewer than k results, even when there are enough matching documents. +[discrete] +[[approximate-knn-search-and-filtering]] +==== Approximate kNN search and filtering + +Unlike conventional query filtering, where more restrictive filters typically lead to faster queries, +applying filters in an approximate kNN search with an HNSW index can decrease performance. +This is because searching the HNSW graph requires additional exploration to obtain the `num_candidates` +that meet the filter criteria. + +To avoid significant performance drawbacks, Lucene implements the following strategies per segment: + +* If the filtered document count is less than or equal to num_candidates, the search bypasses the HNSW graph and +uses a brute force search on the filtered documents. + +* While exploring the HNSW graph, if the number of nodes explored exceeds the number of documents that satisfy the filter, +the search will stop exploring the graph and switch to a brute force search over the filtered documents. + + [discrete] ==== Combine approximate kNN with other features diff --git a/docs/reference/search/search-your-data/search-application-api.asciidoc b/docs/reference/search/search-your-data/search-application-api.asciidoc index 6312751d37bca..7c9308e78ebea 100644 --- a/docs/reference/search/search-your-data/search-application-api.asciidoc +++ b/docs/reference/search/search-your-data/search-application-api.asciidoc @@ -295,6 +295,12 @@ This may be helpful when experimenting with specific search queries that you wan If your search application's name is `my_search_application`, your alias will be `my_search_application`. You can search this using the <>. +[discrete] +[[search-application-cross-cluster-search]] +===== Cross cluster search + +Search applications do not currently support {ccs} because it is not possible to add a remote cluster's index or index pattern to an index alias. + [NOTE] ==== You should use the Search Applications management APIs to update your application and _not_ directly use {es} APIs such as the alias API. diff --git a/docs/reference/settings/inference-settings.asciidoc b/docs/reference/settings/inference-settings.asciidoc index fa0905cf0ef73..3476058a17b21 100644 --- a/docs/reference/settings/inference-settings.asciidoc +++ b/docs/reference/settings/inference-settings.asciidoc @@ -34,7 +34,7 @@ message can be logged again. Defaults to one hour (`1h`). `xpack.inference.http.max_response_size`:: (<>) Specifies the maximum size in bytes an HTTP response is allowed to have, -defaults to `10mb`, the maximum configurable value is `50mb`. +defaults to `50mb`, the maximum configurable value is `100mb`. `xpack.inference.http.max_total_connections`:: (<>) Specifies the maximum number of connections the internal connection pool can diff --git a/docs/reference/vectors/vector-functions.asciidoc b/docs/reference/vectors/vector-functions.asciidoc index 4e627ef18ec6c..2a80290cf9d3b 100644 --- a/docs/reference/vectors/vector-functions.asciidoc +++ b/docs/reference/vectors/vector-functions.asciidoc @@ -1,4 +1,3 @@ -[role="xpack"] [[vector-functions]] ===== Functions for vector fields @@ -17,6 +16,8 @@ This is the list of available vector functions and vector access methods: 6. <].vectorValue`>> – returns a vector's value as an array of floats 7. <].magnitude`>> – returns a vector's magnitude +NOTE: The `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors. + NOTE: The recommended way to access dense vectors is through the `cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note however, that you should call these functions only once per script. For example, @@ -193,7 +194,7 @@ we added `1` in the denominator. ====== Hamming distance The `hamming` function calculates {wikipedia}/Hamming_distance[Hamming distance] between a given query vector and -document vectors. It is only available for byte vectors. +document vectors. It is only available for byte and bit vectors. [source,console] -------------------------------------------------- @@ -278,10 +279,14 @@ You can access vector values directly through the following functions: - `doc[].vectorValue` – returns a vector's value as an array of floats +NOTE: For `bit` vectors, it does return a `float[]`, where each element represents 8 bits. + - `doc[].magnitude` – returns a vector's magnitude as a float (for vectors created prior to version 7.5 the magnitude is not stored. So this function calculates it anew every time it is called). +NOTE: For `bit` vectors, this is just the square root of the sum of `1` bits. + For example, the script below implements a cosine similarity using these two functions: @@ -319,3 +324,14 @@ GET my-index-000001/_search } } -------------------------------------------------- +[[vector-functions-bit-vectors]] +====== Bit vectors and vector functions + +When using `bit` vectors, not all the vector functions are available. The supported functions are: + +* <> – calculates Hamming distance, the sum of the bitwise XOR of the two vectors +* <> – calculates L^1^ distance, this is simply the `hamming` distance +* <> - calculates L^2^ distance, this is the square root of the `hamming` distance + +Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors. + diff --git a/modules/aggregations/src/main/java/org/elasticsearch/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java b/modules/aggregations/src/main/java/org/elasticsearch/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java index 07a363ed727c7..dfe0a0642ccc3 100644 --- a/modules/aggregations/src/main/java/org/elasticsearch/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java +++ b/modules/aggregations/src/main/java/org/elasticsearch/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java @@ -239,8 +239,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I @Override public InternalAggregation buildEmptyAggregation() { - List buckets = new ArrayList<>(0); - return new InternalAdjacencyMatrix(name, buckets, metadata()); + return new InternalAdjacencyMatrix(name, List.of(), metadata()); } final long bucketOrd(long owningBucketOrdinal, int filterOrd) { diff --git a/modules/aggregations/src/main/java/org/elasticsearch/aggregations/bucket/timeseries/TimeSeriesAggregator.java b/modules/aggregations/src/main/java/org/elasticsearch/aggregations/bucket/timeseries/TimeSeriesAggregator.java index 53142f6cdf601..f238419687cfc 100644 --- a/modules/aggregations/src/main/java/org/elasticsearch/aggregations/bucket/timeseries/TimeSeriesAggregator.java +++ b/modules/aggregations/src/main/java/org/elasticsearch/aggregations/bucket/timeseries/TimeSeriesAggregator.java @@ -105,7 +105,7 @@ public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws I @Override public InternalAggregation buildEmptyAggregation() { - return new InternalTimeSeries(name, new ArrayList<>(), false, metadata()); + return new InternalTimeSeries(name, List.of(), false, metadata()); } @Override diff --git a/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/200_rollover_failure_store.yml b/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/200_rollover_failure_store.yml index dcbb0d2e465db..0742435f045fb 100644 --- a/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/200_rollover_failure_store.yml +++ b/modules/data-streams/src/yamlRestTest/resources/rest-api-spec/test/data_stream/200_rollover_failure_store.yml @@ -416,7 +416,7 @@ teardown: "Rolling over a failure store on a data stream without the failure store enabled should work": - do: allowed_warnings: - - "index template [my-other-template] has index patterns [data-*] matching patterns from existing older templates [global] with patterns (global => [*]); this template [my-template] will take precedence during new index creation" + - "index template [my-other-template] has index patterns [other-data-*] matching patterns from existing older templates [global] with patterns (global => [*]); this template [my-other-template] will take precedence during new index creation" indices.put_index_template: name: my-other-template body: diff --git a/modules/kibana/src/internalClusterTest/java/org/elasticsearch/kibana/KibanaThreadPoolIT.java b/modules/kibana/src/internalClusterTest/java/org/elasticsearch/kibana/KibanaThreadPoolIT.java index 48e2b14e63fc7..b4cb4404525f4 100644 --- a/modules/kibana/src/internalClusterTest/java/org/elasticsearch/kibana/KibanaThreadPoolIT.java +++ b/modules/kibana/src/internalClusterTest/java/org/elasticsearch/kibana/KibanaThreadPoolIT.java @@ -11,8 +11,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.search.SearchPhaseExecutionException; -import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.settings.Settings; @@ -37,7 +35,6 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures; -import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.startsWith; /** @@ -108,7 +105,6 @@ public void testKibanaThreadPoolByPassesBlockedThreadPools() throws Exception { }); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/107625") public void testBlockedThreadPoolsRejectUserRequests() throws Exception { assertAcked(client().admin().indices().prepareCreate(USER_INDEX)); @@ -126,15 +122,16 @@ private void assertThreadPoolsBlocked() { assertThat(e1.getMessage(), startsWith("rejected execution of TimedRunnable")); var e2 = expectThrows(EsRejectedExecutionException.class, () -> client().prepareGet(USER_INDEX, "id").get()); assertThat(e2.getMessage(), startsWith("rejected execution of ActionRunnable")); - var e3 = expectThrows( - SearchPhaseExecutionException.class, - () -> client().prepareSearch(USER_INDEX) - .setQuery(QueryBuilders.matchAllQuery()) - // Request times out if max concurrent shard requests is set to 1 - .setMaxConcurrentShardRequests(usually() ? SearchRequest.DEFAULT_MAX_CONCURRENT_SHARD_REQUESTS : randomIntBetween(2, 10)) - .get() - ); - assertThat(e3.getMessage(), containsString("all shards failed")); + // intentionally commented out this test until https://github.com/elastic/elasticsearch/issues/97916 is fixed + // var e3 = expectThrows( + // SearchPhaseExecutionException.class, + // () -> client().prepareSearch(USER_INDEX) + // .setQuery(QueryBuilders.matchAllQuery()) + // // Request times out if max concurrent shard requests is set to 1 + // .setMaxConcurrentShardRequests(usually() ? SearchRequest.DEFAULT_MAX_CONCURRENT_SHARD_REQUESTS : randomIntBetween(2, 10)) + // .get() + // ); + // assertThat(e3.getMessage(), containsString("all shards failed")); } protected void runWithBlockedThreadPools(Runnable runnable) throws Exception { diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/140_dense_vector_basic.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/140_dense_vector_basic.yml index e49dc20e73406..25088f51e2b59 100644 --- a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/140_dense_vector_basic.yml +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/140_dense_vector_basic.yml @@ -229,6 +229,7 @@ setup: Content-Type: application/json catch: bad_request search: + allow_partial_search_results: false body: query: script_score: @@ -243,6 +244,7 @@ setup: Content-Type: application/json catch: bad_request search: + allow_partial_search_results: false body: query: script_score: diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml new file mode 100644 index 0000000000000..3eb686bda2174 --- /dev/null +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml @@ -0,0 +1,392 @@ +setup: + - requires: + cluster_features: ["mapper.vectors.bit_vectors"] + reason: "support for bit vectors added in 8.15" + test_runner_features: headers + + - do: + indices.create: + index: test-index + body: + mappings: + properties: + vector: + type: dense_vector + index: false + element_type: bit + dims: 40 + indexed_vector: + type: dense_vector + element_type: bit + dims: 40 + index: true + similarity: l2_norm + + - do: + index: + index: test-index + id: "1" + body: + vector: [8, 5, -15, 1, -7] + indexed_vector: [8, 5, -15, 1, -7] + + - do: + index: + index: test-index + id: "2" + body: + vector: [-1, 115, -3, 4, -128] + indexed_vector: [-1, 115, -3, 4, -128] + + - do: + index: + index: test-index + id: "3" + body: + vector: [2, 18, -5, 0, -124] + indexed_vector: [2, 18, -5, 0, -124] + + - do: + indices.refresh: {} + +--- +"Test vector magnitude equality": + - skip: + features: close_to + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "doc['vector'].magnitude" + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "2"} + - close_to: {hits.hits.0._score: {value: 4.690416, error: 0.01}} + + - match: {hits.hits.1._id: "1"} + - close_to: {hits.hits.1._score: {value: 3.8729835, error: 0.01}} + + - match: {hits.hits.2._id: "3"} + - close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}} + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "doc['indexed_vector'].magnitude" + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "2"} + - close_to: {hits.hits.0._score: {value: 4.690416, error: 0.01}} + + - match: {hits.hits.1._id: "1"} + - close_to: {hits.hits.1._score: {value: 3.8729835, error: 0.01}} + + - match: {hits.hits.2._id: "3"} + - close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}} + +--- +"Dot Product is not supported": + - do: + catch: bad_request + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "dotProduct(params.query_vector, 'vector')" + params: + query_vector: [0, 111, -13, 14, -124] + - do: + catch: bad_request + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "dotProduct(params.query_vector, 'vector')" + params: + query_vector: "006ff30e84" + +--- +"Cosine Similarity is not supported": + - do: + catch: bad_request + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, 'vector')" + params: + query_vector: [0, 111, -13, 14, -124] + - do: + catch: bad_request + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, 'vector')" + params: + query_vector: "006ff30e84" + + - do: + catch: bad_request + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, 'indexed_vector')" + params: + query_vector: [0, 111, -13, 14, -124] +--- +"L1 norm": + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "l1norm(params.query_vector, 'vector')" + params: + query_vector: [0, 111, -13, 14, -124] + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0._score: 17.0} + + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1._score: 16.0} + + - match: {hits.hits.2._id: "3"} + - match: {hits.hits.2._score: 11.0} + +--- +"L1 norm hexidecimal": + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "l1norm(params.query_vector, 'vector')" + params: + query_vector: "006ff30e84" + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0._score: 17.0} + + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1._score: 16.0} + + - match: {hits.hits.2._id: "3"} + - match: {hits.hits.2._score: 11.0} +--- +"L2 norm": + - requires: + test_runner_features: close_to + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "l2norm(params.query_vector, 'vector')" + params: + query_vector: [0, 111, -13, 14, -124] + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "2"} + - close_to: {hits.hits.0._score: {value: 4.123, error: 0.001}} + + - match: {hits.hits.1._id: "1"} + - close_to: {hits.hits.1._score: {value: 4, error: 0.001}} + + - match: {hits.hits.2._id: "3"} + - close_to: {hits.hits.2._score: {value: 3.316, error: 0.001}} +--- +"L2 norm hexidecimal": + - requires: + test_runner_features: close_to + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "l2norm(params.query_vector, 'vector')" + params: + query_vector: "006ff30e84" + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "2"} + - close_to: {hits.hits.0._score: {value: 4.123, error: 0.001}} + + - match: {hits.hits.1._id: "1"} + - close_to: {hits.hits.1._score: {value: 4, error: 0.001}} + + - match: {hits.hits.2._id: "3"} + - close_to: {hits.hits.2._score: {value: 3.316, error: 0.001}} +--- +"Hamming distance": + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "hamming(params.query_vector, 'vector')" + params: + query_vector: [0, 111, -13, 14, -124] + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0._score: 17.0} + + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1._score: 16.0} + + - match: {hits.hits.2._id: "3"} + - match: {hits.hits.2._score: 11.0} + + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "hamming(params.query_vector, 'indexed_vector')" + params: + query_vector: [0, 111, -13, 14, -124] + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0._score: 17.0} + + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1._score: 16.0} + + - match: {hits.hits.2._id: "3"} + - match: {hits.hits.2._score: 11.0} +--- +"Hamming distance hexidecimal": + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "hamming(params.query_vector, 'vector')" + params: + query_vector: "006ff30e84" + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0._score: 17.0} + + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1._score: 16.0} + + - match: {hits.hits.2._id: "3"} + - match: {hits.hits.2._score: 11.0} + + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "hamming(params.query_vector, 'indexed_vector')" + params: + query_vector: "006ff30e84" + + - match: {hits.total: 3} + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0._score: 17.0} + + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1._score: 16.0} + + - match: {hits.hits.2._id: "3"} + - match: {hits.hits.2._score: 11.0} diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java index 5ad1152d65e85..d2f7f6ab61977 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java @@ -72,8 +72,10 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.test.MockLog; +import org.elasticsearch.test.ReachabilityChecker; import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.LeakTracker; import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.netty4.Netty4Utils; import org.elasticsearch.xcontent.ToXContentObject; @@ -315,14 +317,20 @@ public void onFailure(Exception exception) { private static Releasable withResourceTracker() { assertNull(refs); + final ReachabilityChecker reachabilityChecker = new ReachabilityChecker(); final var latch = new CountDownLatch(1); - refs = AbstractRefCounted.of(latch::countDown); + refs = LeakTracker.wrap(reachabilityChecker.register(AbstractRefCounted.of(latch::countDown))); return () -> { refs.decRef(); + boolean success = false; try { safeAwait(latch); + success = true; } finally { refs = null; + if (success == false) { + reachabilityChecker.ensureUnreachable(); + } } }; } @@ -635,10 +643,10 @@ public void close() { @Override public void accept(RestChannel channel) { - localRefs.mustIncRef(); client.execute(TYPE, new Request(), new RestActionListener<>(channel) { @Override protected void processResponse(Response response) { + localRefs.mustIncRef(); channel.sendResponse(RestResponse.chunked(RestStatus.OK, response.getResponseBodyPart(), () -> { // cancellation notification only happens while processing a continuation, not while computing // the next one; prompt cancellation requires use of something like RestCancellableNodeClient diff --git a/muted-tests.yml b/muted-tests.yml index 225bb7ac038eb..748f6f463f345 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -50,8 +50,6 @@ tests: - class: "org.elasticsearch.xpack.rollup.job.RollupIndexerStateTests" issue: "https://github.com/elastic/elasticsearch/issues/109627" method: "testMultipleJobTriggering" -- class: "org.elasticsearch.index.store.FsDirectoryFactoryTests" - issue: "https://github.com/elastic/elasticsearch/issues/109681" - class: "org.elasticsearch.xpack.test.rest.XPackRestIT" issue: "https://github.com/elastic/elasticsearch/issues/109687" method: "test {p0=sql/translate/Translate SQL}" @@ -61,31 +59,29 @@ tests: - class: org.elasticsearch.action.search.SearchProgressActionListenerIT method: testSearchProgressWithHits issue: https://github.com/elastic/elasticsearch/issues/109830 -- class: "org.elasticsearch.xpack.shutdown.NodeShutdownReadinessIT" - issue: "https://github.com/elastic/elasticsearch/issues/109838" - method: "testShutdownReadinessService" - class: "org.elasticsearch.xpack.security.ScrollHelperIntegTests" issue: "https://github.com/elastic/elasticsearch/issues/109905" method: "testFetchAllEntities" -- class: "org.elasticsearch.xpack.ml.integration.AutodetectMemoryLimitIT" - issue: "https://github.com/elastic/elasticsearch/issues/109904" - class: "org.elasticsearch.xpack.esql.action.AsyncEsqlQueryActionIT" issue: "https://github.com/elastic/elasticsearch/issues/109944" method: "testBasicAsyncExecution" - class: "org.elasticsearch.xpack.security.authz.store.NativePrivilegeStoreCacheTests" issue: "https://github.com/elastic/elasticsearch/issues/110015" -- class: "org.elasticsearch.painless.LangPainlessClientYamlTestSuiteIT" - issue: "https://github.com/elastic/elasticsearch/issues/110032" - method: "test {yaml=painless/140_dense_vector_basic/Test hamming distance fails on float}" - class: "org.elasticsearch.action.admin.indices.rollover.RolloverIT" issue: "https://github.com/elastic/elasticsearch/issues/110034" method: "testRolloverWithClosedWriteIndex" -- class: org.elasticsearch.datastreams.DataStreamsClientYamlTestSuiteIT - method: test {p0=data_stream/200_rollover_failure_store/Rolling over a failure store on a data stream without the failure store enabled should work} - issue: https://github.com/elastic/elasticsearch/issues/110051 - class: org.elasticsearch.xpack.transform.transforms.TransformIndexerTests method: testMaxPageSearchSizeIsResetToConfiguredValue issue: https://github.com/elastic/elasticsearch/issues/109844 +- class: org.elasticsearch.index.store.FsDirectoryFactoryTests + method: testStoreDirectory + issue: https://github.com/elastic/elasticsearch/issues/110210 +- class: org.elasticsearch.index.store.FsDirectoryFactoryTests + method: testPreload + issue: https://github.com/elastic/elasticsearch/issues/110211 +- class: org.elasticsearch.synonyms.SynonymsManagementAPIServiceIT + method: testUpdateRuleWithMaxSynonyms + issue: https://github.com/elastic/elasticsearch/issues/110212 # Examples: # diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.get_categories.json b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.get_categories.json index 6dfa2e64dd293..69f8dd74e3d55 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.get_categories.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.get_categories.json @@ -30,7 +30,7 @@ } }, { - "path":"/_ml/anomaly_detectors/{job_id}/results/categories/", + "path":"/_ml/anomaly_detectors/{job_id}/results/categories", "methods":[ "GET", "POST" diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/rollup.get_jobs.json b/rest-api-spec/src/main/resources/rest-api-spec/api/rollup.get_jobs.json index 46ac1c4d304d1..e373c9f08bfd5 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/rollup.get_jobs.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/rollup.get_jobs.json @@ -24,7 +24,7 @@ } }, { - "path":"/_rollup/job/", + "path":"/_rollup/job", "methods":[ "GET" ] diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/rollup.get_rollup_caps.json b/rest-api-spec/src/main/resources/rest-api-spec/api/rollup.get_rollup_caps.json index 7dcc83ee0cd47..a72187f9ca926 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/rollup.get_rollup_caps.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/rollup.get_rollup_caps.json @@ -24,7 +24,7 @@ } }, { - "path":"/_rollup/data/", + "path":"/_rollup/data", "methods":[ "GET" ] diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/security.put_privileges.json b/rest-api-spec/src/main/resources/rest-api-spec/api/security.put_privileges.json index da63002b49485..8c920e10f285b 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/security.put_privileges.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/security.put_privileges.json @@ -13,7 +13,7 @@ "url":{ "paths":[ { - "path":"/_security/privilege/", + "path":"/_security/privilege", "methods":[ "PUT", "POST" diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml index 9fc82eb125def..dcd1f93e35da8 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/indices.create/20_synthetic_source.yml @@ -13,7 +13,7 @@ invalid: mode: synthetic properties: kwd: - type: keyword + type: boolean doc_values: false diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/101_knn_nested_search_bits.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/101_knn_nested_search_bits.yml new file mode 100644 index 0000000000000..a3d920d903ae8 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/101_knn_nested_search_bits.yml @@ -0,0 +1,301 @@ +setup: + - requires: + cluster_features: "mapper.vectors.bit_vectors" + test_runner_features: close_to + reason: 'bit vectors added in 8.15' + - do: + indices.create: + index: test + body: + settings: + index: + number_of_shards: 2 + mappings: + properties: + name: + type: keyword + nested: + type: nested + properties: + paragraph_id: + type: keyword + vector: + type: dense_vector + dims: 40 + index: true + element_type: bit + similarity: l2_norm + + - do: + index: + index: test + id: "1" + body: + name: cow.jpg + nested: + - paragraph_id: 0 + vector: [100, 20, -34, 15, -100] + - paragraph_id: 1 + vector: [40, 30, -3, 1, -20] + + - do: + index: + index: test + id: "2" + body: + name: moose.jpg + nested: + - paragraph_id: 0 + vector: [-1, 100, -13, 14, -127] + - paragraph_id: 2 + vector: [0, 100, 0, 15, -127] + - paragraph_id: 3 + vector: [0, 1, 0, 2, -15] + + - do: + index: + index: test + id: "3" + body: + name: rabbit.jpg + nested: + - paragraph_id: 0 + vector: [1, 111, -13, 14, -1] + + - do: + indices.refresh: {} + +--- +"nested kNN search only": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nested.vector + query_vector: [-1, 90, -10, 14, -127] + k: 2 + num_candidates: 3 + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0.fields.name.0: "moose.jpg"} + + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1.fields.name.0: "cow.jpg"} + + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nested.vector + query_vector: [-1, 90, -10, 14, -127] + k: 2 + num_candidates: 3 + inner_hits: {size: 1, "fields": ["nested.paragraph_id"], _source: false} + + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0.fields.name.0: "moose.jpg"} + - match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"} + + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1.fields.name.0: "cow.jpg"} + - match: {hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"} + +--- +"nested kNN search filtered": + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nested.vector + query_vector: [-1, 90, -10, 14, -127] + k: 2 + num_candidates: 3 + filter: {term: {name: "rabbit.jpg"}} + + - match: {hits.total.value: 1} + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nested.vector + query_vector: [-1, 90, -10, 14, -127] + k: 3 + num_candidates: 3 + filter: {term: {name: "rabbit.jpg"}} + inner_hits: {size: 1, fields: ["nested.paragraph_id"], _source: false} + + - match: {hits.total.value: 1} + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + - match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"} +--- +"nested kNN search inner_hits size > 1": + - do: + index: + index: test + id: "4" + body: + name: moose.jpg + nested: + - paragraph_id: 0 + vector: [-1, 90, -10, 14, -127] + - paragraph_id: 2 + vector: [ 0, 100.0, 0, 14, -127 ] + - paragraph_id: 3 + vector: [ 0, 1.0, 0, 2, -15 ] + + - do: + index: + index: test + id: "5" + body: + name: moose.jpg + nested: + - paragraph_id: 0 + vector: [ -1, 100, -13, 14, -127 ] + - paragraph_id: 2 + vector: [ 0, 100, 0, 15, -127 ] + - paragraph_id: 3 + vector: [ 0, 1, 0, 2, -15 ] + + - do: + index: + index: test + id: "6" + body: + name: moose.jpg + nested: + - paragraph_id: 0 + vector: [ -1, 100, -13, 15, -127 ] + - paragraph_id: 2 + vector: [ 0, 100, 0, 15, -127 ] + - paragraph_id: 3 + vector: [ 0, 1, 0, 2, -15 ] + - do: + indices.refresh: { } + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nested.vector + query_vector: [-1, 90, -10, 15, -127] + k: 3 + num_candidates: 5 + inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false} + + - match: {hits.total.value: 3} + - length: { hits.hits.0.inner_hits.nested.hits.hits: 2 } + - length: { hits.hits.1.inner_hits.nested.hits.hits: 2 } + - length: { hits.hits.2.inner_hits.nested.hits.hits: 2 } + + - match: { hits.hits.0.fields.name.0: "moose.jpg" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nested.vector + query_vector: [-1, 90, -10, 15, -127] + k: 5 + num_candidates: 5 + inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false} + + - match: {hits.total.value: 5} + # All these initial matches are "moose.jpg", which has 3 nested vectors, but two are closest + - match: {hits.hits.0.fields.name.0: "moose.jpg"} + - length: { hits.hits.0.inner_hits.nested.hits.hits: 2 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" } + - match: {hits.hits.1.fields.name.0: "moose.jpg"} + - length: { hits.hits.1.inner_hits.nested.hits.hits: 2 } + - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.1.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" } + - match: {hits.hits.2.fields.name.0: "moose.jpg"} + - length: { hits.hits.2.inner_hits.nested.hits.hits: 2 } + - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.2.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" } + - match: {hits.hits.3.fields.name.0: "moose.jpg"} + - length: { hits.hits.3.inner_hits.nested.hits.hits: 2 } + - match: { hits.hits.3.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.3.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" } + # Rabbit only has one passage vector + - match: {hits.hits.4.fields.name.0: "cow.jpg"} + - length: { hits.hits.4.inner_hits.nested.hits.hits: 2 } + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nested.vector + query_vector: [ -1, 90, -10, 15, -127 ] + k: 3 + num_candidates: 3 + filter: {term: {name: "cow.jpg"}} + inner_hits: {size: 3, fields: ["nested.paragraph_id"], _source: false} + + - match: {hits.total.value: 1} + - match: { hits.hits.0._id: "1" } + - length: { hits.hits.0.inner_hits.nested.hits.hits: 2 } + - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" } + - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "1" } +--- +"nested kNN search inner_hits & boosting": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nested.vector + query_vector: [-1, 90, -10, 15, -127] + k: 3 + num_candidates: 5 + inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false} + + - close_to: { hits.hits.0._score: {value: 0.8, error: 0.00001} } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: {value: 0.8, error: 0.00001} } + - close_to: { hits.hits.1._score: {value: 0.625, error: 0.00001} } + - close_to: { hits.hits.1.inner_hits.nested.hits.hits.0._score: {value: 0.625, error: 0.00001} } + - close_to: { hits.hits.2._score: {value: 0.5, error: 0.00001} } + - close_to: { hits.hits.2.inner_hits.nested.hits.hits.0._score: {value: 0.5, error: 0.00001} } + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nested.vector + query_vector: [-1, 90, -10, 15, -127] + k: 3 + num_candidates: 5 + boost: 2 + inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false} + - close_to: { hits.hits.0._score: {value: 1.6, error: 0.00001} } + - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: {value: 1.6, error: 0.00001} } + - close_to: { hits.hits.1._score: {value: 1.25, error: 0.00001} } + - close_to: { hits.hits.1.inner_hits.nested.hits.hits.0._score: {value: 1.25, error: 0.00001} } + - close_to: { hits.hits.2._score: {value: 1, error: 0.00001} } + - close_to: { hits.hits.2.inner_hits.nested.hits.hits.0._score: {value: 1.0, error: 0.00001} } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml index 74fbe221c0fe7..f989e17e6ec30 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml @@ -116,8 +116,9 @@ setup: --- "Knn search with hex string for byte field - dimensions mismatch" : # [64, 10, -30, 10] - is encoded as '400ae20a' + # the error message has been adjusted in later versions - do: - catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/ + catch: /dimension|dimensions \[4\] than the document|index vectors \[3\]/ search: index: knn_hex_vector_index body: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml index e01f3ec18b8c3..cd94275234661 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml @@ -116,8 +116,9 @@ setup: --- "Knn query with hex string for byte field - dimensions mismatch" : # [64, 10, -30, 10] - is encoded as '400ae20a' + # the error message has been adjusted in later versions - do: - catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/ + catch: /dimension|dimensions \[4\] than the document|index vectors \[3\]/ search: index: knn_hex_vector_index body: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index 24437e3db1379..cb5aae482507a 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -204,7 +204,7 @@ setup: num_candidates: 3 k: 3 field: vector - similarity: 10.3 + similarity: 17 query_vector: [-0.5, 90.0, -10, 14.8] - length: {hits.hits: 1} @@ -222,7 +222,7 @@ setup: num_candidates: 3 k: 3 field: vector - similarity: 11 + similarity: 17 query_vector: [-0.5, 90.0, -10, 14.8] filter: {"term": {"name": "moose.jpg"}} diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml new file mode 100644 index 0000000000000..ed469ffd7ff16 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit.yml @@ -0,0 +1,356 @@ +setup: + - requires: + cluster_features: "mapper.vectors.bit_vectors" + reason: 'mapper.vectors.bit_vectors' + + - do: + indices.create: + index: test + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + element_type: bit + dims: 40 + index: true + similarity: l2_norm + + - do: + index: + index: test + id: "1" + body: + name: cow.jpg + vector: [2, -1, 1, 4, -3] + + - do: + index: + index: test + id: "2" + body: + name: moose.jpg + vector: [127.0, -128.0, 0.0, 1.0, -1.0] + + - do: + index: + index: test + id: "3" + body: + name: rabbit.jpg + vector: [5, 4.0, 3, 2.0, 127] + + - do: + indices.refresh: {} + +--- +"kNN search only": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 2 + num_candidates: 3 + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0.fields.name.0: "moose.jpg"} + + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1.fields.name.0: "cow.jpg"} + +--- +"kNN search plus query": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127.0, -128.0, 0.0, 1.0, -1.0] + k: 2 + num_candidates: 3 + query: + term: + name: rabbit.jpg + + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.1.fields.name.0: "moose.jpg"} + +--- +"kNN search with filter": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [5.0, 4, 3.0, 2, 127.0] + k: 2 + num_candidates: 3 + filter: + term: + name: "rabbit.jpg" + + - match: {hits.total.value: 1} + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [2, -1, 1, 4, -3] + k: 2 + num_candidates: 3 + filter: + - term: + name: "rabbit.jpg" + - term: + _id: 2 + + - match: {hits.total.value: 0} + +--- +"Vector similarity search only": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 0.98 + query_vector: [5, 4.0, 3, 2.0, 127] + + - length: {hits.hits: 1} + + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} +--- +"Vector similarity with filter only": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 0.98 + query_vector: [5, 4.0, 3, 2.0, 127] + filter: {"term": {"name": "rabbit.jpg"}} + + - length: {hits.hits: 1} + + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 0.98 + query_vector: [5, 4.0, 3, 2.0, 127] + filter: {"term": {"name": "cow.jpg"}} + + - length: {hits.hits: 0} +--- +"dim mismatch": + - do: + catch: bad_request + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [1, 2, 3, 4, 5, 6] + k: 2 + num_candidates: 3 +--- +"disallow quantized vector types": + - do: + catch: bad_request + indices.create: + index: test + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + element_type: bit + dims: 32 + index: true + similarity: l2_norm + index_options: + type: int8_flat + + - do: + catch: bad_request + indices.create: + index: test + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + element_type: bit + dims: 32 + index: true + similarity: l2_norm + index_options: + type: int4_flat + + - do: + catch: bad_request + indices.create: + index: test + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + element_type: bit + dims: 32 + index: true + similarity: l2_norm + index_options: + type: int8_hnsw + + - do: + catch: bad_request + indices.create: + index: test + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + element_type: bit + dims: 32 + index: true + similarity: l2_norm + index_options: + type: int4_hnsw +--- +"disallow vector index type change to quantized type": + - do: + catch: bad_request + indices.put_mapping: + index: test + body: + properties: + vector: + type: dense_vector + element_type: bit + dims: 32 + index: true + similarity: l2_norm + index_options: + type: int4_hnsw + - do: + catch: bad_request + indices.put_mapping: + index: test + body: + properties: + vector: + type: dense_vector + element_type: bit + dims: 32 + index: true + similarity: l2_norm + index_options: + type: int8_hnsw +--- +"Defaults to l2_norm with bit vectors": + - do: + indices.create: + index: default_to_l2_norm_bit + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bit + dims: 40 + index: true + + - do: + indices.get_mapping: + index: default_to_l2_norm_bit + + - match: { default_to_l2_norm_bit.mappings.properties.vector.similarity: l2_norm } + +--- +"Only allow l2_norm with bit vectors": + - do: + catch: bad_request + indices.create: + index: dot_product_fails_for_bits + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bit + dims: 40 + index: true + similarity: dot_product + + - do: + catch: bad_request + indices.create: + index: cosine_product_fails_for_bits + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bit + dims: 40 + index: true + similarity: cosine + + - do: + catch: bad_request + indices.create: + index: cosine_product_fails_for_bits + body: + mappings: + properties: + type: dense_vector + element_type: bit + dims: 40 + index: true + similarity: max_inner_product diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml new file mode 100644 index 0000000000000..ec7bde4de8435 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_bit_flat.yml @@ -0,0 +1,223 @@ +setup: + - requires: + cluster_features: "mapper.vectors.bit_vectors" + reason: 'mapper.vectors.bit_vectors' + + - do: + indices.create: + index: test + body: + mappings: + properties: + name: + type: keyword + vector: + type: dense_vector + element_type: bit + dims: 40 + index: true + similarity: l2_norm + index_options: + type: flat + + - do: + index: + index: test + id: "1" + body: + name: cow.jpg + vector: [2, -1, 1, 4, -3] + + - do: + index: + index: test + id: "2" + body: + name: moose.jpg + vector: [127.0, -128.0, 0.0, 1.0, -1.0] + + - do: + index: + index: test + id: "3" + body: + name: rabbit.jpg + vector: [5, 4.0, 3, 2.0, 127] + + - do: + indices.refresh: {} + +--- +"kNN search only": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127, 127, -128, -128, 127] + k: 2 + num_candidates: 3 + + - match: {hits.hits.0._id: "2"} + - match: {hits.hits.0.fields.name.0: "moose.jpg"} + + - match: {hits.hits.1._id: "1"} + - match: {hits.hits.1.fields.name.0: "cow.jpg"} + +--- +"kNN search plus query": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [127.0, -128.0, 0.0, 1.0, -1.0] + k: 2 + num_candidates: 3 + query: + term: + name: rabbit.jpg + + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.1.fields.name.0: "moose.jpg"} + +--- +"kNN search with filter": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [5.0, 4, 3.0, 2, 127.0] + k: 2 + num_candidates: 3 + filter: + term: + name: "rabbit.jpg" + + - match: {hits.total.value: 1} + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [2, -1, 1, 4, -3] + k: 2 + num_candidates: 3 + filter: + - term: + name: "rabbit.jpg" + - term: + _id: 2 + + - match: {hits.total.value: 0} + +--- +"Vector similarity search only": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 0.98 + query_vector: [5, 4.0, 3, 2.0, 127] + + - length: {hits.hits: 1} + + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} +--- +"Vector similarity with filter only": + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 0.98 + query_vector: [5, 4.0, 3, 2.0, 127] + filter: {"term": {"name": "rabbit.jpg"}} + + - length: {hits.hits: 1} + + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.0.fields.name.0: "rabbit.jpg"} + + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + num_candidates: 3 + k: 3 + field: vector + similarity: 0.98 + query_vector: [5, 4.0, 3, 2.0, 127] + filter: {"term": {"name": "cow.jpg"}} + + - length: {hits.hits: 0} +--- +"dim mismatch": + - do: + catch: bad_request + search: + index: test + body: + fields: [ "name" ] + knn: + field: vector + query_vector: [1, 2, 3, 4, 5, 6] + k: 2 + num_candidates: 3 +--- +"disallow vector index type change to quantized type": + - do: + catch: bad_request + indices.put_mapping: + index: test + body: + properties: + vector: + type: dense_vector + element_type: bit + dims: 32 + index: true + similarity: l2_norm + index_options: + type: int4_hnsw + - do: + catch: bad_request + indices.put_mapping: + index: test + body: + properties: + vector: + type: dense_vector + element_type: bit + dims: 32 + index: true + similarity: l2_norm + index_options: + type: int8_hnsw diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java index b07a861a5a5ef..ee66d57bec5cc 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/ClusterStateDiffIT.java @@ -8,6 +8,7 @@ package org.elasticsearch.cluster; +import org.elasticsearch.TransportVersion; import org.elasticsearch.Version; import org.elasticsearch.cluster.block.ClusterBlock; import org.elasticsearch.cluster.block.ClusterBlocks; @@ -48,6 +49,7 @@ import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.shard.IndexLongFieldRange; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.snapshots.Snapshot; import org.elasticsearch.snapshots.SnapshotId; @@ -568,6 +570,7 @@ public IndexMetadata randomCreate(String name) { settingsBuilder.put(randomSettings(Settings.EMPTY)).put(IndexMetadata.SETTING_VERSION_CREATED, randomVersion(random())); builder.settings(settingsBuilder); builder.numberOfShards(randomIntBetween(1, 10)).numberOfReplicas(randomInt(10)); + builder.eventIngestedRange(IndexLongFieldRange.UNKNOWN, TransportVersion.current()); int aliasCount = randomInt(10); for (int i = 0; i < aliasCount; i++) { builder.putAlias(randomAlias()); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java index b563d6849c777..204d7131c44d2 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java @@ -1041,7 +1041,6 @@ public void testHistoryRetention() throws Exception { assertThat(recoveryState.getTranslog().recoveredOperations(), greaterThan(0)); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105122") public void testDoNotInfinitelyWaitForMapping() { internalCluster().ensureAtLeastNumDataNodes(3); createIndex( diff --git a/server/src/internalClusterTest/java/org/elasticsearch/reservedstate/service/FileSettingsServiceIT.java b/server/src/internalClusterTest/java/org/elasticsearch/reservedstate/service/FileSettingsServiceIT.java index 6e89c1447edb6..2fe808d813ccc 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/reservedstate/service/FileSettingsServiceIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/reservedstate/service/FileSettingsServiceIT.java @@ -42,15 +42,16 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; @ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, autoManageMasterNodes = false) public class FileSettingsServiceIT extends ESIntegTestCase { - private static AtomicLong versionCounter = new AtomicLong(1); + private static final AtomicLong versionCounter = new AtomicLong(1); - private static String testJSON = """ + private static final String testJSON = """ { "metadata": { "version": "%s", @@ -63,7 +64,7 @@ public class FileSettingsServiceIT extends ESIntegTestCase { } }"""; - private static String testJSON43mb = """ + private static final String testJSON43mb = """ { "metadata": { "version": "%s", @@ -76,7 +77,7 @@ public class FileSettingsServiceIT extends ESIntegTestCase { } }"""; - private static String testCleanupJSON = """ + private static final String testCleanupJSON = """ { "metadata": { "version": "%s", @@ -87,7 +88,7 @@ public class FileSettingsServiceIT extends ESIntegTestCase { } }"""; - private static String testErrorJSON = """ + private static final String testErrorJSON = """ { "metadata": { "version": "%s", @@ -165,8 +166,7 @@ public void clusterChanged(ClusterChangedEvent event) { private void assertClusterStateSaveOK(CountDownLatch savedClusterState, AtomicLong metadataVersion, String expectedBytesPerSec) throws Exception { - boolean awaitSuccessful = savedClusterState.await(20, TimeUnit.SECONDS); - assertTrue(awaitSuccessful); + assertTrue(savedClusterState.await(20, TimeUnit.SECONDS)); final ClusterStateResponse clusterStateResponse = clusterAdmin().state( new ClusterStateRequest().waitForMetadataVersion(metadataVersion.get()) @@ -180,11 +180,13 @@ private void assertClusterStateSaveOK(CountDownLatch savedClusterState, AtomicLo ClusterUpdateSettingsRequest req = new ClusterUpdateSettingsRequest().persistentSettings( Settings.builder().put(INDICES_RECOVERY_MAX_BYTES_PER_SEC_SETTING.getKey(), "1234kb") ); - assertEquals( - "java.lang.IllegalArgumentException: Failed to process request " - + "[org.elasticsearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest/unset] " - + "with errors: [[indices.recovery.max_bytes_per_sec] set as read-only by [file_settings]]", - expectThrows(ExecutionException.class, () -> clusterAdmin().updateSettings(req).get()).getMessage() + assertThat( + expectThrows(ExecutionException.class, () -> clusterAdmin().updateSettings(req).get()).getMessage(), + is( + "java.lang.IllegalArgumentException: Failed to process request " + + "[org.elasticsearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest/unset] " + + "with errors: [[indices.recovery.max_bytes_per_sec] set as read-only by [file_settings]]" + ) ); } @@ -256,16 +258,15 @@ public void testReservedStatePersistsOnRestart() throws Exception { internalCluster().restartNode(masterNode); final ClusterStateResponse clusterStateResponse = clusterAdmin().state(new ClusterStateRequest()).actionGet(); - assertEquals( - 1, + assertThat( clusterStateResponse.getState() .metadata() .reservedStateMetadata() .get(FileSettingsService.NAMESPACE) .handlers() .get(ReservedClusterSettingsAction.NAME) - .keys() - .size() + .keys(), + hasSize(1) ); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/CollapseSearchResultsIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/CollapseSearchResultsIT.java index a12a26d69c5ff..f5fdd752a6f57 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/CollapseSearchResultsIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/CollapseSearchResultsIT.java @@ -39,4 +39,26 @@ public void testCollapse() { } ); } + + public void testCollapseWithDocValueFields() { + final String indexName = "test_collapse"; + createIndex(indexName); + final String collapseField = "collapse_field"; + final String otherField = "other_field"; + assertAcked(indicesAdmin().preparePutMapping(indexName).setSource(collapseField, "type=keyword", otherField, "type=keyword")); + index(indexName, "id_1_0", Map.of(collapseField, "value1", otherField, "other_value1")); + index(indexName, "id_1_1", Map.of(collapseField, "value1", otherField, "other_value2")); + index(indexName, "id_2_0", Map.of(collapseField, "value2", otherField, "other_value3")); + refresh(indexName); + + assertNoFailuresAndResponse( + prepareSearch(indexName).setQuery(new MatchAllQueryBuilder()) + .addDocValueField(otherField) + .setCollapse(new CollapseBuilder(collapseField).setInnerHits(new InnerHitBuilder("ih").setSize(2))), + searchResponse -> { + assertEquals(collapseField, searchResponse.getHits().getCollapseField()); + assertEquals(Set.of(new BytesRef("value1"), new BytesRef("value2")), Set.of(searchResponse.getHits().getCollapseValues())); + } + ); + } } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/FunctionScorePluginIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/FunctionScorePluginIT.java index 396af7e8501cf..d42a84677a8f7 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/FunctionScorePluginIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/FunctionScorePluginIT.java @@ -146,7 +146,6 @@ private static class LinearMultScoreFunction implements DecayFunction { @Override public double evaluate(double value, double scale) { - return value; } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotStressTestsIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotStressTestsIT.java index 9bcddd5c58d66..b8b6dcb25b557 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotStressTestsIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/SnapshotStressTestsIT.java @@ -216,11 +216,13 @@ private static class TrackedCluster { static final Logger logger = LogManager.getLogger(TrackedCluster.class); static final String CLIENT = "client"; + static final String NODE_RESTARTER = "node_restarter"; private final ThreadPool threadPool = new TestThreadPool( "TrackedCluster", // a single thread for "client" activities, to limit the number of activities all starting at once - new ScalingExecutorBuilder(CLIENT, 1, 1, TimeValue.ZERO, true, CLIENT) + new ScalingExecutorBuilder(CLIENT, 1, 1, TimeValue.ZERO, true, CLIENT), + new ScalingExecutorBuilder(NODE_RESTARTER, 1, 5, TimeValue.ZERO, true, NODE_RESTARTER) ); private final Executor clientExecutor = threadPool.executor(CLIENT); @@ -1163,7 +1165,7 @@ private void startNodeRestarter() { final String nodeName = trackedNode.nodeName; final Releasable releaseAll = localReleasables.transfer(); - threadPool.generic().execute(mustSucceed(() -> { + threadPool.executor(NODE_RESTARTER).execute(mustSucceed(() -> { logger.info("--> restarting [{}]", nodeName); cluster.restartNode(nodeName); logger.info("--> finished restarting [{}]", nodeName); diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index db7e3d40518ba..e2810a6f5bf16 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -449,7 +449,10 @@ with org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat, org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat, - org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; + org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat, + org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat; + provides org.apache.lucene.codecs.Codec with Elasticsearch814Codec; provides org.apache.logging.log4j.core.util.ContextDataProvider with org.elasticsearch.common.logging.DynamicContextDataProvider; diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 977f292912746..28a75443b5842 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -199,6 +199,10 @@ static TransportVersion def(int id) { public static final TransportVersion SNAPSHOT_REQUEST_TIMEOUTS = def(8_690_00_0); public static final TransportVersion INDEX_METADATA_MAPPINGS_UPDATED_VERSION = def(8_691_00_0); public static final TransportVersion ML_INFERENCE_ELAND_SETTINGS_ADDED = def(8_692_00_0); + public static final TransportVersion ML_ANTHROPIC_INTEGRATION_ADDED = def(8_693_00_0); + public static final TransportVersion ML_INFERENCE_GOOGLE_VERTEX_AI_EMBEDDINGS_ADDED = def(8_694_00_0); + public static final TransportVersion EVENT_INGESTED_RANGE_IN_CLUSTER_STATE = def(8_695_00_0); + public static final TransportVersion ESQL_ADD_AGGREGATE_TYPE = def(8_696_00_0); public static final TransportVersion MULTI_PROJECT = def(8_999_00_0); // THIS IS A HACK FOR NOW (!) /* @@ -264,7 +268,7 @@ static TransportVersion def(int id) { * Reference to the minimum transport version that can be used with CCS. * This should be the transport version used by the previous minor release. */ - public static final TransportVersion MINIMUM_CCS_VERSION = V_8_13_0; + public static final TransportVersion MINIMUM_CCS_VERSION = SHUTDOWN_REQUEST_TIMEOUTS_FIX_8_14; static final NavigableMap VERSION_IDS = getAllVersionIds(TransportVersions.class); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index 55c754545cbbe..82c498c64e1c9 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -28,6 +28,7 @@ import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.lucene.grouping.TopFieldGroups; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchHit; @@ -301,11 +302,13 @@ private static Sort checkSameSortTypes(Collection results, SortField[] } private static SortField.Type getType(SortField sortField) { - if (sortField instanceof SortedNumericSortField) { - return ((SortedNumericSortField) sortField).getNumericType(); - } - if (sortField instanceof SortedSetSortField) { + if (sortField instanceof SortedNumericSortField sf) { + return sf.getNumericType(); + } else if (sortField instanceof SortedSetSortField) { return SortField.Type.STRING; + } else if (sortField.getComparatorSource() instanceof IndexFieldData.XFieldComparatorSource cmp) { + // This can occur if the sort field wasn't rewritten by Lucene#rewriteMergeSortField because all search shards are local. + return cmp.reducedType(); } else { return sortField.getType(); } diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index f8d30786aca34..c2d1cdae85cd9 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -292,24 +292,43 @@ public long buildTookInMillis() { @Override protected void doExecute(Task task, SearchRequest searchRequest, ActionListener listener) { - ActionListener loggingAndMetrics = listener.delegateFailureAndWrap((l, searchResponse) -> { - searchResponseMetrics.recordTookTime(searchResponse.getTookInMillis()); - if (searchResponse.getShardFailures() != null && searchResponse.getShardFailures().length > 0) { - // Deduplicate failures by exception message and index - ShardOperationFailedException[] groupedFailures = ExceptionsHelper.groupBy(searchResponse.getShardFailures()); - for (ShardOperationFailedException f : groupedFailures) { - boolean causeHas500Status = false; - if (f.getCause() != null) { - causeHas500Status = ExceptionsHelper.status(f.getCause()).getStatus() >= 500; - } - if ((f.status().getStatus() >= 500 || causeHas500Status) - && ExceptionsHelper.isNodeOrShardUnavailableTypeException(f.getCause()) == false) { - logger.warn("TransportSearchAction shard failure (partial results response)", f); + ActionListener loggingAndMetrics = new ActionListener<>() { + @Override + public void onResponse(SearchResponse searchResponse) { + try { + searchResponseMetrics.recordTookTime(searchResponse.getTookInMillis()); + SearchResponseMetrics.ResponseCountTotalStatus responseCountTotalStatus = + SearchResponseMetrics.ResponseCountTotalStatus.SUCCESS; + if (searchResponse.getShardFailures() != null && searchResponse.getShardFailures().length > 0) { + // Deduplicate failures by exception message and index + ShardOperationFailedException[] groupedFailures = ExceptionsHelper.groupBy(searchResponse.getShardFailures()); + for (ShardOperationFailedException f : groupedFailures) { + boolean causeHas500Status = false; + if (f.getCause() != null) { + causeHas500Status = ExceptionsHelper.status(f.getCause()).getStatus() >= 500; + } + if ((f.status().getStatus() >= 500 || causeHas500Status) + && ExceptionsHelper.isNodeOrShardUnavailableTypeException(f.getCause()) == false) { + logger.warn("TransportSearchAction shard failure (partial results response)", f); + responseCountTotalStatus = SearchResponseMetrics.ResponseCountTotalStatus.PARTIAL_FAILURE; + } + } } + listener.onResponse(searchResponse); + // increment after the delegated onResponse to ensure we don't + // record both a success and a failure if there is an exception + searchResponseMetrics.incrementResponseCount(responseCountTotalStatus); + } catch (Exception e) { + onFailure(e); } } - l.onResponse(searchResponse); - }); + + @Override + public void onFailure(Exception e) { + searchResponseMetrics.incrementResponseCount(SearchResponseMetrics.ResponseCountTotalStatus.FAILURE); + listener.onFailure(e); + } + }; executeRequest((SearchTask) task, searchRequest, loggingAndMetrics, AsyncSearchActionProvider::new); } diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java index c1c43310b0e11..d60033786abeb 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java @@ -55,20 +55,39 @@ public TransportSearchScrollAction( @Override protected void doExecute(Task task, SearchScrollRequest request, ActionListener listener) { - ActionListener loggingAndMetrics = listener.delegateFailureAndWrap((l, searchResponse) -> { - searchResponseMetrics.recordTookTime(searchResponse.getTookInMillis()); - if (searchResponse.getShardFailures() != null && searchResponse.getShardFailures().length > 0) { - ShardOperationFailedException[] groupedFailures = ExceptionsHelper.groupBy(searchResponse.getShardFailures()); - for (ShardOperationFailedException f : groupedFailures) { - Throwable cause = f.getCause() == null ? f : f.getCause(); - if (ExceptionsHelper.status(cause).getStatus() >= 500 - && ExceptionsHelper.isNodeOrShardUnavailableTypeException(cause) == false) { - logger.warn("TransportSearchScrollAction shard failure (partial results response)", f); + ActionListener loggingAndMetrics = new ActionListener<>() { + @Override + public void onResponse(SearchResponse searchResponse) { + try { + searchResponseMetrics.recordTookTime(searchResponse.getTookInMillis()); + SearchResponseMetrics.ResponseCountTotalStatus responseCountTotalStatus = + SearchResponseMetrics.ResponseCountTotalStatus.SUCCESS; + if (searchResponse.getShardFailures() != null && searchResponse.getShardFailures().length > 0) { + ShardOperationFailedException[] groupedFailures = ExceptionsHelper.groupBy(searchResponse.getShardFailures()); + for (ShardOperationFailedException f : groupedFailures) { + Throwable cause = f.getCause() == null ? f : f.getCause(); + if (ExceptionsHelper.status(cause).getStatus() >= 500 + && ExceptionsHelper.isNodeOrShardUnavailableTypeException(cause) == false) { + logger.warn("TransportSearchScrollAction shard failure (partial results response)", f); + responseCountTotalStatus = SearchResponseMetrics.ResponseCountTotalStatus.PARTIAL_FAILURE; + } + } } + listener.onResponse(searchResponse); + // increment after the delegated onResponse to ensure we don't + // record both a success and a failure if there is an exception + searchResponseMetrics.incrementResponseCount(responseCountTotalStatus); + } catch (Exception e) { + onFailure(e); } } - l.onResponse(searchResponse); - }); + + @Override + public void onFailure(Exception e) { + searchResponseMetrics.incrementResponseCount(SearchResponseMetrics.ResponseCountTotalStatus.FAILURE); + listener.onFailure(e); + } + }; try { ParsedScrollId scrollId = parseScrollId(request.scrollId()); Runnable action = switch (scrollId.getType()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/ClusterState.java b/server/src/main/java/org/elasticsearch/cluster/ClusterState.java index aabc213235dc7..988e38f35391e 100644 --- a/server/src/main/java/org/elasticsearch/cluster/ClusterState.java +++ b/server/src/main/java/org/elasticsearch/cluster/ClusterState.java @@ -47,6 +47,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.index.shard.IndexLongFieldRange; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContent; @@ -233,6 +234,27 @@ public ClusterState( this.minVersions = blocks.hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) ? new CompatibilityVersions(TransportVersions.MINIMUM_COMPATIBLE, Map.of()) // empty map because cluster state is unknown : CompatibilityVersions.minimumVersions(compatibilityVersions.values()); + + assert compatibilityVersions.isEmpty() + || blocks.hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) + || assertEventIngestedIsUnknownInMixedClusters(metadata, this.minVersions); + } + + private boolean assertEventIngestedIsUnknownInMixedClusters(Metadata metadata, CompatibilityVersions compatibilityVersions) { + if (compatibilityVersions.transportVersion().before(TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE) + && metadata != null + && metadata.indices() != null) { + for (IndexMetadata indexMetadata : metadata.indices().values()) { + assert indexMetadata.getEventIngestedRange() == IndexLongFieldRange.UNKNOWN + : "event.ingested range should be UNKNOWN but is " + + indexMetadata.getEventIngestedRange() + + " for index: " + + indexMetadata.getIndex() + + " minTransportVersion: " + + compatibilityVersions.transportVersion(); + } + } + return true; } private static boolean assertConsistentRoutingNodes( diff --git a/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java b/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java index 51fca588699e2..a01383b3eaa93 100644 --- a/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java +++ b/server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java @@ -12,6 +12,8 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ResultDeduplicator; import org.elasticsearch.action.support.ChannelActionListener; @@ -543,9 +545,10 @@ public void shardStarted( final long primaryTerm, final String message, final ShardLongFieldRange timestampRange, + final ShardLongFieldRange eventIngestedRange, final ActionListener listener ) { - shardStarted(shardRouting, primaryTerm, message, timestampRange, listener, clusterService.state()); + shardStarted(shardRouting, primaryTerm, message, timestampRange, eventIngestedRange, listener, clusterService.state()); } public void shardStarted( @@ -553,11 +556,19 @@ public void shardStarted( final long primaryTerm, final String message, final ShardLongFieldRange timestampRange, + final ShardLongFieldRange eventIngestedRange, final ActionListener listener, final ClusterState currentState ) { remoteShardStateUpdateDeduplicator.executeOnce( - new StartedShardEntry(shardRouting.shardId(), shardRouting.allocationId().getId(), primaryTerm, message, timestampRange), + new StartedShardEntry( + shardRouting.shardId(), + shardRouting.allocationId().getId(), + primaryTerm, + message, + timestampRange, + eventIngestedRange + ), listener, (req, l) -> sendShardAction(SHARD_STARTED_ACTION_NAME, currentState, req, l) ); @@ -585,6 +596,14 @@ public void messageReceived(StartedShardEntry request, TransportChannel channel, } } + /** + * Holder of the pair of time ranges needed in cluster state - one for @timestamp, the other for 'event.ingested'. + * Since 'event.ingested' was added well after @timestamp, it can be UNKNOWN when @timestamp range is present. + * @param timestampRange range for @timestamp + * @param eventIngestedRange range for event.ingested + */ + record ClusterStateTimeRanges(IndexLongFieldRange timestampRange, IndexLongFieldRange eventIngestedRange) {} + public static class ShardStartedClusterStateTaskExecutor implements ClusterStateTaskExecutor { private final AllocationService allocationService; private final RerouteService rerouteService; @@ -599,37 +618,42 @@ public ClusterState execute(BatchExecutionContext batchE List> tasksToBeApplied = new ArrayList<>(); List shardRoutingsToBeApplied = new ArrayList<>(batchExecutionContext.taskContexts().size()); Set seenShardRoutings = new HashSet<>(); // to prevent duplicates - final Map updatedTimestampRanges = new HashMap<>(); + final Map updatedTimestampRanges = new HashMap<>(); final ClusterState initialState = batchExecutionContext.initialState(); for (var taskContext : batchExecutionContext.taskContexts()) { final var task = taskContext.getTask(); - StartedShardEntry entry = task.getEntry(); - final ShardRouting matched = initialState.getRoutingTable().getByAllocationId(entry.shardId, entry.allocationId); + StartedShardEntry startedShardEntry = task.getEntry(); + final ShardRouting matched = initialState.getRoutingTable() + .getByAllocationId(startedShardEntry.shardId, startedShardEntry.allocationId); if (matched == null) { // tasks that correspond to non-existent shards are marked as successful. The reason is that we resend shard started // events on every cluster state publishing that does not contain the shard as started yet. This means that old stale // requests might still be in flight even after the shard has already been started or failed on the master. We just // ignore these requests for now. - logger.debug("{} ignoring shard started task [{}] (shard does not exist anymore)", entry.shardId, entry); + logger.debug( + "{} ignoring shard started task [{}] (shard does not exist anymore)", + startedShardEntry.shardId, + startedShardEntry + ); taskContext.success(task::onSuccess); } else { - if (matched.primary() && entry.primaryTerm > 0) { - final IndexMetadata indexMetadata = initialState.metadata().index(entry.shardId.getIndex()); + if (matched.primary() && startedShardEntry.primaryTerm > 0) { + final IndexMetadata indexMetadata = initialState.metadata().index(startedShardEntry.shardId.getIndex()); assert indexMetadata != null; - final long currentPrimaryTerm = indexMetadata.primaryTerm(entry.shardId.id()); - if (currentPrimaryTerm != entry.primaryTerm) { - assert currentPrimaryTerm > entry.primaryTerm + final long currentPrimaryTerm = indexMetadata.primaryTerm(startedShardEntry.shardId.id()); + if (currentPrimaryTerm != startedShardEntry.primaryTerm) { + assert currentPrimaryTerm > startedShardEntry.primaryTerm : "received a primary term with a higher term than in the " + "current cluster state (received [" - + entry.primaryTerm + + startedShardEntry.primaryTerm + "] but current is [" + currentPrimaryTerm + "])"; logger.debug( "{} ignoring shard started task [{}] (primary term {} does not match current term {})", - entry.shardId, - entry, - entry.primaryTerm, + startedShardEntry.shardId, + startedShardEntry, + startedShardEntry.primaryTerm, currentPrimaryTerm ); taskContext.success(task::onSuccess); @@ -637,12 +661,12 @@ public ClusterState execute(BatchExecutionContext batchE } } if (matched.initializing() == false) { - assert matched.active() : "expected active shard routing for task " + entry + " but found " + matched; + assert matched.active() : "expected active shard routing for task " + startedShardEntry + " but found " + matched; // same as above, this might have been a stale in-flight request, so we just ignore. logger.debug( "{} ignoring shard started task [{}] (shard exists but is not initializing: {})", - entry.shardId, - entry, + startedShardEntry.shardId, + startedShardEntry, matched ); taskContext.success(task::onSuccess); @@ -651,32 +675,66 @@ public ClusterState execute(BatchExecutionContext batchE if (seenShardRoutings.contains(matched)) { logger.trace( "{} ignoring shard started task [{}] (already scheduled to start {})", - entry.shardId, - entry, + startedShardEntry.shardId, + startedShardEntry, matched ); tasksToBeApplied.add(taskContext); } else { - logger.debug("{} starting shard {} (shard started task: [{}])", entry.shardId, matched, entry); + logger.debug( + "{} starting shard {} (shard started task: [{}])", + startedShardEntry.shardId, + matched, + startedShardEntry + ); tasksToBeApplied.add(taskContext); shardRoutingsToBeApplied.add(matched); seenShardRoutings.add(matched); - // expand the timestamp range recorded in the index metadata if needed - final Index index = entry.shardId.getIndex(); - IndexLongFieldRange currentTimestampMillisRange = updatedTimestampRanges.get(index); + // expand the timestamp range(s) recorded in the index metadata if needed + final Index index = startedShardEntry.shardId.getIndex(); + ClusterStateTimeRanges clusterStateTimeRanges = updatedTimestampRanges.get(index); + IndexLongFieldRange currentTimestampMillisRange = clusterStateTimeRanges == null + ? null + : clusterStateTimeRanges.timestampRange(); + IndexLongFieldRange currentEventIngestedMillisRange = clusterStateTimeRanges == null + ? null + : clusterStateTimeRanges.eventIngestedRange(); + final IndexMetadata indexMetadata = initialState.metadata().index(index); if (currentTimestampMillisRange == null) { currentTimestampMillisRange = indexMetadata.getTimestampRange(); } - final IndexLongFieldRange newTimestampMillisRange; - newTimestampMillisRange = currentTimestampMillisRange.extendWithShardRange( - entry.shardId.id(), + if (currentEventIngestedMillisRange == null) { + currentEventIngestedMillisRange = indexMetadata.getEventIngestedRange(); + } + + final IndexLongFieldRange newTimestampMillisRange = currentTimestampMillisRange.extendWithShardRange( + startedShardEntry.shardId.id(), indexMetadata.getNumberOfShards(), - entry.timestampRange + startedShardEntry.timestampRange ); - if (newTimestampMillisRange != currentTimestampMillisRange) { - updatedTimestampRanges.put(index, newTimestampMillisRange); + /* + * Only track 'event.ingested' range this if the cluster state min transport version is on/after the version + * where we added 'event.ingested'. If we don't do that, we will have different cluster states on different + * nodes because we can't send this data over the wire to older nodes. + */ + IndexLongFieldRange newEventIngestedMillisRange = IndexLongFieldRange.UNKNOWN; + TransportVersion minTransportVersion = batchExecutionContext.initialState().getMinTransportVersion(); + if (minTransportVersion.onOrAfter(TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE)) { + newEventIngestedMillisRange = currentEventIngestedMillisRange.extendWithShardRange( + startedShardEntry.shardId.id(), + indexMetadata.getNumberOfShards(), + startedShardEntry.eventIngestedRange + ); + } + + if (newTimestampMillisRange != currentTimestampMillisRange + || newEventIngestedMillisRange != currentEventIngestedMillisRange) { + updatedTimestampRanges.put( + index, + new ClusterStateTimeRanges(newTimestampMillisRange, newEventIngestedMillisRange) + ); } } } @@ -690,10 +748,12 @@ public ClusterState execute(BatchExecutionContext batchE if (updatedTimestampRanges.isEmpty() == false) { final Metadata.Builder metadataBuilder = Metadata.builder(maybeUpdatedState.metadata()); - for (Map.Entry updatedTimestampRangeEntry : updatedTimestampRanges.entrySet()) { + for (Map.Entry updatedTimeRangesEntry : updatedTimestampRanges.entrySet()) { + ClusterStateTimeRanges timeRanges = updatedTimeRangesEntry.getValue(); metadataBuilder.put( - IndexMetadata.builder(metadataBuilder.getSafe(updatedTimestampRangeEntry.getKey())) - .timestampRange(updatedTimestampRangeEntry.getValue()) + IndexMetadata.builder(metadataBuilder.getSafe(updatedTimeRangesEntry.getKey())) + .timestampRange(timeRanges.timestampRange()) + .eventIngestedRange(timeRanges.eventIngestedRange(), maybeUpdatedState.getMinTransportVersion()) ); } maybeUpdatedState = ClusterState.builder(maybeUpdatedState).metadata(metadataBuilder).build(); @@ -725,6 +785,15 @@ private static boolean assertStartedIndicesHaveCompleteTimestampRanges(ClusterSt + clusterState.metadata().index(cursor.getKey()).getTimestampRange() + " for " + cursor.getValue().prettyPrint(); + + assert cursor.getValue().allPrimaryShardsActive() == false + || clusterState.metadata().index(cursor.getKey()).getEventIngestedRange().isComplete() + : "index [" + + cursor.getKey() + + "] should have complete event.ingested range, but got " + + clusterState.metadata().index(cursor.getKey()).getEventIngestedRange() + + " for " + + cursor.getValue().prettyPrint(); } return true; } @@ -748,6 +817,7 @@ public static class StartedShardEntry extends TransportRequest { final long primaryTerm; final String message; final ShardLongFieldRange timestampRange; + final ShardLongFieldRange eventIngestedRange; StartedShardEntry(StreamInput in) throws IOException { super(in); @@ -756,6 +826,11 @@ public static class StartedShardEntry extends TransportRequest { primaryTerm = in.readVLong(); this.message = in.readString(); this.timestampRange = ShardLongFieldRange.readFrom(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE)) { + this.eventIngestedRange = ShardLongFieldRange.readFrom(in); + } else { + this.eventIngestedRange = ShardLongFieldRange.UNKNOWN; + } } public StartedShardEntry( @@ -763,13 +838,15 @@ public StartedShardEntry( final String allocationId, final long primaryTerm, final String message, - final ShardLongFieldRange timestampRange + final ShardLongFieldRange timestampRange, + final ShardLongFieldRange eventIngestedRange ) { this.shardId = shardId; this.allocationId = allocationId; this.primaryTerm = primaryTerm; this.message = message; this.timestampRange = timestampRange; + this.eventIngestedRange = eventIngestedRange; } @Override @@ -780,6 +857,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeVLong(primaryTerm); out.writeString(message); timestampRange.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE)) { + eventIngestedRange.writeTo(out); + } } @Override @@ -802,12 +882,13 @@ public boolean equals(Object o) { && shardId.equals(that.shardId) && allocationId.equals(that.allocationId) && message.equals(that.message) - && timestampRange.equals(that.timestampRange); + && timestampRange.equals(that.timestampRange) + && eventIngestedRange.equals(that.eventIngestedRange); } @Override public int hashCode() { - return Objects.hash(shardId, allocationId, primaryTerm, message, timestampRange); + return Objects.hash(shardId, allocationId, primaryTerm, message, timestampRange, eventIngestedRange); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java index 7bfb3b9c8ae76..391ae451c8cef 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java @@ -137,6 +137,9 @@ public class IndexMetadata implements Diffable, ToXContentFragmen EnumSet.of(ClusterBlockLevel.WRITE) ); + // 'event.ingested' (part of Elastic Common Schema) range is tracked in cluster state, along with @timestamp + public static final String EVENT_INGESTED_FIELD_NAME = "event.ingested"; + @Nullable public String getDownsamplingInterval() { return settings.get(IndexMetadata.INDEX_DOWNSAMPLE_INTERVAL_KEY); @@ -538,6 +541,7 @@ public Iterator> settings() { static final String KEY_MAPPINGS_UPDATED_VERSION = "mappings_updated_version"; static final String KEY_SYSTEM = "system"; static final String KEY_TIMESTAMP_RANGE = "timestamp_range"; + static final String KEY_EVENT_INGESTED_RANGE = "event_ingested_range"; public static final String KEY_PRIMARY_TERMS = "primary_terms"; public static final String KEY_STATS = "stats"; @@ -603,7 +607,10 @@ public Iterator> settings() { private final boolean isSystem; private final boolean isHidden; + // range for the @timestamp field for the Index private final IndexLongFieldRange timestampRange; + // range for the event.ingested field for the Index + private final IndexLongFieldRange eventIngestedRange; private final int priority; @@ -670,6 +677,7 @@ private IndexMetadata( final boolean isSystem, final boolean isHidden, final IndexLongFieldRange timestampRange, + final IndexLongFieldRange eventIngestedRange, final int priority, final long creationDate, final boolean ignoreDiskWatermarks, @@ -724,6 +732,7 @@ private IndexMetadata( assert isHidden == INDEX_HIDDEN_SETTING.get(settings); this.isHidden = isHidden; this.timestampRange = timestampRange; + this.eventIngestedRange = eventIngestedRange; this.priority = priority; this.creationDate = creationDate; this.ignoreDiskWatermarks = ignoreDiskWatermarks; @@ -780,6 +789,7 @@ IndexMetadata withMappingMetadata(MappingMetadata mapping) { this.isSystem, this.isHidden, this.timestampRange, + this.eventIngestedRange, this.priority, this.creationDate, this.ignoreDiskWatermarks, @@ -840,6 +850,7 @@ public IndexMetadata withInSyncAllocationIds(int shardId, Set inSyncSet) this.isSystem, this.isHidden, this.timestampRange, + this.eventIngestedRange, this.priority, this.creationDate, this.ignoreDiskWatermarks, @@ -898,6 +909,7 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { this.isSystem, this.isHidden, this.timestampRange, + this.eventIngestedRange, this.priority, this.creationDate, this.ignoreDiskWatermarks, @@ -919,13 +931,24 @@ public IndexMetadata withIncrementedPrimaryTerm(int shardId) { } /** - * @param timestampRange new timestamp range + * @param timestampRange new @timestamp range + * @param eventIngestedRange new 'event.ingested' range + * @param minClusterTransportVersion minimum transport version used between nodes of this cluster * @return copy of this instance with updated timestamp range */ - public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { - if (timestampRange.equals(this.timestampRange)) { + public IndexMetadata withTimestampRanges( + IndexLongFieldRange timestampRange, + IndexLongFieldRange eventIngestedRange, + TransportVersion minClusterTransportVersion + ) { + if (timestampRange.equals(this.timestampRange) && eventIngestedRange.equals(this.eventIngestedRange)) { return this; } + IndexLongFieldRange allowedEventIngestedRange = eventIngestedRange; + // remove this check when the EVENT_INGESTED_RANGE_IN_CLUSTER_STATE version is removed + if (minClusterTransportVersion.before(TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE)) { + allowedEventIngestedRange = IndexLongFieldRange.UNKNOWN; + } return new IndexMetadata( this.index, this.version, @@ -956,6 +979,7 @@ public IndexMetadata withTimestampRange(IndexLongFieldRange timestampRange) { this.isSystem, this.isHidden, timestampRange, + allowedEventIngestedRange, this.priority, this.creationDate, this.ignoreDiskWatermarks, @@ -1010,6 +1034,7 @@ public IndexMetadata withIncrementedVersion() { this.isSystem, this.isHidden, this.timestampRange, + this.eventIngestedRange, this.priority, this.creationDate, this.ignoreDiskWatermarks, @@ -1360,6 +1385,10 @@ public IndexLongFieldRange getTimestampRange() { return timestampRange; } + public IndexLongFieldRange getEventIngestedRange() { + return eventIngestedRange; + } + /** * @return whether this index has a time series timestamp range */ @@ -1512,7 +1541,12 @@ private static class IndexMetadataDiff implements Diff { private final Diff> rolloverInfos; private final IndexVersion mappingsUpdatedVersion; private final boolean isSystem; + + // range for the @timestamp field for the Index private final IndexLongFieldRange timestampRange; + // range for the event.ingested field for the Index + private final IndexLongFieldRange eventIngestedRange; + private final IndexMetadataStats stats; private final Double indexWriteLoadForecast; private final Long shardSizeInBytesForecast; @@ -1551,6 +1585,7 @@ private static class IndexMetadataDiff implements Diff { mappingsUpdatedVersion = after.mappingsUpdatedVersion; isSystem = after.isSystem; timestampRange = after.timestampRange; + eventIngestedRange = after.eventIngestedRange; stats = after.stats; indexWriteLoadForecast = after.writeLoadForecast; shardSizeInBytesForecast = after.shardSizeInBytesForecast; @@ -1629,6 +1664,11 @@ private static class IndexMetadataDiff implements Diff { indexWriteLoadForecast = null; shardSizeInBytesForecast = null; } + if (in.getTransportVersion().onOrAfter(TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE)) { + eventIngestedRange = IndexLongFieldRange.readFrom(in); + } else { + eventIngestedRange = IndexLongFieldRange.UNKNOWN; + } } @Override @@ -1670,6 +1710,12 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalDouble(indexWriteLoadForecast); out.writeOptionalLong(shardSizeInBytesForecast); } + if (out.getTransportVersion().onOrAfter(TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE)) { + eventIngestedRange.writeTo(out); + } else { + assert eventIngestedRange == IndexLongFieldRange.UNKNOWN + : "eventIngestedRange should be UNKNOWN until all nodes are on the new version but is " + eventIngestedRange; + } } @Override @@ -1698,6 +1744,7 @@ public IndexMetadata apply(IndexMetadata part) { builder.rolloverInfos.putAllFromMap(rolloverInfos.apply(part.rolloverInfos)); builder.system(isSystem); builder.timestampRange(timestampRange); + builder.eventIngestedRange(eventIngestedRange); builder.stats(stats); builder.indexWriteLoadForecast(indexWriteLoadForecast); builder.shardSizeInBytesForecast(shardSizeInBytesForecast); @@ -1775,6 +1822,11 @@ public static IndexMetadata readFrom(StreamInput in, @Nullable Function mappingsMetadata = new HashMap<>(); DocumentMapper docMapper = documentMapperSupplier.get(); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java index 34f71d315f97a..be6d6f3ef1e53 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java @@ -886,6 +886,7 @@ static Tuple> closeRoutingTable( final IndexMetadata.Builder updatedMetadata = IndexMetadata.builder(indexMetadata).state(IndexMetadata.State.CLOSE); metadata.put( updatedMetadata.timestampRange(IndexLongFieldRange.NO_SHARDS) + .eventIngestedRange(IndexLongFieldRange.NO_SHARDS, currentState.getMinTransportVersion()) .settingsVersion(indexMetadata.getSettingsVersion() + 1) .settings(Settings.builder().put(indexMetadata.getSettings()).put(VERIFIED_BEFORE_CLOSE_SETTING.getKey(), true)) ); @@ -1132,6 +1133,7 @@ private ClusterState openIndices(final Index[] indices, final ClusterState curre .settingsVersion(indexMetadata.getSettingsVersion() + 1) .settings(updatedSettings) .timestampRange(IndexLongFieldRange.NO_SHARDS) + .eventIngestedRange(IndexLongFieldRange.NO_SHARDS, currentState.getMinTransportVersion()) .build(); // The index might be closed because we couldn't import it due to an old incompatible diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java index bce727f5790ff..a798ed58833b8 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java @@ -48,6 +48,7 @@ import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperService.MergeReason; import org.elasticsearch.index.mapper.RoutingFieldMapper; +import org.elasticsearch.index.shard.IndexLongFieldRange; import org.elasticsearch.indices.IndexTemplateMissingException; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.indices.InvalidIndexTemplateException; @@ -1652,7 +1653,12 @@ private static void validateCompositeTemplate( final ClusterState stateWithIndex = ClusterState.builder(stateWithTemplate) .metadata( Metadata.builder(stateWithTemplate.metadata()) - .put(IndexMetadata.builder(temporaryIndexName).settings(finalResolvedSettings)) + .put( + IndexMetadata.builder(temporaryIndexName) + // necessary to pass asserts in ClusterState constructor + .eventIngestedRange(IndexLongFieldRange.UNKNOWN, state.getMinTransportVersion()) + .settings(finalResolvedSettings) + ) .build() ) .build(); diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/IndexMetadataUpdater.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/IndexMetadataUpdater.java index e8231f8c09387..98885acd127e2 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/IndexMetadataUpdater.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/IndexMetadataUpdater.java @@ -9,6 +9,7 @@ package org.elasticsearch.cluster.routing.allocation; import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; @@ -104,9 +105,10 @@ public void relocationCompleted(ShardRouting removedRelocationSource) { * * @param oldMetadata {@link Metadata} object from before the routing nodes was changed. * @param newRoutingTable {@link RoutingTable} object after routing changes were applied. + * @param minClusterTransportVersion minimum TransportVersion used between nodes of this cluster * @return adapted {@link Metadata}, potentially the original one if no change was needed. */ - public Metadata applyChanges(Metadata oldMetadata, RoutingTable newRoutingTable) { + public Metadata applyChanges(Metadata oldMetadata, RoutingTable newRoutingTable, TransportVersion minClusterTransportVersion) { Map>> changesGroupedByIndex = shardChanges.entrySet() .stream() .collect(Collectors.groupingBy(e -> e.getKey().getIndex())); @@ -119,7 +121,14 @@ public Metadata applyChanges(Metadata oldMetadata, RoutingTable newRoutingTable) for (Map.Entry shardEntry : indexChanges.getValue()) { ShardId shardId = shardEntry.getKey(); Updates updates = shardEntry.getValue(); - updatedIndexMetadata = updateInSyncAllocations(newRoutingTable, oldIndexMetadata, updatedIndexMetadata, shardId, updates); + updatedIndexMetadata = updateInSyncAllocations( + newRoutingTable, + oldIndexMetadata, + updatedIndexMetadata, + shardId, + updates, + minClusterTransportVersion + ); updatedIndexMetadata = updates.increaseTerm ? updatedIndexMetadata.withIncrementedPrimaryTerm(shardId.id()) : updatedIndexMetadata; @@ -140,7 +149,8 @@ private static IndexMetadata updateInSyncAllocations( IndexMetadata oldIndexMetadata, IndexMetadata updatedIndexMetadata, ShardId shardId, - Updates updates + Updates updates, + TransportVersion minClusterTransportVersion ) { assert Sets.haveEmptyIntersection(updates.addedAllocationIds, updates.removedAllocationIds) : "allocation ids cannot be both added and removed in the same allocation round, added ids: " @@ -167,10 +177,13 @@ private static IndexMetadata updateInSyncAllocations( updatedIndexMetadata = updatedIndexMetadata.withInSyncAllocationIds(shardId.id(), Set.of()); } else { final String allocationId; + if (recoverySource == RecoverySource.ExistingStoreRecoverySource.FORCE_STALE_PRIMARY_INSTANCE) { allocationId = RecoverySource.ExistingStoreRecoverySource.FORCED_ALLOCATION_ID; - updatedIndexMetadata = updatedIndexMetadata.withTimestampRange( - updatedIndexMetadata.getTimestampRange().removeShard(shardId.id(), oldIndexMetadata.getNumberOfShards()) + updatedIndexMetadata = updatedIndexMetadata.withTimestampRanges( + updatedIndexMetadata.getTimestampRange().removeShard(shardId.id(), oldIndexMetadata.getNumberOfShards()), + updatedIndexMetadata.getEventIngestedRange().removeShard(shardId.id(), oldIndexMetadata.getNumberOfShards()), + minClusterTransportVersion ); } else { assert recoverySource instanceof RecoverySource.SnapshotRecoverySource diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java index 382e49135ea8d..af5f8cd7bd8c6 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java @@ -339,7 +339,7 @@ public RoutingChangesObserver changes() { * Returns updated {@link Metadata} based on the changes that were made to the routing nodes */ public Metadata updateMetadataWithRoutingChanges(RoutingTable newRoutingTable) { - Metadata metadata = indexMetadataUpdater.applyChanges(metadata(), newRoutingTable); + Metadata metadata = indexMetadataUpdater.applyChanges(metadata(), newRoutingTable, clusterState.getMinTransportVersion()); return resizeSourceIndexUpdater.applyChanges(metadata, newRoutingTable); } diff --git a/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java b/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java index 55c421b87196d..ae8f8cb28da11 100644 --- a/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java +++ b/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java @@ -44,6 +44,7 @@ import static java.time.temporal.ChronoField.MONTH_OF_YEAR; import static java.time.temporal.ChronoField.NANO_OF_SECOND; import static java.time.temporal.ChronoField.SECOND_OF_MINUTE; +import static org.elasticsearch.common.util.ArrayUtils.prepend; public class DateFormatters { @@ -202,7 +203,11 @@ private static DateFormatter newDateFormatter(String format, DateTimeFormatter p new JavaTimeDateTimePrinter(STRICT_DATE_OPTIONAL_TIME_PRINTER), JAVA_TIME_PARSERS_ONLY ? new DateTimeParser[] { javaTimeParser } - : new DateTimeParser[] { new Iso8601DateTimeParser(Set.of(), false).withLocale(Locale.ROOT), javaTimeParser } + : new DateTimeParser[] { + new Iso8601DateTimeParser(Set.of(), false, null, DecimalSeparator.BOTH, TimezonePresence.OPTIONAL).withLocale( + Locale.ROOT + ), + javaTimeParser } ); } @@ -266,7 +271,13 @@ private static DateFormatter newDateFormatter(String format, DateTimeFormatter p JAVA_TIME_PARSERS_ONLY ? new DateTimeParser[] { javaTimeParser } : new DateTimeParser[] { - new Iso8601DateTimeParser(Set.of(HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE), true).withLocale(Locale.ROOT), + new Iso8601DateTimeParser( + Set.of(HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE), + true, + null, + DecimalSeparator.BOTH, + TimezonePresence.OPTIONAL + ).withLocale(Locale.ROOT), javaTimeParser } ); } @@ -316,7 +327,11 @@ private static DateFormatter newDateFormatter(String format, DateTimeFormatter p new JavaTimeDateTimePrinter(STRICT_DATE_OPTIONAL_TIME_PRINTER), JAVA_TIME_PARSERS_ONLY ? new DateTimeParser[] { javaTimeParser } - : new DateTimeParser[] { new Iso8601DateTimeParser(Set.of(), false).withLocale(Locale.ROOT), javaTimeParser } + : new DateTimeParser[] { + new Iso8601DateTimeParser(Set.of(), false, null, DecimalSeparator.BOTH, TimezonePresence.OPTIONAL).withLocale( + Locale.ROOT + ), + javaTimeParser } ); } @@ -739,24 +754,53 @@ private static DateFormatter newDateFormatter(String format, DateTimeFormatter p /* * A strict formatter that formats or parses a year and a month, such as '2011-12'. */ - private static final DateFormatter STRICT_YEAR_MONTH = newDateFormatter( - "strict_year_month", - new DateTimeFormatterBuilder().appendValue(ChronoField.YEAR, 4, 4, SignStyle.EXCEEDS_PAD) + private static final DateFormatter STRICT_YEAR_MONTH; + static { + DateTimeFormatter javaTimeFormatter = new DateTimeFormatterBuilder().appendValue(ChronoField.YEAR, 4, 4, SignStyle.EXCEEDS_PAD) .appendLiteral("-") .appendValue(MONTH_OF_YEAR, 2, 2, SignStyle.NOT_NEGATIVE) .toFormatter(Locale.ROOT) - .withResolverStyle(ResolverStyle.STRICT) - ); + .withResolverStyle(ResolverStyle.STRICT); + DateTimeParser javaTimeParser = new JavaTimeDateTimeParser(javaTimeFormatter); + + STRICT_YEAR_MONTH = new JavaDateFormatter( + "strict_year_month", + new JavaTimeDateTimePrinter(javaTimeFormatter), + JAVA_TIME_PARSERS_ONLY + ? new DateTimeParser[] { javaTimeParser } + : new DateTimeParser[] { + new Iso8601DateTimeParser( + Set.of(MONTH_OF_YEAR), + false, + MONTH_OF_YEAR, + DecimalSeparator.BOTH, + TimezonePresence.FORBIDDEN + ).withLocale(Locale.ROOT), + javaTimeParser } + ); + } /* * A strict formatter that formats or parses a year, such as '2011'. */ - private static final DateFormatter STRICT_YEAR = newDateFormatter( - "strict_year", - new DateTimeFormatterBuilder().appendValue(ChronoField.YEAR, 4, 4, SignStyle.EXCEEDS_PAD) + private static final DateFormatter STRICT_YEAR; + static { + DateTimeFormatter javaTimeFormatter = new DateTimeFormatterBuilder().appendValue(ChronoField.YEAR, 4, 4, SignStyle.EXCEEDS_PAD) .toFormatter(Locale.ROOT) - .withResolverStyle(ResolverStyle.STRICT) - ); + .withResolverStyle(ResolverStyle.STRICT); + DateTimeParser javaTimeParser = new JavaTimeDateTimeParser(javaTimeFormatter); + + STRICT_YEAR = new JavaDateFormatter( + "strict_year", + new JavaTimeDateTimePrinter(javaTimeFormatter), + JAVA_TIME_PARSERS_ONLY + ? new DateTimeParser[] { javaTimeParser } + : new DateTimeParser[] { + new Iso8601DateTimeParser(Set.of(), false, ChronoField.YEAR, DecimalSeparator.BOTH, TimezonePresence.FORBIDDEN) + .withLocale(Locale.ROOT), + javaTimeParser } + ); + } /* * A strict formatter that formats or parses a hour, minute and second, such as '09:43:25'. @@ -787,18 +831,39 @@ private static DateFormatter newDateFormatter(String format, DateTimeFormatter p * Returns a formatter that combines a full date and time, separated by a 'T' * (uuuu-MM-dd'T'HH:mm:ss.SSSZZ). */ - private static final DateFormatter STRICT_DATE_TIME = newDateFormatter( - "strict_date_time", - STRICT_DATE_PRINTER, - new DateTimeFormatterBuilder().append(STRICT_DATE_FORMATTER) - .appendZoneOrOffsetId() - .toFormatter(Locale.ROOT) - .withResolverStyle(ResolverStyle.STRICT), - new DateTimeFormatterBuilder().append(STRICT_DATE_FORMATTER) - .append(TIME_ZONE_FORMATTER_NO_COLON) - .toFormatter(Locale.ROOT) - .withResolverStyle(ResolverStyle.STRICT) - ); + private static final DateFormatter STRICT_DATE_TIME; + static { + DateTimeParser[] javaTimeParsers = new DateTimeParser[] { + new JavaTimeDateTimeParser( + new DateTimeFormatterBuilder().append(STRICT_DATE_FORMATTER) + .appendZoneOrOffsetId() + .toFormatter(Locale.ROOT) + .withResolverStyle(ResolverStyle.STRICT) + ), + new JavaTimeDateTimeParser( + new DateTimeFormatterBuilder().append(STRICT_DATE_FORMATTER) + .append(TIME_ZONE_FORMATTER_NO_COLON) + .toFormatter(Locale.ROOT) + .withResolverStyle(ResolverStyle.STRICT) + ) }; + + STRICT_DATE_TIME = new JavaDateFormatter( + "strict_date_time", + new JavaTimeDateTimePrinter(STRICT_DATE_PRINTER), + JAVA_TIME_PARSERS_ONLY + ? javaTimeParsers + : prepend( + new Iso8601DateTimeParser( + Set.of(MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE), + false, + null, + DecimalSeparator.DOT, + TimezonePresence.MANDATORY + ).withLocale(Locale.ROOT), + javaTimeParsers + ) + ); + } private static final DateTimeFormatter STRICT_ORDINAL_DATE_TIME_NO_MILLIS_BASE = new DateTimeFormatterBuilder().appendValue( ChronoField.YEAR, @@ -841,21 +906,44 @@ private static DateFormatter newDateFormatter(String format, DateTimeFormatter p * Returns a formatter that combines a full date and time without millis, * separated by a 'T' (uuuu-MM-dd'T'HH:mm:ssZZ). */ - private static final DateFormatter STRICT_DATE_TIME_NO_MILLIS = newDateFormatter( - "strict_date_time_no_millis", - new DateTimeFormatterBuilder().append(STRICT_DATE_TIME_NO_MILLIS_FORMATTER) - .appendOffset("+HH:MM", "Z") - .toFormatter(Locale.ROOT) - .withResolverStyle(ResolverStyle.STRICT), - new DateTimeFormatterBuilder().append(STRICT_DATE_TIME_NO_MILLIS_FORMATTER) - .appendZoneOrOffsetId() - .toFormatter(Locale.ROOT) - .withResolverStyle(ResolverStyle.STRICT), - new DateTimeFormatterBuilder().append(STRICT_DATE_TIME_NO_MILLIS_FORMATTER) - .append(TIME_ZONE_FORMATTER_NO_COLON) - .toFormatter(Locale.ROOT) - .withResolverStyle(ResolverStyle.STRICT) - ); + private static final DateFormatter STRICT_DATE_TIME_NO_MILLIS; + static { + DateTimeParser[] javaTimeParsers = new DateTimeParser[] { + new JavaTimeDateTimeParser( + new DateTimeFormatterBuilder().append(STRICT_DATE_TIME_NO_MILLIS_FORMATTER) + .appendZoneOrOffsetId() + .toFormatter(Locale.ROOT) + .withResolverStyle(ResolverStyle.STRICT) + ), + new JavaTimeDateTimeParser( + new DateTimeFormatterBuilder().append(STRICT_DATE_TIME_NO_MILLIS_FORMATTER) + .append(TIME_ZONE_FORMATTER_NO_COLON) + .toFormatter(Locale.ROOT) + .withResolverStyle(ResolverStyle.STRICT) + ) }; + + STRICT_DATE_TIME_NO_MILLIS = new JavaDateFormatter( + "strict_date_time_no_millis", + new JavaTimeDateTimePrinter( + new DateTimeFormatterBuilder().append(STRICT_DATE_TIME_NO_MILLIS_FORMATTER) + .appendOffset("+HH:MM", "Z") + .toFormatter(Locale.ROOT) + .withResolverStyle(ResolverStyle.STRICT) + ), + JAVA_TIME_PARSERS_ONLY + ? javaTimeParsers + : prepend( + new Iso8601DateTimeParser( + Set.of(MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE), + false, + SECOND_OF_MINUTE, + DecimalSeparator.BOTH, + TimezonePresence.MANDATORY + ).withLocale(Locale.ROOT), + javaTimeParsers + ) + ); + } // NOTE: this is not a strict formatter to retain the joda time based behaviour, even though it's named like this private static final DateTimeFormatter STRICT_HOUR_MINUTE_SECOND_MILLIS_FORMATTER = new DateTimeFormatterBuilder().append( @@ -891,37 +979,75 @@ private static DateFormatter newDateFormatter(String format, DateTimeFormatter p * two digit minute of hour, two digit second of minute, and three digit * fraction of second (uuuu-MM-dd'T'HH:mm:ss.SSS). */ - private static final DateFormatter STRICT_DATE_HOUR_MINUTE_SECOND_FRACTION = newDateFormatter( - "strict_date_hour_minute_second_fraction", - new DateTimeFormatterBuilder().append(STRICT_YEAR_MONTH_DAY_FORMATTER) - .appendLiteral("T") - .append(STRICT_HOUR_MINUTE_SECOND_MILLIS_PRINTER) - .toFormatter(Locale.ROOT) - .withResolverStyle(ResolverStyle.STRICT), - new DateTimeFormatterBuilder().append(STRICT_YEAR_MONTH_DAY_FORMATTER) - .appendLiteral("T") - .append(STRICT_HOUR_MINUTE_SECOND_FORMATTER) - // this one here is lenient as well to retain joda time based bwc compatibility - .appendFraction(NANO_OF_SECOND, 1, 9, true) - .toFormatter(Locale.ROOT) - .withResolverStyle(ResolverStyle.STRICT) - ); + private static final DateFormatter STRICT_DATE_HOUR_MINUTE_SECOND_FRACTION; + static { + DateTimeParser javaTimeParser = new JavaTimeDateTimeParser( + new DateTimeFormatterBuilder().append(STRICT_YEAR_MONTH_DAY_FORMATTER) + .appendLiteral("T") + .append(STRICT_HOUR_MINUTE_SECOND_FORMATTER) + // this one here is lenient as well to retain joda time based bwc compatibility + .appendFraction(NANO_OF_SECOND, 1, 9, true) + .toFormatter(Locale.ROOT) + .withResolverStyle(ResolverStyle.STRICT) + ); - private static final DateFormatter STRICT_DATE_HOUR_MINUTE_SECOND_MILLIS = newDateFormatter( - "strict_date_hour_minute_second_millis", - new DateTimeFormatterBuilder().append(STRICT_YEAR_MONTH_DAY_FORMATTER) - .appendLiteral("T") - .append(STRICT_HOUR_MINUTE_SECOND_MILLIS_PRINTER) - .toFormatter(Locale.ROOT) - .withResolverStyle(ResolverStyle.STRICT), - new DateTimeFormatterBuilder().append(STRICT_YEAR_MONTH_DAY_FORMATTER) - .appendLiteral("T") - .append(STRICT_HOUR_MINUTE_SECOND_FORMATTER) - // this one here is lenient as well to retain joda time based bwc compatibility - .appendFraction(NANO_OF_SECOND, 1, 9, true) - .toFormatter(Locale.ROOT) - .withResolverStyle(ResolverStyle.STRICT) - ); + STRICT_DATE_HOUR_MINUTE_SECOND_FRACTION = new JavaDateFormatter( + "strict_date_hour_minute_second_fraction", + new JavaTimeDateTimePrinter( + new DateTimeFormatterBuilder().append(STRICT_YEAR_MONTH_DAY_FORMATTER) + .appendLiteral("T") + .append(STRICT_HOUR_MINUTE_SECOND_MILLIS_PRINTER) + .toFormatter(Locale.ROOT) + .withResolverStyle(ResolverStyle.STRICT) + ), + JAVA_TIME_PARSERS_ONLY + ? new DateTimeParser[] { javaTimeParser } + : new DateTimeParser[] { + new Iso8601DateTimeParser( + Set.of(MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE, NANO_OF_SECOND), + false, + null, + DecimalSeparator.DOT, + TimezonePresence.FORBIDDEN + ).withLocale(Locale.ROOT), + javaTimeParser } + ); + } + + private static final DateFormatter STRICT_DATE_HOUR_MINUTE_SECOND_MILLIS; + static { + DateTimeParser javaTimeParser = new JavaTimeDateTimeParser( + new DateTimeFormatterBuilder().append(STRICT_YEAR_MONTH_DAY_FORMATTER) + .appendLiteral("T") + .append(STRICT_HOUR_MINUTE_SECOND_FORMATTER) + // this one here is lenient as well to retain joda time based bwc compatibility + .appendFraction(NANO_OF_SECOND, 1, 9, true) + .toFormatter(Locale.ROOT) + .withResolverStyle(ResolverStyle.STRICT) + ); + + STRICT_DATE_HOUR_MINUTE_SECOND_MILLIS = new JavaDateFormatter( + "strict_date_hour_minute_second_millis", + new JavaTimeDateTimePrinter( + new DateTimeFormatterBuilder().append(STRICT_YEAR_MONTH_DAY_FORMATTER) + .appendLiteral("T") + .append(STRICT_HOUR_MINUTE_SECOND_MILLIS_PRINTER) + .toFormatter(Locale.ROOT) + .withResolverStyle(ResolverStyle.STRICT) + ), + JAVA_TIME_PARSERS_ONLY + ? new DateTimeParser[] { javaTimeParser } + : new DateTimeParser[] { + new Iso8601DateTimeParser( + Set.of(MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE, NANO_OF_SECOND), + false, + null, + DecimalSeparator.DOT, + TimezonePresence.FORBIDDEN + ).withLocale(Locale.ROOT), + javaTimeParser } + ); + } /* * Returns a formatter for a two digit hour of day. (HH) @@ -1235,10 +1361,27 @@ private static DateFormatter newDateFormatter(String format, DateTimeFormatter p * two digit minute of hour, and two digit second of * minute. (uuuu-MM-dd'T'HH:mm:ss) */ - private static final DateFormatter STRICT_DATE_HOUR_MINUTE_SECOND = newDateFormatter( - "strict_date_hour_minute_second", - DateTimeFormatter.ofPattern("uuuu-MM-dd'T'HH:mm:ss", Locale.ROOT) - ); + private static final DateFormatter STRICT_DATE_HOUR_MINUTE_SECOND; + static { + DateTimeFormatter javaTimeFormatter = DateTimeFormatter.ofPattern("uuuu-MM-dd'T'HH:mm:ss", Locale.ROOT); + DateTimeParser javaTimeParser = new JavaTimeDateTimeParser(javaTimeFormatter); + + STRICT_DATE_HOUR_MINUTE_SECOND = new JavaDateFormatter( + "strict_date_hour_minute_second", + new JavaTimeDateTimePrinter(javaTimeFormatter), + JAVA_TIME_PARSERS_ONLY + ? new DateTimeParser[] { javaTimeParser } + : new DateTimeParser[] { + new Iso8601DateTimeParser( + Set.of(MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE), + false, + SECOND_OF_MINUTE, + DecimalSeparator.BOTH, + TimezonePresence.FORBIDDEN + ).withLocale(Locale.ROOT), + javaTimeParser } + ); + } /* * A basic formatter for a full date as four digit year, two digit diff --git a/server/src/main/java/org/elasticsearch/common/time/DecimalSeparator.java b/server/src/main/java/org/elasticsearch/common/time/DecimalSeparator.java new file mode 100644 index 0000000000000..3598599e1f759 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/time/DecimalSeparator.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.time; + +enum DecimalSeparator { + DOT, + COMMA, + BOTH +} diff --git a/server/src/main/java/org/elasticsearch/common/time/Iso8601DateTimeParser.java b/server/src/main/java/org/elasticsearch/common/time/Iso8601DateTimeParser.java index cce4b13f4a166..027c1ec94a411 100644 --- a/server/src/main/java/org/elasticsearch/common/time/Iso8601DateTimeParser.java +++ b/server/src/main/java/org/elasticsearch/common/time/Iso8601DateTimeParser.java @@ -24,8 +24,14 @@ class Iso8601DateTimeParser implements DateTimeParser { // and we already account for . or , in decimals private final Locale locale; - Iso8601DateTimeParser(Set mandatoryFields, boolean optionalTime) { - parser = new Iso8601Parser(mandatoryFields, optionalTime, Map.of()); + Iso8601DateTimeParser( + Set mandatoryFields, + boolean optionalTime, + ChronoField maxAllowedField, + DecimalSeparator decimalSeparator, + TimezonePresence timezonePresence + ) { + parser = new Iso8601Parser(mandatoryFields, optionalTime, maxAllowedField, decimalSeparator, timezonePresence, Map.of()); timezone = null; locale = null; } @@ -57,7 +63,18 @@ public DateTimeParser withLocale(Locale locale) { } Iso8601DateTimeParser withDefaults(Map defaults) { - return new Iso8601DateTimeParser(new Iso8601Parser(parser.mandatoryFields(), parser.optionalTime(), defaults), timezone, locale); + return new Iso8601DateTimeParser( + new Iso8601Parser( + parser.mandatoryFields(), + parser.optionalTime(), + parser.maxAllowedField(), + parser.decimalSeparator(), + parser.timezonePresence(), + defaults + ), + timezone, + locale + ); } @Override diff --git a/server/src/main/java/org/elasticsearch/common/time/Iso8601Parser.java b/server/src/main/java/org/elasticsearch/common/time/Iso8601Parser.java index fe92ff62b6ddc..6e420df9c72dd 100644 --- a/server/src/main/java/org/elasticsearch/common/time/Iso8601Parser.java +++ b/server/src/main/java/org/elasticsearch/common/time/Iso8601Parser.java @@ -13,16 +13,18 @@ import java.time.DateTimeException; import java.time.ZoneId; import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoField; import java.util.EnumMap; import java.util.EnumSet; import java.util.Map; +import java.util.Objects; import java.util.Set; /** * Parses datetimes in ISO8601 format (and subsequences thereof). *

- * This is faster than the generic parsing in {@link java.time.format.DateTimeFormatter}, as this is hard-coded and specific to ISO-8601. + * This is faster than the generic parsing in {@link DateTimeFormatter}, as this is hard-coded and specific to ISO-8601. * Various public libraries provide their own variant of this mechanism. We use our own for a few reasons: *

    *
  • @@ -37,13 +39,14 @@ */ class Iso8601Parser { - private static final Set VALID_MANDATORY_FIELDS = EnumSet.of( + private static final Set VALID_SPECIFIED_FIELDS = EnumSet.of( ChronoField.YEAR, ChronoField.MONTH_OF_YEAR, ChronoField.DAY_OF_MONTH, ChronoField.HOUR_OF_DAY, ChronoField.MINUTE_OF_HOUR, - ChronoField.SECOND_OF_MINUTE + ChronoField.SECOND_OF_MINUTE, + ChronoField.NANO_OF_SECOND ); private static final Set VALID_DEFAULT_FIELDS = EnumSet.of( @@ -57,31 +60,51 @@ class Iso8601Parser { private final Set mandatoryFields; private final boolean optionalTime; + @Nullable + private final ChronoField maxAllowedField; + private final DecimalSeparator decimalSeparator; + private final TimezonePresence timezonePresence; private final Map defaults; /** * Constructs a new {@code Iso8601Parser} object * - * @param mandatoryFields - * The set of fields that must be present for a valid parse. These should be specified in field order - * (eg if {@link ChronoField#DAY_OF_MONTH} is specified, {@link ChronoField#MONTH_OF_YEAR} should also be specified). - * {@link ChronoField#YEAR} is always mandatory. - * @param optionalTime - * {@code false} if the presence of time fields follows {@code mandatoryFields}, - * {@code true} if a time component is always optional, despite the presence of time fields in {@code mandatoryFields}. - * This makes it possible to specify 'time is optional, but if it is present, it must have these fields' - * by settings {@code optionalTime = true} and putting time fields such as {@link ChronoField#HOUR_OF_DAY} - * and {@link ChronoField#MINUTE_OF_HOUR} in {@code mandatoryFields}. - * @param defaults - * Map of default field values, if they are not present in the parsed string. + * @param mandatoryFields The set of fields that must be present for a valid parse. These should be specified in field order + * (eg if {@link ChronoField#DAY_OF_MONTH} is specified, + * {@link ChronoField#MONTH_OF_YEAR} should also be specified). + * {@link ChronoField#YEAR} is always mandatory. + * @param optionalTime {@code false} if the presence of time fields follows {@code mandatoryFields}, + * {@code true} if a time component is always optional, + * despite the presence of time fields in {@code mandatoryFields}. + * This makes it possible to specify 'time is optional, but if it is present, it must have these fields' + * by settings {@code optionalTime = true} and putting time fields such as {@link ChronoField#HOUR_OF_DAY} + * and {@link ChronoField#MINUTE_OF_HOUR} in {@code mandatoryFields}. + * @param maxAllowedField The most-specific field allowed in the parsed string, + * or {@code null} if everything up to nanoseconds is allowed. + * @param decimalSeparator The decimal separator that is allowed. + * @param timezonePresence Specifies if the timezone is optional, mandatory, or forbidden. + * @param defaults Map of default field values, if they are not present in the parsed string. */ - Iso8601Parser(Set mandatoryFields, boolean optionalTime, Map defaults) { - checkChronoFields(mandatoryFields, VALID_MANDATORY_FIELDS); + Iso8601Parser( + Set mandatoryFields, + boolean optionalTime, + @Nullable ChronoField maxAllowedField, + DecimalSeparator decimalSeparator, + TimezonePresence timezonePresence, + Map defaults + ) { + checkChronoFields(mandatoryFields, VALID_SPECIFIED_FIELDS); + if (maxAllowedField != null && VALID_SPECIFIED_FIELDS.contains(maxAllowedField) == false) { + throw new IllegalArgumentException("Invalid chrono field specified " + maxAllowedField); + } checkChronoFields(defaults.keySet(), VALID_DEFAULT_FIELDS); this.mandatoryFields = EnumSet.of(ChronoField.YEAR); // year is always mandatory this.mandatoryFields.addAll(mandatoryFields); this.optionalTime = optionalTime; + this.maxAllowedField = maxAllowedField; + this.decimalSeparator = Objects.requireNonNull(decimalSeparator); + this.timezonePresence = Objects.requireNonNull(timezonePresence); this.defaults = defaults.isEmpty() ? Map.of() : new EnumMap<>(defaults); } @@ -103,6 +126,18 @@ Set mandatoryFields() { return mandatoryFields; } + ChronoField maxAllowedField() { + return maxAllowedField; + } + + DecimalSeparator decimalSeparator() { + return decimalSeparator; + } + + TimezonePresence timezonePresence() { + return timezonePresence; + } + private boolean isOptional(ChronoField field) { return mandatoryFields.contains(field) == false; } @@ -186,7 +221,7 @@ private ParseResult parse(CharSequence str, @Nullable ZoneId defaultTimezone) { : ParseResult.error(4); } - if (str.charAt(4) != '-') return ParseResult.error(4); + if (str.charAt(4) != '-' || maxAllowedField == ChronoField.YEAR) return ParseResult.error(4); // MONTHS Integer months = parseInt(str, 5, 7); @@ -208,7 +243,7 @@ private ParseResult parse(CharSequence str, @Nullable ZoneId defaultTimezone) { : ParseResult.error(7); } - if (str.charAt(7) != '-') return ParseResult.error(7); + if (str.charAt(7) != '-' || maxAllowedField == ChronoField.MONTH_OF_YEAR) return ParseResult.error(7); // DAYS Integer days = parseInt(str, 8, 10); @@ -230,7 +265,7 @@ private ParseResult parse(CharSequence str, @Nullable ZoneId defaultTimezone) { : ParseResult.error(10); } - if (str.charAt(10) != 'T') return ParseResult.error(10); + if (str.charAt(10) != 'T' || maxAllowedField == ChronoField.DAY_OF_MONTH) return ParseResult.error(10); if (len == 11) { return isOptional(ChronoField.HOUR_OF_DAY) ? new ParseResult( @@ -252,7 +287,7 @@ private ParseResult parse(CharSequence str, @Nullable ZoneId defaultTimezone) { Integer hours = parseInt(str, 11, 13); if (hours == null || hours > 23) return ParseResult.error(11); if (len == 13) { - return isOptional(ChronoField.MINUTE_OF_HOUR) + return isOptional(ChronoField.MINUTE_OF_HOUR) && timezonePresence != TimezonePresence.MANDATORY ? new ParseResult( withZoneOffset( years, @@ -285,13 +320,13 @@ private ParseResult parse(CharSequence str, @Nullable ZoneId defaultTimezone) { : ParseResult.error(13); } - if (str.charAt(13) != ':') return ParseResult.error(13); + if (str.charAt(13) != ':' || maxAllowedField == ChronoField.HOUR_OF_DAY) return ParseResult.error(13); // MINUTES + timezone Integer minutes = parseInt(str, 14, 16); if (minutes == null || minutes > 59) return ParseResult.error(14); if (len == 16) { - return isOptional(ChronoField.SECOND_OF_MINUTE) + return isOptional(ChronoField.SECOND_OF_MINUTE) && timezonePresence != TimezonePresence.MANDATORY ? new ParseResult( withZoneOffset( years, @@ -324,15 +359,17 @@ private ParseResult parse(CharSequence str, @Nullable ZoneId defaultTimezone) { : ParseResult.error(16); } - if (str.charAt(16) != ':') return ParseResult.error(16); + if (str.charAt(16) != ':' || maxAllowedField == ChronoField.MINUTE_OF_HOUR) return ParseResult.error(16); // SECONDS + timezone Integer seconds = parseInt(str, 17, 19); if (seconds == null || seconds > 59) return ParseResult.error(17); if (len == 19) { - return new ParseResult( - withZoneOffset(years, months, days, hours, minutes, seconds, defaultZero(ChronoField.NANO_OF_SECOND), defaultTimezone) - ); + return isOptional(ChronoField.NANO_OF_SECOND) && timezonePresence != TimezonePresence.MANDATORY + ? new ParseResult( + withZoneOffset(years, months, days, hours, minutes, seconds, defaultZero(ChronoField.NANO_OF_SECOND), defaultTimezone) + ) + : ParseResult.error(19); } if (isZoneId(str, 19)) { ZoneId timezone = parseZoneId(str, 19); @@ -343,11 +380,9 @@ private ParseResult parse(CharSequence str, @Nullable ZoneId defaultTimezone) { : ParseResult.error(19); } - char decSeparator = str.charAt(19); - if (decSeparator != '.' && decSeparator != ',') return ParseResult.error(19); + if (checkDecimalSeparator(str.charAt(19)) == false || maxAllowedField == ChronoField.SECOND_OF_MINUTE) return ParseResult.error(19); // NANOS + timezone - // nanos are always optional // the last number could be millis or nanos, or any combination in the middle // so we keep parsing numbers until we get to not a number int nanos = 0; @@ -364,7 +399,9 @@ private ParseResult parse(CharSequence str, @Nullable ZoneId defaultTimezone) { nanos *= NANO_MULTIPLICANDS[29 - pos]; if (len == pos) { - return new ParseResult(withZoneOffset(years, months, days, hours, minutes, seconds, nanos, defaultTimezone)); + return timezonePresence != TimezonePresence.MANDATORY + ? new ParseResult(withZoneOffset(years, months, days, hours, minutes, seconds, nanos, defaultTimezone)) + : ParseResult.error(pos); } if (isZoneId(str, pos)) { ZoneId timezone = parseZoneId(str, pos); @@ -377,6 +414,16 @@ private ParseResult parse(CharSequence str, @Nullable ZoneId defaultTimezone) { return ParseResult.error(pos); } + private boolean checkDecimalSeparator(char separator) { + boolean isDot = separator == '.'; + boolean isComma = separator == ','; + return switch (decimalSeparator) { + case DOT -> isDot; + case COMMA -> isComma; + case BOTH -> isDot || isComma; + }; + } + private static boolean isZoneId(CharSequence str, int pos) { // all region zoneIds must start with [A-Za-z] (see ZoneId#of) // this also covers Z and UT/UTC/GMT zone variants @@ -385,10 +432,14 @@ private static boolean isZoneId(CharSequence str, int pos) { } /** - * This parses the zone offset, which is of the format accepted by {@link java.time.ZoneId#of(String)}. + * This parses the zone offset, which is of the format accepted by {@link ZoneId#of(String)}. * It has fast paths for numerical offsets, but falls back on {@code ZoneId.of} for non-trivial zone ids. */ private ZoneId parseZoneId(CharSequence str, int pos) { + if (timezonePresence == TimezonePresence.FORBIDDEN) { + return null; + } + int len = str.length(); char first = str.charAt(pos); diff --git a/server/src/main/java/org/elasticsearch/common/time/JavaDateFormatter.java b/server/src/main/java/org/elasticsearch/common/time/JavaDateFormatter.java index e8d729f9e9977..79b0c44d39108 100644 --- a/server/src/main/java/org/elasticsearch/common/time/JavaDateFormatter.java +++ b/server/src/main/java/org/elasticsearch/common/time/JavaDateFormatter.java @@ -18,7 +18,6 @@ import java.time.temporal.TemporalAccessor; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; @@ -149,19 +148,24 @@ static DateFormatter combined(String input, List formatters) { assert formatters.isEmpty() == false; DateTimePrinter printer = null; - List parsers = new ArrayList<>(formatters.size()); - List roundUpParsers = new ArrayList<>(formatters.size()); + List parsers = new ArrayList<>(formatters.size()); + List roundUpParsers = new ArrayList<>(formatters.size()); for (DateFormatter formatter : formatters) { JavaDateFormatter javaDateFormatter = (JavaDateFormatter) formatter; if (printer == null) { printer = javaDateFormatter.printer; } - Collections.addAll(parsers, javaDateFormatter.parsers); - Collections.addAll(roundUpParsers, javaDateFormatter.roundupParsers); + parsers.add(javaDateFormatter.parsers); + roundUpParsers.add(javaDateFormatter.roundupParsers); } - return new JavaDateFormatter(input, printer, roundUpParsers.toArray(DateTimeParser[]::new), parsers.toArray(DateTimeParser[]::new)); + return new JavaDateFormatter( + input, + printer, + roundUpParsers.stream().flatMap(Arrays::stream).toArray(DateTimeParser[]::new), + parsers.stream().flatMap(Arrays::stream).toArray(DateTimeParser[]::new) + ); } private JavaDateFormatter(String format, DateTimePrinter printer, DateTimeParser[] roundupParsers, DateTimeParser[] parsers) { diff --git a/server/src/main/java/org/elasticsearch/common/time/TimezonePresence.java b/server/src/main/java/org/elasticsearch/common/time/TimezonePresence.java new file mode 100644 index 0000000000000..fd8cdcc28976d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/common/time/TimezonePresence.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.common.time; + +enum TimezonePresence { + OPTIONAL, + MANDATORY, + FORBIDDEN +} diff --git a/server/src/main/java/org/elasticsearch/common/util/ArrayUtils.java b/server/src/main/java/org/elasticsearch/common/util/ArrayUtils.java index 2f1264fa88247..0b48a298fe59a 100644 --- a/server/src/main/java/org/elasticsearch/common/util/ArrayUtils.java +++ b/server/src/main/java/org/elasticsearch/common/util/ArrayUtils.java @@ -68,6 +68,21 @@ public static T[] concat(T[] one, T[] other) { return target; } + /** + * Copy the given element and array into a new array of size {@code array.length + 1}. + * @param added first element in the newly created array + * @param array array to copy to the end of new returned array copy + * @return copy that contains added element and array + * @param type of the array elements + */ + public static T[] prepend(T added, T[] array) { + @SuppressWarnings("unchecked") + T[] updated = (T[]) Array.newInstance(array.getClass().getComponentType(), array.length + 1); + updated[0] = added; + System.arraycopy(array, 0, updated, 1, array.length); + return updated; + } + /** * Copy the given array and the added element into a new array of size {@code array.length + 1}. * @param array array to copy to the beginning of new returned array copy @@ -76,9 +91,7 @@ public static T[] concat(T[] one, T[] other) { * @param type of the array elements */ public static T[] append(T[] array, T added) { - @SuppressWarnings("unchecked") - final T[] updated = (T[]) Array.newInstance(added.getClass(), array.length + 1); - System.arraycopy(array, 0, updated, 0, array.length); + T[] updated = Arrays.copyOf(array, array.length + 1); updated[array.length] = added; return updated; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java index 690b580d0c322..861f5ecd56f5a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java @@ -54,11 +54,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException return new ES813FlatVectorReader(format.fieldsReader(state)); } - public static class ES813FlatVectorWriter extends KnnVectorsWriter { + static class ES813FlatVectorWriter extends KnnVectorsWriter { private final FlatVectorsWriter writer; - public ES813FlatVectorWriter(FlatVectorsWriter writer) { + ES813FlatVectorWriter(FlatVectorsWriter writer) { super(); this.writer = writer; } @@ -94,11 +94,11 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE } } - public static class ES813FlatVectorReader extends KnnVectorsReader { + static class ES813FlatVectorReader extends KnnVectorsReader { private final FlatVectorsReader reader; - public ES813FlatVectorReader(FlatVectorsReader reader) { + ES813FlatVectorReader(FlatVectorsReader reader) { super(); this.reader = reader; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorFormat.java new file mode 100644 index 0000000000000..86bc58c5862ee --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorFormat.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +public class ES815BitFlatVectorFormat extends KnnVectorsFormat { + + static final String NAME = "ES815BitFlatVectorFormat"; + + private final FlatVectorsFormat format = new ES815BitFlatVectorsFormat(); + + /** + * Sole constructor + */ + public ES815BitFlatVectorFormat() { + super(NAME); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES813FlatVectorFormat.ES813FlatVectorWriter(format.fieldsWriter(state)); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new ES813FlatVectorFormat.ES813FlatVectorReader(format.fieldsReader(state)); + } + + @Override + public String toString() { + return NAME; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java new file mode 100644 index 0000000000000..659cc89bfe46d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; + +import java.io.IOException; + +class ES815BitFlatVectorsFormat extends FlatVectorsFormat { + + private final FlatVectorsFormat delegate = new Lucene99FlatVectorsFormat(FlatBitVectorScorer.INSTANCE); + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState segmentWriteState) throws IOException { + return delegate.fieldsWriter(segmentWriteState); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState segmentReadState) throws IOException { + return delegate.fieldsReader(segmentReadState); + } + + static class FlatBitVectorScorer implements FlatVectorsScorer { + + static final FlatBitVectorScorer INSTANCE = new FlatBitVectorScorer(); + + static void checkDimensions(int queryLen, int fieldLen) { + if (queryLen != fieldLen) { + throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen); + } + } + + @Override + public String toString() { + return super.toString(); + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues + ) throws IOException { + assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes; + assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN; + if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) { + assert randomAccessVectorValues instanceof RandomAccessQuantizedByteVectorValues == false; + return switch (vectorSimilarityFunction) { + case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingScorerSupplier(randomAccessVectorValuesBytes); + }; + } + throw new IllegalArgumentException("Unsupported vector type or similarity function"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues, + byte[] bytes + ) { + assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes; + assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN; + if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) { + checkDimensions(bytes.length, randomAccessVectorValuesBytes.dimension()); + return switch (vectorSimilarityFunction) { + case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingVectorScorer( + randomAccessVectorValuesBytes, + bytes + ); + }; + } + throw new IllegalArgumentException("Unsupported vector type or similarity function"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessVectorValues randomAccessVectorValues, + float[] floats + ) { + throw new IllegalArgumentException("Unsupported vector type"); + } + } + + static float hammingScore(byte[] a, byte[] b) { + return ((a.length * Byte.SIZE) - VectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE); + } + + static class HammingVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + private final byte[] query; + private final RandomAccessVectorValues.Bytes byteValues; + + HammingVectorScorer(RandomAccessVectorValues.Bytes byteValues, byte[] query) { + super(byteValues); + this.query = query; + this.byteValues = byteValues; + } + + @Override + public float score(int i) throws IOException { + return hammingScore(byteValues.vectorValue(i), query); + } + } + + static class HammingScorerSupplier implements RandomVectorScorerSupplier { + private final RandomAccessVectorValues.Bytes byteValues, byteValues1, byteValues2; + + HammingScorerSupplier(RandomAccessVectorValues.Bytes byteValues) throws IOException { + this.byteValues = byteValues; + this.byteValues1 = byteValues.copy(); + this.byteValues2 = byteValues.copy(); + } + + @Override + public RandomVectorScorer scorer(int i) throws IOException { + byte[] query = byteValues1.vectorValue(i); + return new HammingVectorScorer(byteValues2, query); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new HammingScorerSupplier(byteValues); + } + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java new file mode 100644 index 0000000000000..f7884c0b73688 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +public class ES815HnswBitVectorsFormat extends KnnVectorsFormat { + + static final String NAME = "ES815HnswBitVectorsFormat"; + + static final int MAXIMUM_MAX_CONN = 512; + static final int MAXIMUM_BEAM_WIDTH = 3200; + + private final int maxConn; + private final int beamWidth; + + private final FlatVectorsFormat flatVectorsFormat = new ES815BitFlatVectorsFormat(); + + public ES815HnswBitVectorsFormat() { + this(16, 100); + } + + public ES815HnswBitVectorsFormat(int maxConn, int beamWidth) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn + ); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth + ); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public String toString() { + return "ES815HnswBitVectorsFormat(name=ES815HnswBitVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java index 72bd15c3c3daa..34c518a93404b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java @@ -1022,7 +1022,11 @@ private String originalName() { @Override protected SyntheticSourceMode syntheticSourceMode() { - return SyntheticSourceMode.NATIVE; + if (fieldType.stored() || hasDocValues) { + return SyntheticSourceMode.NATIVE; + } + + return SyntheticSourceMode.FALLBACK; } @Override @@ -1044,6 +1048,7 @@ public SourceLoader.SyntheticFieldLoader syntheticFieldLoader(String simpleName) "field [" + fullPath() + "] of type [" + typeName() + "] doesn't support synthetic source because it declares a normalizer" ); } + if (fieldType.stored()) { return new StringStoredFieldFieldLoader( fullPath(), @@ -1057,33 +1062,29 @@ protected void write(XContentBuilder b, Object value) throws IOException { } }; } - if (hasDocValues == false) { - throw new IllegalArgumentException( - "field [" - + fullPath() - + "] of type [" - + typeName() - + "] doesn't support synthetic source because it doesn't have doc values and isn't stored" - ); - } - return new SortedSetDocValuesSyntheticFieldLoader( - fullPath(), - simpleName, - fieldType().ignoreAbove == Defaults.IGNORE_ABOVE ? null : originalName(), - false - ) { - @Override - protected BytesRef convert(BytesRef value) { - return value; - } + if (hasDocValues) { + return new SortedSetDocValuesSyntheticFieldLoader( + fullPath(), + simpleName, + fieldType().ignoreAbove == Defaults.IGNORE_ABOVE ? null : originalName(), + false + ) { - @Override - protected BytesRef preserve(BytesRef value) { - // Preserve must make a deep copy because convert gets a shallow copy from the iterator - return BytesRef.deepCopyOf(value); - } - }; + @Override + protected BytesRef convert(BytesRef value) { + return value; + } + + @Override + protected BytesRef preserve(BytesRef value) { + // Preserve must make a deep copy because convert gets a shallow copy from the iterator + return BytesRef.deepCopyOf(value); + } + }; + } + + return super.syntheticFieldLoader(); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index ab5e731c1430a..f7d9b2b4cbd28 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -25,7 +25,8 @@ public Set getFeatures() { PassThroughObjectMapper.PASS_THROUGH_PRIORITY, RangeFieldMapper.NULL_VALUES_OFF_BY_ONE_FIX, SourceFieldMapper.SYNTHETIC_SOURCE_FALLBACK, - DenseVectorFieldMapper.INT4_QUANTIZATION + DenseVectorFieldMapper.INT4_QUANTIZATION, + DenseVectorFieldMapper.BIT_VECTORS ); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapper.java index 140a7849754a8..1e5143a58f20a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapper.java @@ -1988,7 +1988,11 @@ public void doValidate(MappingLookup lookup) { @Override protected SyntheticSourceMode syntheticSourceMode() { - return SyntheticSourceMode.NATIVE; + if (hasDocValues) { + return SyntheticSourceMode.NATIVE; + } + + return SyntheticSourceMode.FALLBACK; } @Override @@ -1996,21 +2000,16 @@ public SourceLoader.SyntheticFieldLoader syntheticFieldLoader() { if (hasScript()) { return SourceLoader.SyntheticFieldLoader.NOTHING; } - if (hasDocValues == false) { - throw new IllegalArgumentException( - "field [" - + fullPath() - + "] of type [" - + typeName() - + "] doesn't support synthetic source because it doesn't have doc values" - ); - } if (copyTo.copyToFields().isEmpty() != true) { throw new IllegalArgumentException( "field [" + fullPath() + "] of type [" + typeName() + "] doesn't support synthetic source because it declares copy_to" ); } - return type.syntheticFieldLoader(fullPath(), leafName(), ignoreMalformed.value()); + if (hasDocValues) { + return type.syntheticFieldLoader(fullPath(), leafName(), ignoreMalformed.value()); + } + + return super.syntheticFieldLoader(); } // For testing only: diff --git a/server/src/main/java/org/elasticsearch/index/mapper/flattened/FlattenedFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/flattened/FlattenedFieldMapper.java index eb57e068bd89f..85407fe824275 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/flattened/FlattenedFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/flattened/FlattenedFieldMapper.java @@ -275,6 +275,10 @@ public String typeName() { return CONTENT_TYPE; } + public String rootName() { + return this.rootName; + } + public String key() { return key; } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 15cd10b4f67dc..3a50fe6f28a6a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -31,6 +31,7 @@ import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.common.ParsingException; @@ -41,6 +42,8 @@ import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat; +import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; @@ -100,6 +103,7 @@ static boolean isNotUnitVector(float magnitude) { } public static final NodeFeature INT4_QUANTIZATION = new NodeFeature("mapper.vectors.int4_quantization"); + public static final NodeFeature BIT_VECTORS = new NodeFeature("mapper.vectors.bit_vectors"); public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersions.V_7_5_0; public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersions.FIRST_DETACHED_INDEX_VERSION; @@ -109,6 +113,7 @@ static boolean isNotUnitVector(float magnitude) { public static final String CONTENT_TYPE = "dense_vector"; public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions + public static int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector public static final int MAGNITUDE_BYTES = 4; @@ -134,17 +139,28 @@ public static class Builder extends FieldMapper.Builder { throw new MapperParsingException("Property [dims] on field [" + n + "] must be an integer but got [" + o + "]"); } int dims = XContentMapValues.nodeIntegerValue(o); - if (dims < 1 || dims > MAX_DIMS_COUNT) { + int maxDims = elementType.getValue() == ElementType.BIT ? MAX_DIMS_COUNT_BIT : MAX_DIMS_COUNT; + int minDims = elementType.getValue() == ElementType.BIT ? Byte.SIZE : 1; + if (dims < minDims || dims > maxDims) { throw new MapperParsingException( "The number of dimensions for field [" + n - + "] should be in the range [1, " - + MAX_DIMS_COUNT + + "] should be in the range [" + + minDims + + ", " + + maxDims + "] but was [" + dims + "]" ); } + if (elementType.getValue() == ElementType.BIT) { + if (dims % Byte.SIZE != 0) { + throw new MapperParsingException( + "The number of dimensions for field [" + n + "] should be a multiple of 8 but was [" + dims + "]" + ); + } + } return dims; }, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null) .setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current)); @@ -171,13 +187,27 @@ public Builder(String name, IndexVersion indexVersionCreated) { "similarity", false, m -> toType(m).fieldType().similarity, - (Supplier) () -> indexedByDefault && indexed.getValue() ? VectorSimilarity.COSINE : null, + (Supplier) () -> { + if (indexedByDefault && indexed.getValue()) { + return elementType.getValue() == ElementType.BIT ? VectorSimilarity.L2_NORM : VectorSimilarity.COSINE; + } + return null; + }, VectorSimilarity.class - ).acceptsNull().setSerializerCheck((id, ic, v) -> v != null); + ).acceptsNull().setSerializerCheck((id, ic, v) -> v != null).addValidator(vectorSim -> { + if (vectorSim == null) { + return; + } + if (elementType.getValue() == ElementType.BIT && vectorSim != VectorSimilarity.L2_NORM) { + throw new IllegalArgumentException( + "The [" + VectorSimilarity.L2_NORM + "] similarity is the only supported similarity for bit vectors" + ); + } + }); this.indexOptions = new Parameter<>( "index_options", true, - () -> defaultInt8Hnsw && elementType.getValue() != ElementType.BYTE && this.indexed.getValue() + () -> defaultInt8Hnsw && elementType.getValue() == ElementType.FLOAT && this.indexed.getValue() ? new Int8HnswIndexOptions( Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, @@ -266,7 +296,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) { public enum ElementType { - BYTE(1) { + BYTE { @Override public String toString() { @@ -371,7 +401,7 @@ void checkVectorMagnitude( } @Override - public double computeDotProduct(VectorData vectorData) { + public double computeSquaredMagnitude(VectorData vectorData) { return VectorUtil.dotProduct(vectorData.asByteVector(), vectorData.asByteVector()); } @@ -428,7 +458,7 @@ private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVec byte[] decodedVector = HexFormat.of().parseHex(context.parser().text()); fieldMapper.checkDimensionMatches(decodedVector.length, context); VectorData vectorData = VectorData.fromBytes(decodedVector); - double squaredMagnitude = computeDotProduct(vectorData); + double squaredMagnitude = computeSquaredMagnitude(vectorData); checkVectorMagnitude( fieldMapper.fieldType().similarity, errorByteElementsAppender(decodedVector), @@ -463,7 +493,7 @@ public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFie @Override int getNumBytes(int dimensions) { - return dimensions * elementBytes; + return dimensions; } @Override @@ -494,7 +524,7 @@ int parseDimensionCount(DocumentParserContext context) throws IOException { } }, - FLOAT(4) { + FLOAT { @Override public String toString() { @@ -596,7 +626,7 @@ void checkVectorMagnitude( } @Override - public double computeDotProduct(VectorData vectorData) { + public double computeSquaredMagnitude(VectorData vectorData) { return VectorUtil.dotProduct(vectorData.asFloatVector(), vectorData.asFloatVector()); } @@ -656,7 +686,7 @@ VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper @Override int getNumBytes(int dimensions) { - return dimensions * elementBytes; + return dimensions * Float.BYTES; } @Override @@ -665,13 +695,249 @@ ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) { ? ByteBuffer.wrap(new byte[numBytes]).order(ByteOrder.LITTLE_ENDIAN) : ByteBuffer.wrap(new byte[numBytes]); } - }; + }, - final int elementBytes; + BIT { - ElementType(int elementBytes) { - this.elementBytes = elementBytes; - } + @Override + public String toString() { + return "bit"; + } + + @Override + public void writeValue(ByteBuffer byteBuffer, float value) { + byteBuffer.put((byte) value); + } + + @Override + public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException { + b.value(byteBuffer.get()); + } + + private KnnByteVectorField createKnnVectorField(String name, byte[] vector, VectorSimilarityFunction function) { + if (vector == null) { + throw new IllegalArgumentException("vector value must not be null"); + } + FieldType denseVectorFieldType = new FieldType(); + denseVectorFieldType.setVectorAttributes(vector.length, VectorEncoding.BYTE, function); + denseVectorFieldType.freeze(); + return new KnnByteVectorField(name, vector, denseVectorFieldType); + } + + @Override + IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldType, FieldDataContext fieldDataContext) { + return new VectorIndexFieldData.Builder( + denseVectorFieldType.name(), + CoreValuesSourceType.KEYWORD, + denseVectorFieldType.indexVersionCreated, + this, + denseVectorFieldType.dims, + denseVectorFieldType.indexed, + r -> r + ); + } + + @Override + public void checkVectorBounds(float[] vector) { + checkNanAndInfinite(vector); + + StringBuilder errorBuilder = null; + + for (int index = 0; index < vector.length; ++index) { + float value = vector[index]; + + if (value % 1.0f != 0.0f) { + errorBuilder = new StringBuilder( + "element_type [" + + this + + "] vectors only support non-decimal values but found decimal value [" + + value + + "] at dim [" + + index + + "];" + ); + break; + } + + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + errorBuilder = new StringBuilder( + "element_type [" + + this + + "] vectors only support integers between [" + + Byte.MIN_VALUE + + ", " + + Byte.MAX_VALUE + + "] but found [" + + value + + "] at dim [" + + index + + "];" + ); + break; + } + } + + if (errorBuilder != null) { + throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString()); + } + } + + @Override + void checkVectorMagnitude( + VectorSimilarity similarity, + Function appender, + float squaredMagnitude + ) {} + + @Override + public double computeSquaredMagnitude(VectorData vectorData) { + int count = 0; + int i = 0; + byte[] byteBits = vectorData.asByteVector(); + for (int upperBound = byteBits.length & -8; i < upperBound; i += 8) { + count += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(byteBits, i)); + } + + while (i < byteBits.length) { + count += Integer.bitCount(byteBits[i] & 255); + ++i; + } + return count; + } + + private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { + int index = 0; + byte[] vector = new byte[fieldMapper.fieldType().dims / Byte.SIZE]; + for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser() + .nextToken()) { + fieldMapper.checkDimensionExceeded(index, context); + ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()); + final int value; + if (context.parser().numberType() != XContentParser.NumberType.INT) { + float floatValue = context.parser().floatValue(true); + if (floatValue % 1.0f != 0.0f) { + throw new IllegalArgumentException( + "element_type [" + + this + + "] vectors only support non-decimal values but found decimal value [" + + floatValue + + "] at dim [" + + index + + "];" + ); + } + value = (int) floatValue; + } else { + value = context.parser().intValue(true); + } + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throw new IllegalArgumentException( + "element_type [" + + this + + "] vectors only support integers between [" + + Byte.MIN_VALUE + + ", " + + Byte.MAX_VALUE + + "] but found [" + + value + + "] at dim [" + + index + + "];" + ); + } + if (index >= vector.length) { + throw new IllegalArgumentException( + "The number of dimensions for field [" + + fieldMapper.fieldType().name() + + "] should be [" + + fieldMapper.fieldType().dims + + "] but found [" + + (index + 1) * Byte.SIZE + + "]" + ); + } + vector[index++] = (byte) value; + } + fieldMapper.checkDimensionMatches(index * Byte.SIZE, context); + return VectorData.fromBytes(vector); + } + + private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { + byte[] decodedVector = HexFormat.of().parseHex(context.parser().text()); + fieldMapper.checkDimensionMatches(decodedVector.length * Byte.SIZE, context); + return VectorData.fromBytes(decodedVector); + } + + @Override + VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { + XContentParser.Token token = context.parser().currentToken(); + return switch (token) { + case START_ARRAY -> parseVectorArray(context, fieldMapper); + case VALUE_STRING -> parseHexEncodedVector(context, fieldMapper); + default -> throw new ParsingException( + context.parser().getTokenLocation(), + format("Unsupported type [%s] for provided value [%s]", token, context.parser().text()) + ); + }; + } + + @Override + public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { + VectorData vectorData = parseKnnVector(context, fieldMapper); + Field field = createKnnVectorField( + fieldMapper.fieldType().name(), + vectorData.asByteVector(), + fieldMapper.fieldType().similarity.vectorSimilarityFunction(fieldMapper.indexCreatedVersion, this) + ); + context.doc().addWithKey(fieldMapper.fieldType().name(), field); + } + + @Override + int getNumBytes(int dimensions) { + assert dimensions % Byte.SIZE == 0; + return dimensions / Byte.SIZE; + } + + @Override + ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) { + return ByteBuffer.wrap(new byte[numBytes]); + } + + @Override + int parseDimensionCount(DocumentParserContext context) throws IOException { + XContentParser.Token currentToken = context.parser().currentToken(); + return switch (currentToken) { + case START_ARRAY -> { + int index = 0; + for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { + index++; + } + yield index * Byte.SIZE; + } + case VALUE_STRING -> { + byte[] decodedVector = HexFormat.of().parseHex(context.parser().text()); + yield decodedVector.length * Byte.SIZE; + } + default -> throw new ParsingException( + context.parser().getTokenLocation(), + format("Unsupported type [%s] for provided value [%s]", currentToken, context.parser().text()) + ); + }; + } + + @Override + public void checkDimensions(int dvDims, int qvDims) { + if (dvDims != qvDims * Byte.SIZE) { + throw new IllegalArgumentException( + "The query vector has a different number of dimensions [" + + qvDims * Byte.SIZE + + "] than the document vectors [" + + dvDims + + "]." + ); + } + } + }; public abstract void writeValue(ByteBuffer byteBuffer, float value); @@ -695,6 +961,14 @@ abstract void checkVectorMagnitude( float squaredMagnitude ); + public void checkDimensions(int dvDims, int qvDims) { + if (dvDims != qvDims) { + throw new IllegalArgumentException( + "The query vector has a different number of dimensions [" + qvDims + "] than the document vectors [" + dvDims + "]." + ); + } + } + int parseDimensionCount(DocumentParserContext context) throws IOException { int index = 0; for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { @@ -775,7 +1049,7 @@ static Function errorByteElementsAppender(byte[] v return sb -> appendErrorElements(sb, vector); } - public abstract double computeDotProduct(VectorData vectorData); + public abstract double computeSquaredMagnitude(VectorData vectorData); public static ElementType fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); @@ -786,7 +1060,9 @@ public static ElementType fromString(String name) { ElementType.BYTE.toString(), ElementType.BYTE, ElementType.FLOAT.toString(), - ElementType.FLOAT + ElementType.FLOAT, + ElementType.BIT.toString(), + ElementType.BIT ); public enum VectorSimilarity { @@ -795,6 +1071,7 @@ public enum VectorSimilarity { float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { case BYTE, FLOAT -> 1f / (1f + similarity * similarity); + case BIT -> (dim - similarity) / dim; }; } @@ -806,8 +1083,10 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi COSINE { @Override float score(float similarity, ElementType elementType, int dim) { + assert elementType != ElementType.BIT; return switch (elementType) { case BYTE, FLOAT -> (1 + similarity) / 2f; + default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -824,6 +1103,7 @@ float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15)); case FLOAT -> (1 + similarity) / 2f; + default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -837,6 +1117,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1; + default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -863,7 +1144,7 @@ abstract static class IndexOptions implements ToXContent { this.type = type; } - abstract KnnVectorsFormat getVectorsFormat(); + abstract KnnVectorsFormat getVectorsFormat(ElementType elementType); boolean supportsElementType(ElementType elementType) { return true; @@ -1002,7 +1283,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - KnnVectorsFormat getVectorsFormat() { + KnnVectorsFormat getVectorsFormat(ElementType elementType) { + assert elementType == ElementType.FLOAT; return new ES813Int8FlatVectorFormat(confidenceInterval, 7, false); } @@ -1021,7 +1303,7 @@ public int hashCode() { @Override boolean supportsElementType(ElementType elementType) { - return elementType != ElementType.BYTE; + return elementType == ElementType.FLOAT; } @Override @@ -1047,7 +1329,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - KnnVectorsFormat getVectorsFormat() { + KnnVectorsFormat getVectorsFormat(ElementType elementType) { + if (elementType.equals(ElementType.BIT)) { + return new ES815BitFlatVectorFormat(); + } return new ES813FlatVectorFormat(); } @@ -1083,7 +1368,8 @@ static class Int4HnswIndexOptions extends IndexOptions { } @Override - public KnnVectorsFormat getVectorsFormat() { + public KnnVectorsFormat getVectorsFormat(ElementType elementType) { + assert elementType == ElementType.FLOAT; return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 4, true); } @@ -1126,7 +1412,7 @@ public String toString() { @Override boolean supportsElementType(ElementType elementType) { - return elementType != ElementType.BYTE; + return elementType == ElementType.FLOAT; } @Override @@ -1153,7 +1439,8 @@ static class Int4FlatIndexOptions extends IndexOptions { } @Override - public KnnVectorsFormat getVectorsFormat() { + public KnnVectorsFormat getVectorsFormat(ElementType elementType) { + assert elementType == ElementType.FLOAT; return new ES813Int8FlatVectorFormat(confidenceInterval, 4, true); } @@ -1186,7 +1473,7 @@ public String toString() { @Override boolean supportsElementType(ElementType elementType) { - return elementType != ElementType.BYTE; + return elementType == ElementType.FLOAT; } @Override @@ -1216,7 +1503,8 @@ static class Int8HnswIndexOptions extends IndexOptions { } @Override - public KnnVectorsFormat getVectorsFormat() { + public KnnVectorsFormat getVectorsFormat(ElementType elementType) { + assert elementType == ElementType.FLOAT; return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 7, false); } @@ -1261,7 +1549,7 @@ public String toString() { @Override boolean supportsElementType(ElementType elementType) { - return elementType != ElementType.BYTE; + return elementType == ElementType.FLOAT; } @Override @@ -1291,7 +1579,10 @@ static class HnswIndexOptions extends IndexOptions { } @Override - public KnnVectorsFormat getVectorsFormat() { + public KnnVectorsFormat getVectorsFormat(ElementType elementType) { + if (elementType == ElementType.BIT) { + return new ES815HnswBitVectorsFormat(m, efConstruction); + } return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null); } @@ -1412,48 +1703,6 @@ public Query termQuery(Object value, SearchExecutionContext context) { throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries"); } - public Query createKnnQuery( - byte[] queryVector, - int numCands, - Query filter, - Float similarityThreshold, - BitSetProducer parentFilter - ) { - if (isIndexed() == false) { - throw new IllegalArgumentException( - "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" - ); - } - - if (queryVector.length != dims) { - throw new IllegalArgumentException( - "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]" - ); - } - - if (elementType != ElementType.BYTE) { - throw new IllegalArgumentException( - "only [" + ElementType.BYTE + "] elements are supported when querying field [" + name() + "]" - ); - } - - if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { - float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); - elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); - } - Query knnQuery = parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter) - : new ESKnnByteVectorQuery(name(), queryVector, numCands, filter); - if (similarityThreshold != null) { - knnQuery = new VectorSimilarityQuery( - knnQuery, - similarityThreshold, - similarity.score(similarityThreshold, elementType, dims) - ); - } - return knnQuery; - } - public Query createExactKnnQuery(VectorData queryVector) { if (isIndexed() == false) { throw new IllegalArgumentException( @@ -1463,15 +1712,17 @@ public Query createExactKnnQuery(VectorData queryVector) { return switch (elementType) { case BYTE -> createExactKnnByteQuery(queryVector.asByteVector()); case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector()); + case BIT -> createExactKnnBitQuery(queryVector.asByteVector()); }; } + private Query createExactKnnBitQuery(byte[] queryVector) { + elementType.checkDimensions(dims, queryVector.length); + return new DenseVectorQuery.Bytes(queryVector, name()); + } + private Query createExactKnnByteQuery(byte[] queryVector) { - if (queryVector.length != dims) { - throw new IllegalArgumentException( - "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]" - ); - } + elementType.checkDimensions(dims, queryVector.length); if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); @@ -1480,11 +1731,7 @@ private Query createExactKnnByteQuery(byte[] queryVector) { } private Query createExactKnnFloatQuery(float[] queryVector) { - if (queryVector.length != dims) { - throw new IllegalArgumentException( - "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]" - ); - } + elementType.checkDimensions(dims, queryVector.length); elementType.checkVectorBounds(queryVector); if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); @@ -1521,21 +1768,39 @@ public Query createKnnQuery( return switch (getElementType()) { case BYTE -> createKnnByteQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter); case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), numCands, filter, similarityThreshold, parentFilter); + case BIT -> createKnnBitQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter); }; } - private Query createKnnByteQuery( + private Query createKnnBitQuery( byte[] queryVector, int numCands, Query filter, Float similarityThreshold, BitSetProducer parentFilter ) { - if (queryVector.length != dims) { - throw new IllegalArgumentException( - "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]" + elementType.checkDimensions(dims, queryVector.length); + Query knnQuery = parentFilter != null + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter) + : new ESKnnByteVectorQuery(name(), queryVector, numCands, filter); + if (similarityThreshold != null) { + knnQuery = new VectorSimilarityQuery( + knnQuery, + similarityThreshold, + similarity.score(similarityThreshold, elementType, dims) ); } + return knnQuery; + } + + private Query createKnnByteQuery( + byte[] queryVector, + int numCands, + Query filter, + Float similarityThreshold, + BitSetProducer parentFilter + ) { + elementType.checkDimensions(dims, queryVector.length); if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); @@ -1561,11 +1826,7 @@ private Query createKnnFloatQuery( Float similarityThreshold, BitSetProducer parentFilter ) { - if (queryVector.length != dims) { - throw new IllegalArgumentException( - "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]" - ); - } + elementType.checkDimensions(dims, queryVector.length); elementType.checkVectorBounds(queryVector); if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); @@ -1701,7 +1962,7 @@ private void parseBinaryDocValuesVectorAndIndex(DocumentParserContext context) t vectorData.addToBuffer(byteBuffer); if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) { // encode vector magnitude at the end - double dotProduct = elementType.computeDotProduct(vectorData); + double dotProduct = elementType.computeSquaredMagnitude(vectorData); float vectorMagnitude = (float) Math.sqrt(dotProduct); byteBuffer.putFloat(vectorMagnitude); } @@ -1780,9 +2041,9 @@ private static IndexOptions parseIndexOptions(String fieldName, Object propNode) public KnnVectorsFormat getKnnVectorsFormatForField(KnnVectorsFormat defaultFormat) { final KnnVectorsFormat format; if (indexOptions == null) { - format = defaultFormat; + format = fieldType().elementType == ElementType.BIT ? new ES815HnswBitVectorsFormat() : defaultFormat; } else { - format = indexOptions.getVectorsFormat(); + format = indexOptions.getVectorsFormat(fieldType().elementType); } // It's legal to reuse the same format name as this is the same on-disk format. return new KnnVectorsFormat(format.getName()) { diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java index d66b429e6dd95..f35ba3a0fd5b8 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java @@ -17,6 +17,8 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.script.field.DocValuesScriptFieldFactory; import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.BitBinaryDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.BitKnnDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.ByteBinaryDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.ByteKnnDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.KnnDenseVectorDocValuesField; @@ -58,12 +60,14 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { return switch (elementType) { case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims); + case BIT -> new BitKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); }; } else { BinaryDocValues values = DocValues.getBinary(reader, field); return switch (elementType) { case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims); case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); + case BIT -> new BitBinaryDenseVectorDocValuesField(values, name, elementType, dims); }; } } catch (IOException e) { diff --git a/server/src/main/java/org/elasticsearch/index/query/functionscore/ExponentialDecayFunctionBuilder.java b/server/src/main/java/org/elasticsearch/index/query/functionscore/ExponentialDecayFunctionBuilder.java index 2c361fe025dfa..ca6dfa5ef6e51 100644 --- a/server/src/main/java/org/elasticsearch/index/query/functionscore/ExponentialDecayFunctionBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/functionscore/ExponentialDecayFunctionBuilder.java @@ -76,10 +76,7 @@ public int hashCode() { @Override public boolean equals(Object obj) { - if (super.equals(obj)) { - return true; - } - return obj != null && getClass() != obj.getClass(); + return obj == this || (obj != null && obj.getClass() == this.getClass()); } } diff --git a/server/src/main/java/org/elasticsearch/index/query/functionscore/GaussDecayFunctionBuilder.java b/server/src/main/java/org/elasticsearch/index/query/functionscore/GaussDecayFunctionBuilder.java index 4415c87e9815e..1cc9335b5963e 100644 --- a/server/src/main/java/org/elasticsearch/index/query/functionscore/GaussDecayFunctionBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/functionscore/GaussDecayFunctionBuilder.java @@ -83,10 +83,7 @@ public int hashCode() { @Override public boolean equals(Object obj) { - if (super.equals(obj)) { - return true; - } - return obj != null && getClass() != obj.getClass(); + return obj == this || (obj != null && obj.getClass() == this.getClass()); } } } diff --git a/server/src/main/java/org/elasticsearch/index/query/functionscore/LinearDecayFunctionBuilder.java b/server/src/main/java/org/elasticsearch/index/query/functionscore/LinearDecayFunctionBuilder.java index ff22e1d57f832..70c3c4458a217 100644 --- a/server/src/main/java/org/elasticsearch/index/query/functionscore/LinearDecayFunctionBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/functionscore/LinearDecayFunctionBuilder.java @@ -86,10 +86,7 @@ public int hashCode() { @Override public boolean equals(Object obj) { - if (super.equals(obj)) { - return true; - } - return obj != null && getClass() != obj.getClass(); + return obj == this || (obj != null && obj.getClass() == this.getClass()); } } } diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index b3f19b1b7a81d..881f4602be1c7 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -2230,10 +2230,19 @@ public RecoveryState recoveryState() { @Override public ShardLongFieldRange getTimestampRange() { + return determineShardLongFieldRange(DataStream.TIMESTAMP_FIELD_NAME); + } + + @Override + public ShardLongFieldRange getEventIngestedRange() { + return determineShardLongFieldRange(IndexMetadata.EVENT_INGESTED_FIELD_NAME); + } + + private ShardLongFieldRange determineShardLongFieldRange(String fieldName) { if (mapperService() == null) { return ShardLongFieldRange.UNKNOWN; // no mapper service, no idea if the field even exists } - final MappedFieldType mappedFieldType = mapperService().fieldType(DataStream.TIMESTAMP_FIELD_NAME); + final MappedFieldType mappedFieldType = mapperService().fieldType(fieldName); if (mappedFieldType instanceof DateFieldMapper.DateFieldType == false) { return ShardLongFieldRange.UNKNOWN; // field missing or not a date } @@ -2243,10 +2252,10 @@ public ShardLongFieldRange getTimestampRange() { final ShardLongFieldRange rawTimestampFieldRange; try { - rawTimestampFieldRange = getEngine().getRawFieldRange(DataStream.TIMESTAMP_FIELD_NAME); + rawTimestampFieldRange = getEngine().getRawFieldRange(fieldName); assert rawTimestampFieldRange != null; } catch (IOException | AlreadyClosedException e) { - logger.debug("exception obtaining range for timestamp field", e); + logger.debug("exception obtaining range for field " + fieldName, e); return ShardLongFieldRange.UNKNOWN; } if (rawTimestampFieldRange == ShardLongFieldRange.UNKNOWN) { @@ -3337,7 +3346,7 @@ private void executeRecovery( markAsRecovering(reason, recoveryState); // mark the shard as recovering on the cluster state thread threadPool.generic().execute(ActionRunnable.wrap(ActionListener.wrap(r -> { if (r) { - recoveryListener.onRecoveryDone(recoveryState, getTimestampRange()); + recoveryListener.onRecoveryDone(recoveryState, getTimestampRange(), getEventIngestedRange()); } }, e -> recoveryListener.onRecoveryFailure(new RecoveryFailedException(recoveryState, null, e), true)), action)); } diff --git a/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java b/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java index d409c3009ef5b..dd5ad26c58b12 100644 --- a/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java +++ b/server/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java @@ -46,6 +46,7 @@ import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.env.ShardLockObtainFailedException; import org.elasticsearch.gateway.GatewayService; @@ -417,7 +418,10 @@ protected void doRun() throws Exception { // lock is released so it's guaranteed to be deleted by the time we get the lock indicesService.processPendingDeletes(index, indexSettings, timeout); } catch (ShardLockObtainFailedException exc) { - logger.warn("[{}] failed to lock all shards for index - timed out after [{}]]", index, timeout); + logger.warn( + Strings.format("[%s] failed to lock all shards for index - timed out after [%s]]", index, timeout), + exc + ); } catch (InterruptedException e) { logger.warn("[{}] failed to lock all shards for index - interrupted", index); } @@ -905,6 +909,7 @@ private void updateShard(ShardRouting shardRouting, Shard shard, ClusterState cl + state + "], mark shard as started", shard.getTimestampRange(), + shard.getEventIngestedRange(), ActionListener.noop(), clusterState ); @@ -966,12 +971,17 @@ private RecoveryListener(final ShardRouting shardRouting, final long primaryTerm } @Override - public void onRecoveryDone(final RecoveryState state, ShardLongFieldRange timestampMillisFieldRange) { + public void onRecoveryDone( + final RecoveryState state, + ShardLongFieldRange timestampMillisFieldRange, + ShardLongFieldRange eventIngestedMillisFieldRange + ) { shardStateAction.shardStarted( shardRouting, primaryTerm, "after " + state.getRecoverySource(), timestampMillisFieldRange, + eventIngestedMillisFieldRange, ActionListener.noop() ); } @@ -1123,6 +1133,13 @@ public interface Shard { @Nullable ShardLongFieldRange getTimestampRange(); + /** + * @return the range of the {@code @event.ingested} field for this shard, or {@link ShardLongFieldRange#EMPTY} if this field is not + * found, or {@link ShardLongFieldRange#UNKNOWN} if its range is not fixed. + */ + @Nullable + ShardLongFieldRange getEventIngestedRange(); + /** * Updates the shard state based on an incoming cluster state: * - Updates and persists the new routing value. diff --git a/server/src/main/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetService.java b/server/src/main/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetService.java index 3447cc73a4288..ac618ac9308c4 100644 --- a/server/src/main/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetService.java +++ b/server/src/main/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetService.java @@ -517,7 +517,11 @@ private static void logGlobalCheckpointWarning(Logger logger, long startingSeqNo } public interface RecoveryListener { - void onRecoveryDone(RecoveryState state, ShardLongFieldRange timestampMillisFieldRange); + void onRecoveryDone( + RecoveryState state, + ShardLongFieldRange timestampMillisFieldRange, + ShardLongFieldRange eventIngestedMillisFieldRange + ); void onRecoveryFailure(RecoveryFailedException e, boolean sendShardFailure); } diff --git a/server/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java b/server/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java index dda7203fa7b0e..3232099831d8b 100644 --- a/server/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java +++ b/server/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java @@ -323,7 +323,7 @@ public void markAsDone() { indexShard.postRecovery("peer recovery done", ActionListener.runBefore(new ActionListener<>() { @Override public void onResponse(Void unused) { - listener.onRecoveryDone(state(), indexShard.getTimestampRange()); + listener.onRecoveryDone(state(), indexShard.getTimestampRange(), indexShard.getEventIngestedRange()); } @Override diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchResponseMetrics.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchResponseMetrics.java index 00f1f5d5804d6..2e2be59689b65 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchResponseMetrics.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchResponseMetrics.java @@ -8,14 +8,37 @@ package org.elasticsearch.rest.action.search; +import org.elasticsearch.telemetry.metric.LongCounter; import org.elasticsearch.telemetry.metric.LongHistogram; import org.elasticsearch.telemetry.metric.MeterRegistry; +import java.util.Map; + public class SearchResponseMetrics { + public enum ResponseCountTotalStatus { + SUCCESS("succes"), + PARTIAL_FAILURE("partial_failure"), + FAILURE("failure"); + + private final String displayName; + + ResponseCountTotalStatus(String displayName) { + this.displayName = displayName; + } + + public String getDisplayName() { + return displayName; + } + } + + public static final String RESPONSE_COUNT_TOTAL_STATUS_ATTRIBUTE_NAME = "status"; + public static final String TOOK_DURATION_TOTAL_HISTOGRAM_NAME = "es.search_response.took_durations.histogram"; + public static final String RESPONSE_COUNT_TOTAL_COUNTER_NAME = "es.search_response.response_count.total"; private final LongHistogram tookDurationTotalMillisHistogram; + private final LongCounter responseCountTotalCounter; public SearchResponseMetrics(MeterRegistry meterRegistry) { this( @@ -23,16 +46,31 @@ public SearchResponseMetrics(MeterRegistry meterRegistry) { TOOK_DURATION_TOTAL_HISTOGRAM_NAME, "The SearchResponse.took durations in milliseconds, expressed as a histogram", "millis" + ), + meterRegistry.registerLongCounter( + RESPONSE_COUNT_TOTAL_COUNTER_NAME, + "The cumulative total of search responses with an attribute to describe " + + "success, partial failure, or failure, expressed as a single total counter and individual " + + "attribute counters", + "count" ) ); } - private SearchResponseMetrics(LongHistogram tookDurationTotalMillisHistogram) { + private SearchResponseMetrics(LongHistogram tookDurationTotalMillisHistogram, LongCounter responseCountTotalCounter) { this.tookDurationTotalMillisHistogram = tookDurationTotalMillisHistogram; + this.responseCountTotalCounter = responseCountTotalCounter; } public long recordTookTime(long tookTime) { tookDurationTotalMillisHistogram.record(tookTime); return tookTime; } + + public void incrementResponseCount(ResponseCountTotalStatus responseCountTotalStatus) { + responseCountTotalCounter.incrementBy( + 1L, + Map.of(RESPONSE_COUNT_TOTAL_STATUS_ATTRIBUTE_NAME, responseCountTotalStatus.getDisplayName()) + ); + } } diff --git a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java index bccdd5782f277..ad7d74824a1d4 100644 --- a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java +++ b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java @@ -56,7 +56,7 @@ public static class ByteDenseVectorFunction extends DenseVectorFunction { */ public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List queryVector) { super(scoreScript, field); - DenseVector.checkDimensions(field.get().getDims(), queryVector.size()); + field.getElementType().checkDimensions(field.get().getDims(), queryVector.size()); this.queryVector = new byte[queryVector.size()]; float[] validateValues = new float[queryVector.size()]; int queryMagnitude = 0; @@ -168,7 +168,7 @@ public static final class L1Norm { public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); function = switch (field.getElementType()) { - case BYTE -> { + case BYTE, BIT -> { if (queryVector instanceof List) { yield new ByteL1Norm(scoreScript, field, (List) queryVector); } else if (queryVector instanceof String s) { @@ -219,8 +219,8 @@ public static final class Hamming { @SuppressWarnings("unchecked") public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) { DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); - if (field.getElementType() != DenseVectorFieldMapper.ElementType.BYTE) { - throw new IllegalArgumentException("hamming distance is only supported for byte vectors"); + if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) { + throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors"); } if (queryVector instanceof List) { function = new ByteHammingDistance(scoreScript, field, (List) queryVector); @@ -278,7 +278,7 @@ public static final class L2Norm { public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); function = switch (field.getElementType()) { - case BYTE -> { + case BYTE, BIT -> { if (queryVector instanceof List) { yield new ByteL2Norm(scoreScript, field, (List) queryVector); } else if (queryVector instanceof String s) { @@ -342,7 +342,7 @@ public static final class DotProduct { public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) { DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); function = switch (field.getElementType()) { - case BYTE -> { + case BYTE, BIT -> { if (queryVector instanceof List) { yield new ByteDotProduct(scoreScript, field, (List) queryVector); } else if (queryVector instanceof String s) { @@ -406,7 +406,7 @@ public static final class CosineSimilarity { public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fieldName) { DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); function = switch (field.getElementType()) { - case BYTE -> { + case BYTE, BIT -> { if (queryVector instanceof List) { yield new ByteCosineSimilarity(scoreScript, field, (List) queryVector); } else if (queryVector instanceof String s) { diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitBinaryDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitBinaryDenseVector.java new file mode 100644 index 0000000000000..10420543ad181 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitBinaryDenseVector.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.util.BytesRef; + +import java.util.List; + +public class BitBinaryDenseVector extends ByteBinaryDenseVector { + + public BitBinaryDenseVector(byte[] vectorValue, BytesRef docVector, int dims) { + super(vectorValue, docVector, dims); + } + + @Override + public void checkDimensions(int qvDims) { + if (qvDims != dims) { + throw new IllegalArgumentException( + "The query vector has a different number of dimensions [" + + qvDims * Byte.SIZE + + "] than the document vectors [" + + dims * Byte.SIZE + + "]." + ); + } + } + + @Override + public int l1Norm(byte[] queryVector) { + return hamming(queryVector); + } + + @Override + public double l1Norm(List queryVector) { + return hamming(queryVector); + } + + @Override + public double l2Norm(byte[] queryVector) { + return Math.sqrt(hamming(queryVector)); + } + + @Override + public double l2Norm(List queryVector) { + return Math.sqrt(hamming(queryVector)); + } + + @Override + public int dotProduct(byte[] queryVector) { + throw new UnsupportedOperationException("dotProduct is not supported for bit vectors."); + } + + @Override + public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) { + throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors."); + } + + @Override + public double dotProduct(List queryVector) { + throw new UnsupportedOperationException("dotProduct is not supported for bit vectors."); + } + + @Override + public double cosineSimilarity(byte[] queryVector, float qvMagnitude) { + throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors."); + } + + @Override + public double cosineSimilarity(List queryVector) { + throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors."); + } + + @Override + public double dotProduct(float[] queryVector) { + throw new UnsupportedOperationException("dotProduct is not supported for bit vectors."); + } + + @Override + public int getDims() { + return dims * Byte.SIZE; + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitBinaryDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitBinaryDenseVectorDocValuesField.java new file mode 100644 index 0000000000000..cb123c54dfecf --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitBinaryDenseVectorDocValuesField.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.index.BinaryDocValues; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; + +public class BitBinaryDenseVectorDocValuesField extends ByteBinaryDenseVectorDocValuesField { + + public BitBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) { + super(input, name, elementType, dims / 8); + } + + @Override + protected DenseVector getVector() { + return new BitBinaryDenseVector(vectorValue, value, dims); + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitKnnDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitKnnDenseVector.java new file mode 100644 index 0000000000000..ce9d990c75851 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitKnnDenseVector.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.script.field.vectors; + +import java.util.List; + +public class BitKnnDenseVector extends ByteKnnDenseVector { + + public BitKnnDenseVector(byte[] vector) { + super(vector); + } + + @Override + public void checkDimensions(int qvDims) { + if (qvDims != docVector.length) { + throw new IllegalArgumentException( + "The query vector has a different number of dimensions [" + + qvDims * Byte.SIZE + + "] than the document vectors [" + + docVector.length * Byte.SIZE + + "]." + ); + } + } + + @Override + public float getMagnitude() { + if (magnitudeCalculated == false) { + magnitude = DenseVector.getBitMagnitude(docVector, docVector.length); + magnitudeCalculated = true; + } + return magnitude; + } + + @Override + public int l1Norm(byte[] queryVector) { + return hamming(queryVector); + } + + @Override + public double l1Norm(List queryVector) { + return hamming(queryVector); + } + + @Override + public double l2Norm(byte[] queryVector) { + return Math.sqrt(hamming(queryVector)); + } + + @Override + public double l2Norm(List queryVector) { + return Math.sqrt(hamming(queryVector)); + } + + @Override + public int dotProduct(byte[] queryVector) { + throw new UnsupportedOperationException("dotProduct is not supported for bit vectors."); + } + + @Override + public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) { + throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors."); + } + + @Override + public double dotProduct(List queryVector) { + throw new UnsupportedOperationException("dotProduct is not supported for bit vectors."); + } + + @Override + public double cosineSimilarity(byte[] queryVector, float qvMagnitude) { + throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors."); + } + + @Override + public double cosineSimilarity(List queryVector) { + throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors."); + } + + @Override + public double dotProduct(float[] queryVector) { + throw new UnsupportedOperationException("dotProduct is not supported for bit vectors."); + } + + @Override + public int getDims() { + return docVector.length * Byte.SIZE; + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitKnnDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitKnnDenseVectorDocValuesField.java new file mode 100644 index 0000000000000..10421d992727e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitKnnDenseVectorDocValuesField.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.index.ByteVectorValues; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; + +public class BitKnnDenseVectorDocValuesField extends ByteKnnDenseVectorDocValuesField { + + public BitKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) { + super(input, name, dims / 8, DenseVectorFieldMapper.ElementType.BIT); + } + + @Override + protected DenseVector getVector() { + return new BitKnnDenseVector(vector); + } + +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java index c009397452c8a..f2ff8fbccd2fb 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java @@ -21,7 +21,7 @@ public class ByteBinaryDenseVector implements DenseVector { private final BytesRef docVector; private final byte[] vectorValue; - private final int dims; + protected final int dims; private float[] floatDocVector; private boolean magnitudeDecoded; diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVectorDocValuesField.java index b767cd72c4341..c7ce8cd5e937f 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVectorDocValuesField.java @@ -17,11 +17,11 @@ public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesField { - private final BinaryDocValues input; - private final int dims; - private final byte[] vectorValue; - private boolean decoded; - private BytesRef value; + protected final BinaryDocValues input; + protected final int dims; + protected final byte[] vectorValue; + protected boolean decoded; + protected BytesRef value; public ByteBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) { super(name, elementType); @@ -50,13 +50,17 @@ public boolean isEmpty() { return value == null; } + protected DenseVector getVector() { + return new ByteBinaryDenseVector(vectorValue, value, dims); + } + @Override public DenseVector get() { if (isEmpty()) { return DenseVector.EMPTY; } decodeVectorIfNecessary(); - return new ByteBinaryDenseVector(vectorValue, value, dims); + return getVector(); } @Override @@ -65,7 +69,7 @@ public DenseVector get(DenseVector defaultValue) { return defaultValue; } decodeVectorIfNecessary(); - return new ByteBinaryDenseVector(vectorValue, value, dims); + return getVector(); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVectorDocValuesField.java index a2a9ba1c1d750..a41e166d1d8f3 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVectorDocValuesField.java @@ -23,7 +23,11 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField protected final int dims; public ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) { - super(name, ElementType.BYTE); + this(input, name, dims, ElementType.BYTE); + } + + protected ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims, ElementType elementType) { + super(name, elementType); this.dims = dims; this.input = input; } @@ -57,13 +61,17 @@ public boolean isEmpty() { return vector == null; } + protected DenseVector getVector() { + return new ByteKnnDenseVector(vector); + } + @Override public DenseVector get() { if (isEmpty()) { return DenseVector.EMPTY; } - return new ByteKnnDenseVector(vector); + return getVector(); } @Override @@ -72,7 +80,7 @@ public DenseVector get(DenseVector defaultValue) { return defaultValue; } - return new ByteKnnDenseVector(vector); + return getVector(); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/DenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/DenseVector.java index a768e8add6663..d93daecf695a8 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/DenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/DenseVector.java @@ -8,6 +8,7 @@ package org.elasticsearch.script.field.vectors; +import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.VectorUtil; import java.util.List; @@ -25,6 +26,10 @@ class of the argument and checks dimensionality. */ public interface DenseVector { + default void checkDimensions(int qvDims) { + checkDimensions(getDims(), qvDims); + } + float[] getVector(); float getMagnitude(); @@ -38,13 +43,13 @@ public interface DenseVector { @SuppressWarnings("unchecked") default double dotProduct(Object queryVector) { if (queryVector instanceof float[] floats) { - checkDimensions(getDims(), floats.length); + checkDimensions(floats.length); return dotProduct(floats); } else if (queryVector instanceof List list) { - checkDimensions(getDims(), list.size()); + checkDimensions(list.size()); return dotProduct((List) list); } else if (queryVector instanceof byte[] bytes) { - checkDimensions(getDims(), bytes.length); + checkDimensions(bytes.length); return dotProduct(bytes); } @@ -60,13 +65,13 @@ default double dotProduct(Object queryVector) { @SuppressWarnings("unchecked") default double l1Norm(Object queryVector) { if (queryVector instanceof float[] floats) { - checkDimensions(getDims(), floats.length); + checkDimensions(floats.length); return l1Norm(floats); } else if (queryVector instanceof List list) { - checkDimensions(getDims(), list.size()); + checkDimensions(list.size()); return l1Norm((List) list); } else if (queryVector instanceof byte[] bytes) { - checkDimensions(getDims(), bytes.length); + checkDimensions(bytes.length); return l1Norm(bytes); } @@ -80,11 +85,11 @@ default double l1Norm(Object queryVector) { @SuppressWarnings("unchecked") default int hamming(Object queryVector) { if (queryVector instanceof List list) { - checkDimensions(getDims(), list.size()); + checkDimensions(list.size()); return hamming((List) list); } if (queryVector instanceof byte[] bytes) { - checkDimensions(getDims(), bytes.length); + checkDimensions(bytes.length); return hamming(bytes); } @@ -100,13 +105,13 @@ default int hamming(Object queryVector) { @SuppressWarnings("unchecked") default double l2Norm(Object queryVector) { if (queryVector instanceof float[] floats) { - checkDimensions(getDims(), floats.length); + checkDimensions(floats.length); return l2Norm(floats); } else if (queryVector instanceof List list) { - checkDimensions(getDims(), list.size()); + checkDimensions(list.size()); return l2Norm((List) list); } else if (queryVector instanceof byte[] bytes) { - checkDimensions(getDims(), bytes.length); + checkDimensions(bytes.length); return l2Norm(bytes); } @@ -150,13 +155,13 @@ default double cosineSimilarity(float[] queryVector) { @SuppressWarnings("unchecked") default double cosineSimilarity(Object queryVector) { if (queryVector instanceof float[] floats) { - checkDimensions(getDims(), floats.length); + checkDimensions(floats.length); return cosineSimilarity(floats); } else if (queryVector instanceof List list) { - checkDimensions(getDims(), list.size()); + checkDimensions(list.size()); return cosineSimilarity((List) list); } else if (queryVector instanceof byte[] bytes) { - checkDimensions(getDims(), bytes.length); + checkDimensions(bytes.length); return cosineSimilarity(bytes); } @@ -184,6 +189,20 @@ static float getMagnitude(byte[] vector, int dims) { return (float) Math.sqrt(mag); } + static float getBitMagnitude(byte[] vector, int dims) { + int count = 0; + int i = 0; + for (int upperBound = dims & -8; i < upperBound; i += 8) { + count += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(vector, i)); + } + + while (i < dims) { + count += Integer.bitCount(vector[i] & 255); + ++i; + } + return (float) Math.sqrt(count); + } + static float getMagnitude(float[] vector) { return (float) Math.sqrt(VectorUtil.dotProduct(vector, vector)); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java b/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java index 795f51a729ed6..0cebf3d79d754 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java @@ -316,7 +316,10 @@ protected void doClose() {} protected void doPostCollection() throws IOException {} protected final InternalAggregations buildEmptySubAggregations() { - List aggs = new ArrayList<>(); + if (subAggregators.length == 0) { + return InternalAggregations.EMPTY; + } + List aggs = new ArrayList<>(subAggregators.length); for (Aggregator aggregator : subAggregators) { aggs.add(aggregator.buildEmptyAggregation()); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java b/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java index b65f6b01de348..07e72404eefe9 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java @@ -34,7 +34,6 @@ import java.util.Optional; import java.util.stream.Collectors; -import static java.util.Collections.unmodifiableList; import static java.util.Collections.unmodifiableMap; import static org.elasticsearch.common.xcontent.XContentParserUtils.parseTypedKeysObject; @@ -71,7 +70,7 @@ public Iterator iterator() { * The list of {@link InternalAggregation}s. */ public List asList() { - return unmodifiableList(aggregations); + return aggregations; } /** @@ -263,7 +262,7 @@ public static InternalAggregations reduce(List aggregation } // handle special case when there is just one aggregation if (aggregationsList.size() == 1) { - final List internalAggregations = aggregationsList.iterator().next().asList(); + final List internalAggregations = aggregationsList.get(0).asList(); final List reduced = new ArrayList<>(internalAggregations.size()); for (InternalAggregation aggregation : internalAggregations) { if (aggregation.mustReduceOnSingleInternalAgg()) { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregation.java b/server/src/main/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregation.java index 6dd691bbf5aaa..de19c26daff92 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregation.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregation.java @@ -132,7 +132,7 @@ static Object resolvePropertyFromPath(List path, List aggs = new ArrayList<>(); - for (InternalAggregation agg : getAggregations().asList()) { + for (InternalAggregation agg : getAggregations()) { PipelineTree subTree = pipelineTree.subTree(agg.getName()); aggs.add(agg.reducePipelines(agg, reduceContext, subTree)); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/histogram/InternalDateHistogram.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/histogram/InternalDateHistogram.java index e75b2d2002b0f..e0de42cebcc7d 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/histogram/InternalDateHistogram.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/histogram/InternalDateHistogram.java @@ -327,7 +327,6 @@ public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) } private List reduceBuckets(final PriorityQueue> pq, AggregationReduceContext reduceContext) { - int consumeBucketCount = 0; List reducedBuckets = new ArrayList<>(); if (pq.size() > 0) { // list of buckets coming from different shards that have the same key @@ -340,13 +339,7 @@ private List reduceBuckets(final PriorityQueue= minDocCount || reduceContext.isFinalReduce() == false) { - if (consumeBucketCount++ >= REPORT_EMPTY_EVERY) { - reduceContext.consumeBucketsAndMaybeBreak(consumeBucketCount); - consumeBucketCount = 0; - } - reducedBuckets.add(reduced); - } + maybeAddBucket(reduceContext, reducedBuckets, reduced); currentBuckets.clear(); key = top.current().key; } @@ -364,19 +357,21 @@ private List reduceBuckets(final PriorityQueue= minDocCount || reduceContext.isFinalReduce() == false) { - reducedBuckets.add(reduced); - if (consumeBucketCount++ >= REPORT_EMPTY_EVERY) { - reduceContext.consumeBucketsAndMaybeBreak(consumeBucketCount); - consumeBucketCount = 0; - } - } + maybeAddBucket(reduceContext, reducedBuckets, reduced); } } - reduceContext.consumeBucketsAndMaybeBreak(consumeBucketCount); return reducedBuckets; } + private void maybeAddBucket(AggregationReduceContext reduceContext, List reducedBuckets, Bucket reduced) { + if (reduced.getDocCount() >= minDocCount || reduceContext.isFinalReduce() == false) { + reduceContext.consumeBucketsAndMaybeBreak(1); + reducedBuckets.add(reduced); + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(reduced)); + } + } + /** * Reduce a list of same-keyed buckets (from multiple shards) to a single bucket. This * requires all buckets to have the same key. diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/histogram/InternalHistogram.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/histogram/InternalHistogram.java index 7b264ccb022e5..098bd5ebc7b3d 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/histogram/InternalHistogram.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/histogram/InternalHistogram.java @@ -291,7 +291,6 @@ public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) } private List reduceBuckets(PriorityQueue> pq, AggregationReduceContext reduceContext) { - int consumeBucketCount = 0; List reducedBuckets = new ArrayList<>(); if (pq.size() > 0) { // list of buckets coming from different shards that have the same key @@ -305,13 +304,7 @@ private List reduceBuckets(PriorityQueue> pq, // The key changes, reduce what we already buffered and reset the buffer for current buckets. // Using Double.compare instead of != to handle NaN correctly. final Bucket reduced = reduceBucket(currentBuckets, reduceContext); - if (reduced.getDocCount() >= minDocCount || reduceContext.isFinalReduce() == false) { - reducedBuckets.add(reduced); - if (consumeBucketCount++ >= REPORT_EMPTY_EVERY) { - reduceContext.consumeBucketsAndMaybeBreak(consumeBucketCount); - consumeBucketCount = 0; - } - } + maybeAddBucket(reduceContext, reducedBuckets, reduced); currentBuckets.clear(); key = top.current().key; } @@ -329,20 +322,21 @@ private List reduceBuckets(PriorityQueue> pq, if (currentBuckets.isEmpty() == false) { final Bucket reduced = reduceBucket(currentBuckets, reduceContext); - if (reduced.getDocCount() >= minDocCount || reduceContext.isFinalReduce() == false) { - reducedBuckets.add(reduced); - if (consumeBucketCount++ >= REPORT_EMPTY_EVERY) { - reduceContext.consumeBucketsAndMaybeBreak(consumeBucketCount); - consumeBucketCount = 0; - } - } + maybeAddBucket(reduceContext, reducedBuckets, reduced); } } - - reduceContext.consumeBucketsAndMaybeBreak(consumeBucketCount); return reducedBuckets; } + private void maybeAddBucket(AggregationReduceContext reduceContext, List reducedBuckets, Bucket reduced) { + if (reduced.getDocCount() >= minDocCount || reduceContext.isFinalReduce() == false) { + reduceContext.consumeBucketsAndMaybeBreak(1); + reducedBuckets.add(reduced); + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(reduced)); + } + } + private Bucket reduceBucket(List buckets, AggregationReduceContext context) { assert buckets.isEmpty() == false; try (BucketReducer reducer = new BucketReducer<>(buckets.get(0), context, buckets.size())) { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java index 7e749b06442f6..38bcc912c29d4 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.Map; import static java.util.Collections.singletonList; @@ -146,9 +147,11 @@ private State aggStateForResult(long owningBucketOrdinal) { return state; } + private static final List NULL_ITEM_LIST = singletonList(null); + @Override public InternalAggregation buildEmptyAggregation() { - return new InternalScriptedMetric(name, singletonList(null), reduceScript, metadata()); + return new InternalScriptedMetric(name, NULL_ITEM_LIST, reduceScript, metadata()); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/fetch/FetchContext.java b/server/src/main/java/org/elasticsearch/search/fetch/FetchContext.java index 65d49f771a045..85ce8a9fdc5d0 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/FetchContext.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/FetchContext.java @@ -189,7 +189,7 @@ public FetchDocValuesContext docValuesContext() { searchContext.getSearchExecutionContext(), Collections.singletonList(new FieldAndFormat(name, null)) ); - } else if (searchContext.docValuesContext().fields().stream().map(ff -> ff.field).anyMatch(name::equals) == false) { + } else if (searchContext.docValuesContext().fields().stream().map(ff -> ff.field).noneMatch(name::equals)) { dvContext.fields().add(new FieldAndFormat(name, null)); } } diff --git a/server/src/main/java/org/elasticsearch/search/fetch/subphase/FetchDocValuesContext.java b/server/src/main/java/org/elasticsearch/search/fetch/subphase/FetchDocValuesContext.java index 2ae7f6d07bbb9..6a1e071c48269 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/subphase/FetchDocValuesContext.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/subphase/FetchDocValuesContext.java @@ -10,6 +10,7 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.query.SearchExecutionContext; +import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; @@ -24,7 +25,7 @@ */ public class FetchDocValuesContext { - private final Collection fields; + private final List fields; /** * Create a new FetchDocValuesContext using the provided input list. @@ -40,7 +41,7 @@ public FetchDocValuesContext(SearchExecutionContext searchExecutionContext, List fieldToFormats.put(fieldName, new FieldAndFormat(fieldName, field.format, field.includeUnmapped)); } } - this.fields = fieldToFormats.values(); + this.fields = new ArrayList<>(fieldToFormats.values()); int maxAllowedDocvalueFields = searchExecutionContext.getIndexSettings().getMaxDocvalueFields(); if (fields.size() > maxAllowedDocvalueFields) { throw new IllegalArgumentException( @@ -58,7 +59,7 @@ public FetchDocValuesContext(SearchExecutionContext searchExecutionContext, List /** * Returns the required docvalue fields. */ - public Collection fields() { + public List fields() { return this.fields; } } diff --git a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java index 53cc5beaf2e40..8265ea5c2f584 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.Level; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotRequest; import org.elasticsearch.action.support.IndicesOptions; @@ -1331,7 +1332,11 @@ public ClusterState execute(ClusterState currentState) { ensureValidIndexName(currentState, snapshotIndexMetadata, renamedIndexName); shardLimitValidator.validateShardLimit(snapshotIndexMetadata.getSettings(), currentState); - final IndexMetadata.Builder indexMdBuilder = restoreToCreateNewIndex(snapshotIndexMetadata, renamedIndexName); + final IndexMetadata.Builder indexMdBuilder = restoreToCreateNewIndex( + snapshotIndexMetadata, + renamedIndexName, + currentState.getMinTransportVersion() + ); if (request.includeAliases() == false && snapshotIndexMetadata.getAliases().isEmpty() == false && isSystemIndex(snapshotIndexMetadata) == false) { @@ -1349,7 +1354,11 @@ && isSystemIndex(snapshotIndexMetadata) == false) { } else { // Index exists and it's closed - open it in metadata and start recovery validateExistingClosedIndex(currentIndexMetadata, snapshotIndexMetadata, renamedIndexName, partial); - final IndexMetadata.Builder indexMdBuilder = restoreOverClosedIndex(snapshotIndexMetadata, currentIndexMetadata); + final IndexMetadata.Builder indexMdBuilder = restoreOverClosedIndex( + snapshotIndexMetadata, + currentIndexMetadata, + currentState.getMinTransportVersion() + ); if (request.includeAliases() == false && isSystemIndex(snapshotIndexMetadata) == false) { // Remove all snapshot aliases @@ -1726,17 +1735,26 @@ private static IndexMetadata convertLegacyIndex( return convertedIndexMetadataBuilder.build(); } - private static IndexMetadata.Builder restoreToCreateNewIndex(IndexMetadata snapshotIndexMetadata, String renamedIndexName) { + private static IndexMetadata.Builder restoreToCreateNewIndex( + IndexMetadata snapshotIndexMetadata, + String renamedIndexName, + TransportVersion minClusterTransportVersion + ) { return IndexMetadata.builder(snapshotIndexMetadata) .state(IndexMetadata.State.OPEN) .index(renamedIndexName) .settings( Settings.builder().put(snapshotIndexMetadata.getSettings()).put(IndexMetadata.SETTING_INDEX_UUID, UUIDs.randomBase64UUID()) ) - .timestampRange(IndexLongFieldRange.NO_SHARDS); + .timestampRange(IndexLongFieldRange.NO_SHARDS) + .eventIngestedRange(IndexLongFieldRange.NO_SHARDS, minClusterTransportVersion); } - private static IndexMetadata.Builder restoreOverClosedIndex(IndexMetadata snapshotIndexMetadata, IndexMetadata currentIndexMetadata) { + private static IndexMetadata.Builder restoreOverClosedIndex( + IndexMetadata snapshotIndexMetadata, + IndexMetadata currentIndexMetadata, + TransportVersion minTransportVersion + ) { final IndexMetadata.Builder indexMdBuilder = IndexMetadata.builder(snapshotIndexMetadata) .state(IndexMetadata.State.OPEN) .version(Math.max(snapshotIndexMetadata.getVersion(), 1 + currentIndexMetadata.getVersion())) @@ -1745,6 +1763,7 @@ private static IndexMetadata.Builder restoreOverClosedIndex(IndexMetadata snapsh .settingsVersion(Math.max(snapshotIndexMetadata.getSettingsVersion(), 1 + currentIndexMetadata.getSettingsVersion())) .aliasesVersion(Math.max(snapshotIndexMetadata.getAliasesVersion(), 1 + currentIndexMetadata.getAliasesVersion())) .timestampRange(IndexLongFieldRange.NO_SHARDS) + .eventIngestedRange(IndexLongFieldRange.NO_SHARDS, minTransportVersion) .index(currentIndexMetadata.getIndex().getName()) .settings( Settings.builder() diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index a9d00d1c441fa..da2a0c4b90f30 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -1,3 +1,5 @@ org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat +org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteResponseTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteResponseTests.java index 440ac3f286878..b71d16ee530f8 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteResponseTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteResponseTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.action.admin.cluster.reroute; +import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.Version; import org.elasticsearch.cluster.ClusterName; @@ -28,6 +29,7 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.shard.IndexLongFieldRange; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.test.AbstractChunkedSerializingTestCase; import org.elasticsearch.test.ESTestCase; @@ -190,6 +192,9 @@ public void testToXContentWithDeprecatedClusterState() { "system": false, "timestamp_range": { "shards": [] + }, + "event_ingested_range": { + "unknown":true } } }, @@ -271,6 +276,9 @@ public void testToXContentWithDeprecatedClusterStateAndMetadata() { "system" : false, "timestamp_range" : { "shards" : [ ] + }, + "event_ingested_range" : { + "unknown" : true } } }, @@ -354,6 +362,7 @@ private static ClusterState createClusterState() { .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()) .build() ) + .eventIngestedRange(IndexLongFieldRange.UNKNOWN, TransportVersion.current()) .build(), false ) diff --git a/server/src/test/java/org/elasticsearch/cluster/ClusterStateTests.java b/server/src/test/java/org/elasticsearch/cluster/ClusterStateTests.java index 70bdfd6072a27..23f3395c6c49e 100644 --- a/server/src/test/java/org/elasticsearch/cluster/ClusterStateTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/ClusterStateTests.java @@ -45,6 +45,7 @@ import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.shard.IndexLongFieldRange; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.indices.SystemIndices; @@ -314,6 +315,9 @@ public void testToXContent() throws IOException { "timestamp_range": { "shards": [] }, + "event_ingested_range": { + "unknown": true + }, "stats": { "write_load": { "loads": [-1.0], @@ -579,6 +583,9 @@ public void testToXContent_FlatSettingTrue_ReduceMappingFalse() throws IOExcepti "timestamp_range" : { "shards" : [ ] }, + "event_ingested_range" : { + "unknown" : true + }, "stats" : { "write_load" : { "loads" : [ @@ -854,6 +861,9 @@ public void testToXContent_FlatSettingFalse_ReduceMappingTrue() throws IOExcepti "timestamp_range" : { "shards" : [ ] }, + "event_ingested_range" : { + "unknown" : true + }, "stats" : { "write_load" : { "loads" : [ @@ -1024,6 +1034,9 @@ public void testToXContentSameTypeName() throws IOException { "system" : false, "timestamp_range" : { "shards" : [ ] + }, + "event_ingested_range" : { + "shards" : [ ] } } }, @@ -1095,6 +1108,7 @@ private ClusterState buildClusterState() throws IOException { .putRolloverInfo(new RolloverInfo("rolloveAlias", new ArrayList<>(), 1L)) .stats(new IndexMetadataStats(IndexWriteLoad.builder(1).build(), 120, 1)) .indexWriteLoadForecast(8.0) + .eventIngestedRange(IndexLongFieldRange.UNKNOWN, TransportVersions.V_8_0_0) .build(); return ClusterState.builder(ClusterName.DEFAULT) diff --git a/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStartedClusterStateTaskExecutorTests.java b/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStartedClusterStateTaskExecutorTests.java index ea2bc79542e4a..d5ea160427952 100644 --- a/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStartedClusterStateTaskExecutorTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStartedClusterStateTaskExecutorTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.cluster.action.shard; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.cluster.ClusterState; @@ -69,7 +70,14 @@ public void testEmptyTaskListProducesSameClusterState() throws Exception { public void testNonExistentIndexMarkedAsSuccessful() throws Exception { final ClusterState clusterState = stateWithNoShard(); final StartedShardUpdateTask entry = new StartedShardUpdateTask( - new StartedShardEntry(new ShardId("test", "_na", 0), "aId", randomNonNegativeLong(), "test", ShardLongFieldRange.UNKNOWN), + new StartedShardEntry( + new ShardId("test", "_na", 0), + "aId", + randomNonNegativeLong(), + "test", + ShardLongFieldRange.UNKNOWN, + ShardLongFieldRange.UNKNOWN + ), createTestListener() ); @@ -91,6 +99,7 @@ public void testNonExistentShardsAreMarkedAsSuccessful() throws Exception { String.valueOf(i), 0L, "allocation id", + ShardLongFieldRange.UNKNOWN, ShardLongFieldRange.UNKNOWN ), createTestListener() @@ -105,6 +114,7 @@ public void testNonExistentShardsAreMarkedAsSuccessful() throws Exception { String.valueOf(i), 0L, "shard id", + ShardLongFieldRange.UNKNOWN, ShardLongFieldRange.UNKNOWN ), createTestListener() @@ -133,7 +143,14 @@ public void testNonInitializingShardAreMarkedAsSuccessful() throws Exception { } final long primaryTerm = indexMetadata.primaryTerm(shardId.id()); return new StartedShardUpdateTask( - new StartedShardEntry(shardId, allocationId, primaryTerm, "test", ShardLongFieldRange.UNKNOWN), + new StartedShardEntry( + shardId, + allocationId, + primaryTerm, + "test", + ShardLongFieldRange.UNKNOWN, + ShardLongFieldRange.UNKNOWN + ), createTestListener() ); }) @@ -153,7 +170,14 @@ public void testStartPrimary() throws Exception { final String primaryAllocationId = primaryShard.allocationId().getId(); final var task = new StartedShardUpdateTask( - new StartedShardEntry(shardId, primaryAllocationId, primaryTerm, "test", ShardLongFieldRange.UNKNOWN), + new StartedShardEntry( + shardId, + primaryAllocationId, + primaryTerm, + "test", + ShardLongFieldRange.UNKNOWN, + ShardLongFieldRange.UNKNOWN + ), createTestListener() ); @@ -180,7 +204,14 @@ public void testStartReplica() throws Exception { final ShardRouting replicaShard = clusterState.routingTable().shardRoutingTable(shardId).replicaShards().iterator().next(); final String replicaAllocationId = replicaShard.allocationId().getId(); final var task = new StartedShardUpdateTask( - new StartedShardEntry(shardId, replicaAllocationId, primaryTerm, "test", ShardLongFieldRange.UNKNOWN), + new StartedShardEntry( + shardId, + replicaAllocationId, + primaryTerm, + "test", + ShardLongFieldRange.UNKNOWN, + ShardLongFieldRange.UNKNOWN + ), createTestListener() ); @@ -208,7 +239,14 @@ public void testDuplicateStartsAreOkay() throws Exception { final List tasks = IntStream.range(0, randomIntBetween(2, 10)) .mapToObj( i -> new StartedShardUpdateTask( - new StartedShardEntry(shardId, allocationId, primaryTerm, "test", ShardLongFieldRange.UNKNOWN), + new StartedShardEntry( + shardId, + allocationId, + primaryTerm, + "test", + ShardLongFieldRange.UNKNOWN, + ShardLongFieldRange.UNKNOWN + ), createTestListener() ) ) @@ -249,6 +287,7 @@ public void testPrimaryTermsMismatchOnPrimary() throws Exception { primaryAllocationId, primaryTerm - 1, "primary terms does not match on primary", + ShardLongFieldRange.UNKNOWN, ShardLongFieldRange.UNKNOWN ), createTestListener() @@ -270,6 +309,7 @@ public void testPrimaryTermsMismatchOnPrimary() throws Exception { primaryAllocationId, primaryTerm, "primary terms match on primary", + ShardLongFieldRange.UNKNOWN, ShardLongFieldRange.UNKNOWN ), createTestListener() @@ -312,7 +352,14 @@ public void testPrimaryTermsMismatchOnReplica() throws Exception { .getId(); final StartedShardUpdateTask task = new StartedShardUpdateTask( - new StartedShardEntry(shardId, replicaAllocationId, replicaPrimaryTerm, "test on replica", ShardLongFieldRange.UNKNOWN), + new StartedShardEntry( + shardId, + replicaAllocationId, + replicaPrimaryTerm, + "test on replica", + ShardLongFieldRange.UNKNOWN, + ShardLongFieldRange.UNKNOWN + ), createTestListener() ); @@ -339,13 +386,18 @@ public void testExpandsTimestampRangeForPrimary() throws Exception { final String primaryAllocationId = primaryShard.allocationId().getId(); assertThat(indexMetadata.getTimestampRange(), sameInstance(IndexLongFieldRange.NO_SHARDS)); + assertThat(indexMetadata.getEventIngestedRange(), sameInstance(IndexLongFieldRange.NO_SHARDS)); final ShardLongFieldRange shardTimestampRange = randomBoolean() ? ShardLongFieldRange.UNKNOWN : randomBoolean() ? ShardLongFieldRange.EMPTY : ShardLongFieldRange.of(1606407943000L, 1606407944000L); + final ShardLongFieldRange shardEventIngestedRange = randomBoolean() ? ShardLongFieldRange.UNKNOWN + : randomBoolean() ? ShardLongFieldRange.EMPTY + : ShardLongFieldRange.of(1606407943000L, 1606407944000L); + final var task = new StartedShardUpdateTask( - new StartedShardEntry(shardId, primaryAllocationId, primaryTerm, "test", shardTimestampRange), + new StartedShardEntry(shardId, primaryAllocationId, primaryTerm, "test", shardTimestampRange, shardEventIngestedRange), createTestListener() ); @@ -369,6 +421,21 @@ public void testExpandsTimestampRangeForPrimary() throws Exception { assertThat(timestampRange.getMin(), equalTo(shardTimestampRange.getMin())); assertThat(timestampRange.getMax(), equalTo(shardTimestampRange.getMax())); } + + final var eventIngestedRange = resultingState.metadata().index(indexName).getEventIngestedRange(); + if (clusterState.getMinTransportVersion().before(TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE)) { + assertThat(eventIngestedRange, sameInstance(IndexLongFieldRange.UNKNOWN)); + } else { + if (shardEventIngestedRange == ShardLongFieldRange.UNKNOWN) { + assertThat(eventIngestedRange, sameInstance(IndexLongFieldRange.UNKNOWN)); + } else if (shardEventIngestedRange == ShardLongFieldRange.EMPTY) { + assertThat(eventIngestedRange, sameInstance(IndexLongFieldRange.EMPTY)); + } else { + assertTrue(eventIngestedRange.isComplete()); + assertThat(eventIngestedRange.getMin(), equalTo(shardEventIngestedRange.getMin())); + assertThat(eventIngestedRange.getMax(), equalTo(shardEventIngestedRange.getMax())); + } + } } public void testExpandsTimestampRangeForReplica() throws Exception { @@ -380,15 +447,20 @@ public void testExpandsTimestampRangeForReplica() throws Exception { final long primaryTerm = indexMetadata.primaryTerm(shardId.id()); assertThat(indexMetadata.getTimestampRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); + assertThat(indexMetadata.getEventIngestedRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); final ShardLongFieldRange shardTimestampRange = randomBoolean() ? ShardLongFieldRange.UNKNOWN : randomBoolean() ? ShardLongFieldRange.EMPTY : ShardLongFieldRange.of(1606407943000L, 1606407944000L); + final ShardLongFieldRange shardEventIngestedRange = randomBoolean() ? ShardLongFieldRange.UNKNOWN + : randomBoolean() ? ShardLongFieldRange.EMPTY + : ShardLongFieldRange.of(1606407888888L, 1606407999999L); + final ShardRouting replicaShard = clusterState.routingTable().shardRoutingTable(shardId).replicaShards().iterator().next(); final String replicaAllocationId = replicaShard.allocationId().getId(); final var task = new StartedShardUpdateTask( - new StartedShardEntry(shardId, replicaAllocationId, primaryTerm, "test", shardTimestampRange), + new StartedShardEntry(shardId, replicaAllocationId, primaryTerm, "test", shardTimestampRange, shardEventIngestedRange), createTestListener() ); final var resultingState = executeTasks(clusterState, List.of(task)); @@ -401,7 +473,9 @@ public void testExpandsTimestampRangeForReplica() throws Exception { is(ShardRoutingState.STARTED) ); - assertThat(resultingState.metadata().index(indexName).getTimestampRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); + final IndexMetadata latestIndexMetadata = resultingState.metadata().index(indexName); + assertThat(latestIndexMetadata.getTimestampRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); + assertThat(latestIndexMetadata.getEventIngestedRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); } private ClusterState executeTasks(final ClusterState state, final List tasks) throws Exception { diff --git a/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStateActionTests.java b/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStateActionTests.java index 100f3bbcc7829..cada467ea3ad6 100644 --- a/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStateActionTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/action/shard/ShardStateActionTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.util.SetOnce; import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.replication.ClusterStateCreationUtils; import org.elasticsearch.cluster.ClusterState; @@ -61,6 +62,8 @@ import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; import static org.elasticsearch.test.ClusterServiceUtils.setState; +import static org.elasticsearch.test.TransportVersionUtils.getFirstVersion; +import static org.elasticsearch.test.TransportVersionUtils.getPreviousVersion; import static org.elasticsearch.test.TransportVersionUtils.randomCompatibleVersion; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; @@ -391,17 +394,24 @@ public void testDeduplicateRemoteShardStarted() throws InterruptedException { expectedRequests++; shardStateAction.clearRemoteShardRequestDeduplicator(); } - shardStateAction.shardStarted(startedShard, primaryTerm, "started", ShardLongFieldRange.EMPTY, new ActionListener<>() { - @Override - public void onResponse(Void aVoid) { - latch.countDown(); - } + shardStateAction.shardStarted( + startedShard, + primaryTerm, + "started", + ShardLongFieldRange.EMPTY, + ShardLongFieldRange.EMPTY, + new ActionListener<>() { + @Override + public void onResponse(Void aVoid) { + latch.countDown(); + } - @Override - public void onFailure(Exception e) { - latch.countDown(); + @Override + public void onFailure(Exception e) { + latch.countDown(); + } } - }); + ); } CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear(); assertThat(capturedRequests, arrayWithSize(expectedRequests)); @@ -482,7 +492,14 @@ public void testShardStarted() throws InterruptedException { final ShardRouting shardRouting = getRandomShardRouting(index); final long primaryTerm = clusterService.state().metadata().index(shardRouting.index()).primaryTerm(shardRouting.id()); final TestListener listener = new TestListener(); - shardStateAction.shardStarted(shardRouting, primaryTerm, "testShardStarted", ShardLongFieldRange.UNKNOWN, listener); + shardStateAction.shardStarted( + shardRouting, + primaryTerm, + "testShardStarted", + ShardLongFieldRange.UNKNOWN, + ShardLongFieldRange.UNKNOWN, + listener + ); final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear(); assertThat(capturedRequests[0].request(), instanceOf(ShardStateAction.StartedShardEntry.class)); @@ -578,7 +595,37 @@ public void testStartedShardEntrySerialization() throws Exception { final TransportVersion version = randomFrom(randomCompatibleVersion(random())); final ShardLongFieldRange timestampRange = ShardLongFieldRangeWireTests.randomRange(); - final StartedShardEntry startedShardEntry = new StartedShardEntry(shardId, allocationId, primaryTerm, message, timestampRange); + final ShardLongFieldRange eventIngestedRange = ShardLongFieldRangeWireTests.randomRange(); + var startedShardEntry = new StartedShardEntry(shardId, allocationId, primaryTerm, message, timestampRange, eventIngestedRange); + try (StreamInput in = serialize(startedShardEntry, version).streamInput()) { + in.setTransportVersion(version); + final StartedShardEntry deserialized = new StartedShardEntry(in); + assertThat(deserialized.shardId, equalTo(shardId)); + assertThat(deserialized.allocationId, equalTo(allocationId)); + assertThat(deserialized.primaryTerm, equalTo(primaryTerm)); + assertThat(deserialized.message, equalTo(message)); + assertThat(deserialized.timestampRange, equalTo(timestampRange)); + if (version.before(TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE)) { + assertThat(deserialized.eventIngestedRange, equalTo(ShardLongFieldRange.UNKNOWN)); + } else { + assertThat(deserialized.eventIngestedRange, equalTo(eventIngestedRange)); + } + } + } + + public void testStartedShardEntrySerializationWithOlderTransportVersion() throws Exception { + final ShardId shardId = new ShardId(randomRealisticUnicodeOfLengthBetween(10, 100), UUID.randomUUID().toString(), between(0, 1000)); + final String allocationId = randomRealisticUnicodeOfCodepointLengthBetween(10, 100); + final long primaryTerm = randomIntBetween(0, 100); + final String message = randomRealisticUnicodeOfCodepointLengthBetween(10, 100); + final TransportVersion version = randomFrom( + getFirstVersion(), + getPreviousVersion(TransportVersions.MINIMUM_COMPATIBLE), + getPreviousVersion(TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE) + ); + final ShardLongFieldRange timestampRange = ShardLongFieldRangeWireTests.randomRange(); + final ShardLongFieldRange eventIngestedRange = ShardLongFieldRangeWireTests.randomRange(); + var startedShardEntry = new StartedShardEntry(shardId, allocationId, primaryTerm, message, timestampRange, eventIngestedRange); try (StreamInput in = serialize(startedShardEntry, version).streamInput()) { in.setTransportVersion(version); final StartedShardEntry deserialized = new StartedShardEntry(in); @@ -587,6 +634,7 @@ public void testStartedShardEntrySerialization() throws Exception { assertThat(deserialized.primaryTerm, equalTo(primaryTerm)); assertThat(deserialized.message, equalTo(message)); assertThat(deserialized.timestampRange, equalTo(timestampRange)); + assertThat(deserialized.eventIngestedRange, equalTo(ShardLongFieldRange.UNKNOWN)); } } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 4a19284f3c4f9..527fd1f95b728 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -25,17 +25,23 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; +import org.elasticsearch.index.shard.IndexLongFieldRange; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.shard.ShardLongFieldRange; import org.elasticsearch.indices.IndicesModule; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; import org.junit.Before; @@ -53,6 +59,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; @@ -113,6 +120,14 @@ public void testIndexMetadataSerialization() throws IOException { .indexWriteLoadForecast(indexWriteLoadForecast) .shardSizeInBytesForecast(shardSizeInBytesForecast) .putInferenceFields(inferenceFields) + .eventIngestedRange( + randomFrom( + IndexLongFieldRange.UNKNOWN, + IndexLongFieldRange.EMPTY, + IndexLongFieldRange.NO_SHARDS, + IndexLongFieldRange.NO_SHARDS.extendWithShardRange(0, 1, ShardLongFieldRange.of(5000000, 5500000)) + ) + ) .build(); assertEquals(system, metadata.isSystem()); @@ -174,7 +189,99 @@ public void testIndexMetadataSerialization() throws IOException { assertEquals(metadata.getForecastedWriteLoad(), deserialized.getForecastedWriteLoad()); assertEquals(metadata.getForecastedShardSizeInBytes(), deserialized.getForecastedShardSizeInBytes()); assertEquals(metadata.getInferenceFields(), deserialized.getInferenceFields()); + assertEquals(metadata.getEventIngestedRange(), deserialized.getEventIngestedRange()); + } + } + + public void testIndexMetadataFromXContentParsingWithoutEventIngestedField() throws IOException { + Integer numShard = randomFrom(1, 2, 4, 8, 16); + int numberOfReplicas = randomIntBetween(0, 10); + final boolean system = randomBoolean(); + Map customMap = new HashMap<>(); + customMap.put(randomAlphaOfLength(5), randomAlphaOfLength(10)); + customMap.put(randomAlphaOfLength(10), randomAlphaOfLength(15)); + IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null; + Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null; + Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null; + Map inferenceFields = randomInferenceFields(); + + IndexMetadata metadata = IndexMetadata.builder("foo") + .settings(indexSettings(numShard, numberOfReplicas).put("index.version.created", 1)) + .creationDate(randomLong()) + .primaryTerm(0, 2) + .setRoutingNumShards(32) + .system(system) + .putCustom("my_custom", customMap) + .putRolloverInfo( + new RolloverInfo( + randomAlphaOfLength(5), + List.of( + new MaxAgeCondition(TimeValue.timeValueMillis(randomNonNegativeLong())), + new MaxDocsCondition(randomNonNegativeLong()), + new MaxSizeCondition(ByteSizeValue.ofBytes(randomNonNegativeLong())), + new MaxPrimaryShardSizeCondition(ByteSizeValue.ofBytes(randomNonNegativeLong())), + new MaxPrimaryShardDocsCondition(randomNonNegativeLong()), + new OptimalShardCountCondition(3) + ), + randomNonNegativeLong() + ) + ) + .stats(indexStats) + .indexWriteLoadForecast(indexWriteLoadForecast) + .shardSizeInBytesForecast(shardSizeInBytesForecast) + .putInferenceFields(inferenceFields) + .eventIngestedRange( + randomFrom( + IndexLongFieldRange.UNKNOWN, + IndexLongFieldRange.EMPTY, + IndexLongFieldRange.NO_SHARDS, + IndexLongFieldRange.NO_SHARDS.extendWithShardRange(0, 1, ShardLongFieldRange.of(5000000, 5500000)) + ) + ) + .build(); + assertEquals(system, metadata.isSystem()); + + final XContentBuilder builder = JsonXContent.contentBuilder(); + builder.startObject(); + IndexMetadata.FORMAT.toXContent(builder, metadata); + builder.endObject(); + + // convert XContent to a map and remove the IndexMetadata.KEY_EVENT_INGESTED_RANGE entry + // to simulate IndexMetadata from an older cluster version (before TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE) + Map indexMetadataMap = XContentHelper.convertToMap(BytesReference.bytes(builder), true, XContentType.JSON).v2(); + + @SuppressWarnings("unchecked") + Map inner = (Map) indexMetadataMap.get("foo"); + assertTrue(inner.containsKey(IndexMetadata.KEY_EVENT_INGESTED_RANGE)); + inner.remove(IndexMetadata.KEY_EVENT_INGESTED_RANGE); + // validate that the IndexMetadata.KEY_EVENT_INGESTED_RANGE has been removed before calling fromXContent + assertFalse(inner.containsKey(IndexMetadata.KEY_EVENT_INGESTED_RANGE)); + + IndexMetadata fromXContentMeta; + XContentParserConfiguration config = XContentParserConfiguration.EMPTY.withRegistry(xContentRegistry()) + .withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + try (XContentParser xContentParser = XContentHelper.mapToXContentParser(config, indexMetadataMap);) { + fromXContentMeta = IndexMetadata.fromXContent(xContentParser); } + + assertEquals(IndexLongFieldRange.NO_SHARDS, fromXContentMeta.getTimestampRange()); + // should come back as UNKNOWN when missing from IndexMetadata XContent + assertEquals(IndexLongFieldRange.UNKNOWN, fromXContentMeta.getEventIngestedRange()); + + // check a few other fields to ensure the parsing worked as expected + assertEquals( + "expected: " + Strings.toString(metadata) + "\nactual : " + Strings.toString(fromXContentMeta), + metadata, + fromXContentMeta + ); + assertEquals(metadata.hashCode(), fromXContentMeta.hashCode()); + assertEquals(metadata.getNumberOfReplicas(), fromXContentMeta.getNumberOfReplicas()); + assertEquals(metadata.getNumberOfShards(), fromXContentMeta.getNumberOfShards()); + assertEquals(metadata.getCreationVersion(), fromXContentMeta.getCreationVersion()); + Map expectedCustom = Map.of("my_custom", new DiffableStringMap(customMap)); + assertEquals(metadata.getCustomData(), expectedCustom); + assertEquals(metadata.getCustomData(), fromXContentMeta.getCustomData()); + assertEquals(metadata.getStats(), fromXContentMeta.getStats()); } public void testGetRoutingFactor() { diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java index 43d64522ee6fb..8a487e5653627 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java @@ -10,6 +10,8 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.indices.alias.Alias; import org.elasticsearch.action.admin.indices.create.CreateIndexClusterStateUpdateRequest; @@ -47,6 +49,7 @@ import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.query.SearchExecutionContextHelper; +import org.elasticsearch.index.shard.IndexLongFieldRange; import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.InvalidAliasNameException; import org.elasticsearch.indices.InvalidIndexNameException; @@ -1096,11 +1099,51 @@ public void testBuildIndexMetadata() { Settings indexSettings = indexSettings(IndexVersion.current(), 1, 0).build(); List aliases = List.of(AliasMetadata.builder("alias1").build()); - IndexMetadata indexMetadata = buildIndexMetadata("test", aliases, () -> null, indexSettings, 4, sourceIndexMetadata, false); + IndexMetadata indexMetadata = buildIndexMetadata( + "test", + aliases, + () -> null, + indexSettings, + 4, + sourceIndexMetadata, + false, + TransportVersion.current() + ); + + assertThat(indexMetadata.getAliases().size(), is(1)); + assertThat(indexMetadata.getAliases().keySet().iterator().next(), is("alias1")); + assertThat("The source index primary term must be used", indexMetadata.primaryTerm(0), is(3L)); + assertThat(indexMetadata.getTimestampRange(), equalTo(IndexLongFieldRange.NO_SHARDS)); + assertThat(indexMetadata.getEventIngestedRange(), equalTo(IndexLongFieldRange.NO_SHARDS)); + } + + public void testBuildIndexMetadataWithTransportVersionBeforeEventIngestedRangeAdded() { + IndexMetadata sourceIndexMetadata = IndexMetadata.builder("parent") + .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build()) + .numberOfShards(1) + .numberOfReplicas(0) + .primaryTerm(0, 3L) + .build(); + + Settings indexSettings = indexSettings(IndexVersion.current(), 1, 0).build(); + List aliases = List.of(AliasMetadata.builder("alias1").build()); + IndexMetadata indexMetadata = buildIndexMetadata( + "test", + aliases, + () -> null, + indexSettings, + 4, + sourceIndexMetadata, + false, + randomFrom(TransportVersions.V_7_0_0, TransportVersions.V_8_0_0) + ); assertThat(indexMetadata.getAliases().size(), is(1)); assertThat(indexMetadata.getAliases().keySet().iterator().next(), is("alias1")); assertThat("The source index primary term must be used", indexMetadata.primaryTerm(0), is(3L)); + assertThat(indexMetadata.getTimestampRange(), equalTo(IndexLongFieldRange.NO_SHARDS)); + // on versions before event.ingested was added to cluster state, it should default to UNKNOWN, not NO_SHARDS + assertThat(indexMetadata.getEventIngestedRange(), equalTo(IndexLongFieldRange.UNKNOWN)); } public void testGetIndexNumberOfRoutingShardsWithNullSourceIndex() { diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/ToAndFromJsonMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/ToAndFromJsonMetadataTests.java index 81611841b6d83..75b7ec97f2886 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/ToAndFromJsonMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/ToAndFromJsonMetadataTests.java @@ -357,6 +357,9 @@ public void testToXContentAPI_SameTypeName() throws IOException { "system" : false, "timestamp_range" : { "shards" : [ ] + }, + "event_ingested_range" : { + "shards" : [ ] } } }, @@ -562,6 +565,9 @@ public void testToXContentAPI_FlatSettingTrue_ReduceMappingFalse() throws IOExce "system" : false, "timestamp_range" : { "shards" : [ ] + }, + "event_ingested_range" : { + "shards" : [ ] } } }, @@ -673,6 +679,9 @@ public void testToXContentAPI_FlatSettingFalse_ReduceMappingTrue() throws IOExce "system" : false, "timestamp_range" : { "shards" : [ ] + }, + "event_ingested_range" : { + "shards" : [ ] } } }, @@ -810,6 +819,9 @@ public void testToXContentAPIReservedMetadata() throws IOException { "system" : false, "timestamp_range" : { "shards" : [ ] + }, + "event_ingested_range" : { + "shards" : [ ] } } }, diff --git a/server/src/test/java/org/elasticsearch/cluster/serialization/ClusterSerializationTests.java b/server/src/test/java/org/elasticsearch/cluster/serialization/ClusterSerializationTests.java index 5484998fef2e9..1bae3ca59f3d9 100644 --- a/server/src/test/java/org/elasticsearch/cluster/serialization/ClusterSerializationTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/serialization/ClusterSerializationTests.java @@ -36,6 +36,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.shard.IndexLongFieldRange; +import org.elasticsearch.index.shard.ShardLongFieldRange; import org.elasticsearch.snapshots.Snapshot; import org.elasticsearch.snapshots.SnapshotId; import org.elasticsearch.test.TransportVersionUtils; @@ -48,6 +50,7 @@ import java.util.List; import java.util.Map; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; @@ -55,9 +58,102 @@ public class ClusterSerializationTests extends ESAllocationTestCase { public void testClusterStateSerialization() throws Exception { - Metadata metadata = Metadata.builder() - .put(IndexMetadata.builder("test").settings(settings(IndexVersion.current())).numberOfShards(10).numberOfReplicas(1)) - .build(); + IndexLongFieldRange eventIngestedRangeInput = randomFrom( + IndexLongFieldRange.UNKNOWN, + IndexLongFieldRange.NO_SHARDS, + IndexLongFieldRange.EMPTY, + IndexLongFieldRange.NO_SHARDS.extendWithShardRange(0, 1, ShardLongFieldRange.of(100000, 200000)) + ); + + IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder("test") + .settings(settings(IndexVersion.current())) + .numberOfShards(10) + .numberOfReplicas(1) + .eventIngestedRange(eventIngestedRangeInput, TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE); + + ClusterStateTestRecord result = createAndSerializeClusterState(indexMetadataBuilder, TransportVersion.current()); + + assertThat(result.serializedClusterState().getClusterName().value(), equalTo(result.clusterState().getClusterName().value())); + assertThat(result.serializedClusterState().routingTable().toString(), equalTo(result.clusterState().routingTable().toString())); + + IndexLongFieldRange eventIngestedRangeOutput = result.serializedClusterState().getMetadata().index("test").getEventIngestedRange(); + assertThat(eventIngestedRangeInput, equalTo(eventIngestedRangeOutput)); + + if (eventIngestedRangeInput.containsAllShardRanges() && eventIngestedRangeInput != IndexLongFieldRange.EMPTY) { + assertThat(eventIngestedRangeOutput.getMin(), equalTo(100000L)); + assertThat(eventIngestedRangeOutput.getMax(), equalTo(200000L)); + } + } + + public void testClusterStateSerializationWithTimestampRangesWithOlderTransportVersion() throws Exception { + TransportVersion versionBeforeEventIngestedInClusterState = randomFrom( + TransportVersions.V_7_0_0, + TransportVersions.V_8_0_0, + TransportVersions.ML_INFERENCE_GOOGLE_VERTEX_AI_EMBEDDINGS_ADDED // version before EVENT_INGESTED_RANGE_IN_CLUSTER_STATE + ); + { + IndexLongFieldRange eventIngestedRangeInput = randomFrom( + IndexLongFieldRange.UNKNOWN, + IndexLongFieldRange.NO_SHARDS, + IndexLongFieldRange.EMPTY, + IndexLongFieldRange.NO_SHARDS.extendWithShardRange(0, 1, ShardLongFieldRange.of(100000, 200000)) + ); + + IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder("test") + .settings(settings(IndexVersion.current())) + .numberOfShards(10) + .numberOfReplicas(1) + .eventIngestedRange(eventIngestedRangeInput, versionBeforeEventIngestedInClusterState); + + ClusterStateTestRecord result = createAndSerializeClusterState(indexMetadataBuilder, versionBeforeEventIngestedInClusterState); + + assertThat(result.serializedClusterState().getClusterName().value(), equalTo(result.clusterState().getClusterName().value())); + assertThat(result.serializedClusterState().routingTable().toString(), equalTo(result.clusterState().routingTable().toString())); + + IndexLongFieldRange eventIngestedRangeOutput = result.serializedClusterState() + .getMetadata() + .index("test") + .getEventIngestedRange(); + // should always come back as UNKNOWN when an older transport version is passed in + assertSame(IndexLongFieldRange.UNKNOWN, eventIngestedRangeOutput); + } + { + // UNKNOWN is the only allowed state for event.ingested range in older versions, so this serialization test should fail + IndexLongFieldRange eventIngestedRangeInput = randomFrom( + IndexLongFieldRange.NO_SHARDS, + IndexLongFieldRange.EMPTY, + IndexLongFieldRange.NO_SHARDS.extendWithShardRange(0, 1, ShardLongFieldRange.of(100000, 200000)) + ); + + IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder("test") + .settings(settings(IndexVersion.current())) + .numberOfShards(10) + .numberOfReplicas(1) + .eventIngestedRange(eventIngestedRangeInput, TransportVersion.current()); + + AssertionError assertionError = expectThrows( + AssertionError.class, + () -> createAndSerializeClusterState(indexMetadataBuilder, versionBeforeEventIngestedInClusterState) + ); + + assertThat( + assertionError.getMessage(), + containsString("eventIngestedRange should be UNKNOWN until all nodes are on the new version") + ); + } + } + + /** + * @param clusterState original ClusterState created by helper method + * @param serializedClusterState serialized version of the clusterState + */ + private record ClusterStateTestRecord(ClusterState clusterState, ClusterState serializedClusterState) {} + + private static ClusterStateTestRecord createAndSerializeClusterState( + IndexMetadata.Builder indexMetadataBuilder, + TransportVersion transportVersion + ) throws IOException { + Metadata metadata = Metadata.builder().put(indexMetadataBuilder).build(); RoutingTable routingTable = RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY) .addAsNew(metadata.index("test")) @@ -82,15 +178,17 @@ public void testClusterStateSerialization() throws Exception { .routingTable(strategy.reroute(clusterState, "reroute", ActionListener.noop()).routingTable()) .build(); - ClusterState serializedClusterState = ClusterState.Builder.fromBytes( - ClusterState.Builder.toBytes(clusterState), - newNode("node1"), + BytesStreamOutput outStream = new BytesStreamOutput(); + outStream.setTransportVersion(transportVersion); + clusterState.writeTo(outStream); + StreamInput inStream = new NamedWriteableAwareStreamInput( + outStream.bytes().streamInput(), new NamedWriteableRegistry(ClusterModule.getNamedWriteables()) ); + inStream.setTransportVersion(transportVersion); + ClusterState serializedClusterState = ClusterState.readFrom(inStream, null); - assertThat(serializedClusterState.getClusterName().value(), equalTo(clusterState.getClusterName().value())); - - assertThat(serializedClusterState.routingTable().toString(), equalTo(clusterState.routingTable().toString())); + return new ClusterStateTestRecord(clusterState, serializedClusterState); } public void testRoutingTableSerialization() throws Exception { diff --git a/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java b/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java index 3b0935e8f7b5c..e10cca58f8b78 100644 --- a/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java +++ b/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java @@ -50,14 +50,17 @@ private void assertParseException(String input, String format) { } private void assertParseException(String input, String format, int errorIndex) { - assertParseException(input, format, equalTo(errorIndex)); + assertParseException(input, DateFormatter.forPattern(format), equalTo(errorIndex)); } - private void assertParseException(String input, String format, Matcher indexMatcher) { - DateFormatter javaTimeFormatter = DateFormatter.forPattern(format); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> javaTimeFormatter.parse(input)); + private void assertParseException(String input, DateFormatter formatter, int errorIndex) { + assertParseException(input, formatter, equalTo(errorIndex)); + } + + private void assertParseException(String input, DateFormatter formatter, Matcher indexMatcher) { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> formatter.parse(input)); assertThat(e.getMessage(), containsString(input)); - assertThat(e.getMessage(), containsString(format)); + assertThat(e.getMessage(), containsString(formatter.pattern())); assertThat(e.getCause(), instanceOf(DateTimeParseException.class)); assertThat(((DateTimeParseException) e.getCause()).getErrorIndex(), indexMatcher); } @@ -811,6 +814,20 @@ public void testDecimalPointParsing() { assertParseException("2001-01-01T00:00:00.123,456Z", "date_optional_time", 23); // This should fail, but java is ok with this because the field has the same value // assertJavaTimeParseException("2001-01-01T00:00:00.123,123Z", "strict_date_optional_time_nanos"); + + // for historical reasons, + // despite the use of a locale with , separator these formatters still expect only . decimals + DateFormatter formatter = DateFormatter.forPattern("strict_date_time").withLocale(Locale.FRANCE); + assertParses("2020-01-01T12:00:00.0Z", formatter); + assertParseException("2020-01-01T12:00:00,0Z", formatter, 19); + + formatter = DateFormatter.forPattern("strict_date_hour_minute_second_fraction").withLocale(Locale.GERMANY); + assertParses("2020-01-01T12:00:00.0", formatter); + assertParseException("2020-01-01T12:00:00,0", formatter, 19); + + formatter = DateFormatter.forPattern("strict_date_hour_minute_second_millis").withLocale(Locale.ITALY); + assertParses("2020-01-01T12:00:00.0", formatter); + assertParseException("2020-01-01T12:00:00,0", formatter, 19); } public void testTimeZoneFormatting() { diff --git a/server/src/test/java/org/elasticsearch/common/time/Iso8601ParserTests.java b/server/src/test/java/org/elasticsearch/common/time/Iso8601ParserTests.java index 185c9aa983aaa..18d4e3b624465 100644 --- a/server/src/test/java/org/elasticsearch/common/time/Iso8601ParserTests.java +++ b/server/src/test/java/org/elasticsearch/common/time/Iso8601ParserTests.java @@ -33,6 +33,12 @@ import static java.time.temporal.ChronoField.NANO_OF_SECOND; import static java.time.temporal.ChronoField.SECOND_OF_MINUTE; import static java.time.temporal.ChronoField.YEAR; +import static org.elasticsearch.common.time.DecimalSeparator.BOTH; +import static org.elasticsearch.common.time.DecimalSeparator.COMMA; +import static org.elasticsearch.common.time.DecimalSeparator.DOT; +import static org.elasticsearch.common.time.TimezonePresence.FORBIDDEN; +import static org.elasticsearch.common.time.TimezonePresence.MANDATORY; +import static org.elasticsearch.common.time.TimezonePresence.OPTIONAL; import static org.elasticsearch.test.LambdaMatchers.transformedMatch; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -42,7 +48,7 @@ public class Iso8601ParserTests extends ESTestCase { private static Iso8601Parser defaultParser() { - return new Iso8601Parser(Set.of(), true, Map.of()); + return new Iso8601Parser(Set.of(), true, null, BOTH, OPTIONAL, Map.of()); } private static Matcher hasResult(DateTime dateTime) { @@ -77,68 +83,193 @@ public void testOutOfRange() { public void testMandatoryFields() { assertThat( - new Iso8601Parser(Set.of(YEAR), true, Map.of()).tryParse("2023", null), + new Iso8601Parser(Set.of(YEAR), true, null, BOTH, OPTIONAL, Map.of()).tryParse("2023", null), hasResult(new DateTime(2023, null, null, null, null, null, null, null, null)) ); - assertThat(new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR), true, Map.of()).tryParse("2023", null), hasError(4)); + assertThat( + new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR), true, null, BOTH, OPTIONAL, Map.of()).tryParse("2023", null), + hasError(4) + ); assertThat( - new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR), true, Map.of()).tryParse("2023-06", null), + new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR), true, null, BOTH, OPTIONAL, Map.of()).tryParse("2023-06", null), hasResult(new DateTime(2023, 6, null, null, null, null, null, null, null)) ); - assertThat(new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH), true, Map.of()).tryParse("2023-06", null), hasError(7)); + assertThat( + new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH), true, null, BOTH, OPTIONAL, Map.of()).tryParse("2023-06", null), + hasError(7) + ); assertThat( - new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH), true, Map.of()).tryParse("2023-06-20", null), + new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH), true, null, BOTH, OPTIONAL, Map.of()).tryParse("2023-06-20", null), hasResult(new DateTime(2023, 6, 20, null, null, null, null, null, null)) ); assertThat( - new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY), false, Map.of()).tryParse("2023-06-20", null), + new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY), false, null, BOTH, OPTIONAL, Map.of()).tryParse( + "2023-06-20", + null + ), hasError(10) ); assertThat( - new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY), false, Map.of()).tryParse("2023-06-20T15", null), - hasResult(new DateTime(2023, 6, 20, 15, 0, 0, 0, null, null)) - ); - assertThat( - new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR), false, Map.of()).tryParse( + new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY), false, null, BOTH, OPTIONAL, Map.of()).tryParse( "2023-06-20T15", null ), + hasResult(new DateTime(2023, 6, 20, 15, 0, 0, 0, null, null)) + ); + assertThat( + new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR), false, null, BOTH, OPTIONAL, Map.of()) + .tryParse("2023-06-20T15", null), hasError(13) ); assertThat( - new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR), false, Map.of()).tryParse( - "2023-06-20T15Z", - null - ), + new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR), false, null, BOTH, OPTIONAL, Map.of()) + .tryParse("2023-06-20T15Z", null), hasError(13) ); assertThat( - new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR), false, Map.of()).tryParse( - "2023-06-20T15:48", - null - ), + new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR), false, null, BOTH, OPTIONAL, Map.of()) + .tryParse("2023-06-20T15:48", null), hasResult(new DateTime(2023, 6, 20, 15, 48, 0, 0, null, null)) ); assertThat( - new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE), false, Map.of()) - .tryParse("2023-06-20T15:48", null), + new Iso8601Parser( + Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE), + false, + null, + BOTH, + OPTIONAL, + Map.of() + ).tryParse("2023-06-20T15:48", null), hasError(16) ); assertThat( - new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE), false, Map.of()) - .tryParse("2023-06-20T15:48Z", null), + new Iso8601Parser( + Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE), + false, + null, + BOTH, + OPTIONAL, + Map.of() + ).tryParse("2023-06-20T15:48Z", null), hasError(16) ); assertThat( - new Iso8601Parser(Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE), false, Map.of()) - .tryParse("2023-06-20T15:48:09", null), + new Iso8601Parser( + Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE), + false, + null, + BOTH, + OPTIONAL, + Map.of() + ).tryParse("2023-06-20T15:48:09", null), hasResult(new DateTime(2023, 6, 20, 15, 48, 9, 0, null, null)) ); + + assertThat( + new Iso8601Parser( + Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE, NANO_OF_SECOND), + false, + null, + BOTH, + OPTIONAL, + Map.of() + ).tryParse("2023-06-20T15:48:09", null), + hasError(19) + ); + assertThat( + new Iso8601Parser( + Set.of(YEAR, MONTH_OF_YEAR, DAY_OF_MONTH, HOUR_OF_DAY, MINUTE_OF_HOUR, SECOND_OF_MINUTE, NANO_OF_SECOND), + false, + null, + BOTH, + OPTIONAL, + Map.of() + ).tryParse("2023-06-20T15:48:09.5", null), + hasResult(new DateTime(2023, 6, 20, 15, 48, 9, 500_000_000, null, null)) + ); + } + + public void testMaxAllowedField() { + assertThat( + new Iso8601Parser(Set.of(), false, YEAR, BOTH, FORBIDDEN, Map.of()).tryParse("2023", null), + hasResult(new DateTime(2023, null, null, null, null, null, null, null, null)) + ); + assertThat(new Iso8601Parser(Set.of(), false, YEAR, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01", null), hasError(4)); + + assertThat( + new Iso8601Parser(Set.of(), false, MONTH_OF_YEAR, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01", null), + hasResult(new DateTime(2023, 1, null, null, null, null, null, null, null)) + ); + assertThat(new Iso8601Parser(Set.of(), false, MONTH_OF_YEAR, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01", null), hasError(7)); + + assertThat( + new Iso8601Parser(Set.of(), false, DAY_OF_MONTH, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01", null), + hasResult(new DateTime(2023, 1, 1, null, null, null, null, null, null)) + ); + assertThat(new Iso8601Parser(Set.of(), false, DAY_OF_MONTH, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T", null), hasError(10)); + assertThat( + new Iso8601Parser(Set.of(), false, DAY_OF_MONTH, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T12", null), + hasError(10) + ); + + assertThat( + new Iso8601Parser(Set.of(), false, HOUR_OF_DAY, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T12", null), + hasResult(new DateTime(2023, 1, 1, 12, 0, 0, 0, null, null)) + ); + assertThat( + new Iso8601Parser(Set.of(), false, HOUR_OF_DAY, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T12:00", null), + hasError(13) + ); + + assertThat( + new Iso8601Parser(Set.of(), false, MINUTE_OF_HOUR, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T12:00", null), + hasResult(new DateTime(2023, 1, 1, 12, 0, 0, 0, null, null)) + ); + assertThat( + new Iso8601Parser(Set.of(), false, MINUTE_OF_HOUR, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T12:00:00", null), + hasError(16) + ); + + assertThat( + new Iso8601Parser(Set.of(), false, SECOND_OF_MINUTE, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T12:00:00", null), + hasResult(new DateTime(2023, 1, 1, 12, 0, 0, 0, null, null)) + ); + assertThat( + new Iso8601Parser(Set.of(), false, SECOND_OF_MINUTE, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T12:00:00.5", null), + hasError(19) + ); + } + + public void testTimezoneForbidden() { + assertThat(new Iso8601Parser(Set.of(), false, null, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T12Z", null), hasError(13)); + assertThat(new Iso8601Parser(Set.of(), false, null, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T12:00Z", null), hasError(16)); + assertThat( + new Iso8601Parser(Set.of(), false, null, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T12:00:00Z", null), + hasError(19) + ); + + // a default timezone should still make it through + ZoneOffset zoneId = ZoneOffset.ofHours(2); + assertThat( + new Iso8601Parser(Set.of(), false, null, BOTH, FORBIDDEN, Map.of()).tryParse("2023-01-01T12:00:00", zoneId), + hasResult(new DateTime(2023, 1, 1, 12, 0, 0, 0, zoneId, zoneId)) + ); + } + + public void testTimezoneMandatory() { + assertThat(new Iso8601Parser(Set.of(), false, null, BOTH, MANDATORY, Map.of()).tryParse("2023-01-01T12", null), hasError(13)); + assertThat(new Iso8601Parser(Set.of(), false, null, BOTH, MANDATORY, Map.of()).tryParse("2023-01-01T12:00", null), hasError(16)); + assertThat(new Iso8601Parser(Set.of(), false, null, BOTH, MANDATORY, Map.of()).tryParse("2023-01-01T12:00:00", null), hasError(19)); + + assertThat( + new Iso8601Parser(Set.of(), false, null, BOTH, MANDATORY, Map.of()).tryParse("2023-01-01T12:00:00Z", null), + hasResult(new DateTime(2023, 1, 1, 12, 0, 0, 0, ZoneOffset.UTC, ZoneOffset.UTC)) + ); } public void testParseNanos() { @@ -188,6 +319,41 @@ public void testParseNanos() { assertThat(defaultParser().tryParse("2023-01-01T12:00:00.0000000005", null), hasError(29)); } + public void testParseDecimalSeparator() { + assertThat( + new Iso8601Parser(Set.of(), false, null, BOTH, OPTIONAL, Map.of()).tryParse("2023-01-01T12:00:00.0", null), + hasResult(new DateTime(2023, 1, 1, 12, 0, 0, 0, null, null)) + ); + assertThat( + new Iso8601Parser(Set.of(), false, null, BOTH, OPTIONAL, Map.of()).tryParse("2023-01-01T12:00:00,0", null), + hasResult(new DateTime(2023, 1, 1, 12, 0, 0, 0, null, null)) + ); + + assertThat( + new Iso8601Parser(Set.of(), false, null, DOT, OPTIONAL, Map.of()).tryParse("2023-01-01T12:00:00.0", null), + hasResult(new DateTime(2023, 1, 1, 12, 0, 0, 0, null, null)) + ); + assertThat(new Iso8601Parser(Set.of(), false, null, DOT, OPTIONAL, Map.of()).tryParse("2023-01-01T12:00:00,0", null), hasError(19)); + + assertThat( + new Iso8601Parser(Set.of(), false, null, COMMA, OPTIONAL, Map.of()).tryParse("2023-01-01T12:00:00.0", null), + hasError(19) + ); + assertThat( + new Iso8601Parser(Set.of(), false, null, COMMA, OPTIONAL, Map.of()).tryParse("2023-01-01T12:00:00,0", null), + hasResult(new DateTime(2023, 1, 1, 12, 0, 0, 0, null, null)) + ); + + assertThat( + new Iso8601Parser(Set.of(), false, null, BOTH, OPTIONAL, Map.of()).tryParse("2023-01-01T12:00:00+0", null), + hasError(19) + ); + assertThat( + new Iso8601Parser(Set.of(), false, null, BOTH, OPTIONAL, Map.of()).tryParse("2023-01-01T12:00:00+0", null), + hasError(19) + ); + } + private static Matcher hasTimezone(ZoneId offset) { return transformedMatch(r -> r.result().query(TemporalQueries.zone()), equalTo(offset)); } @@ -351,7 +517,7 @@ public void testDefaults() { ); assertThat( - new Iso8601Parser(Set.of(), true, defaults).tryParse("2023", null), + new Iso8601Parser(Set.of(), true, null, BOTH, OPTIONAL, defaults).tryParse("2023", null), hasResult( new DateTime( 2023, @@ -367,7 +533,7 @@ public void testDefaults() { ) ); assertThat( - new Iso8601Parser(Set.of(), true, defaults).tryParse("2023-01", null), + new Iso8601Parser(Set.of(), true, null, BOTH, OPTIONAL, defaults).tryParse("2023-01", null), hasResult( new DateTime( 2023, @@ -383,7 +549,7 @@ public void testDefaults() { ) ); assertThat( - new Iso8601Parser(Set.of(), true, defaults).tryParse("2023-01-01", null), + new Iso8601Parser(Set.of(), true, null, BOTH, OPTIONAL, defaults).tryParse("2023-01-01", null), hasResult( new DateTime( 2023, @@ -399,7 +565,7 @@ public void testDefaults() { ) ); assertThat( - new Iso8601Parser(Set.of(), true, defaults).tryParse("2023-01-01T00", null), + new Iso8601Parser(Set.of(), true, null, BOTH, OPTIONAL, defaults).tryParse("2023-01-01T00", null), hasResult( new DateTime( 2023, @@ -415,15 +581,15 @@ public void testDefaults() { ) ); assertThat( - new Iso8601Parser(Set.of(), true, defaults).tryParse("2023-01-01T00:00", null), + new Iso8601Parser(Set.of(), true, null, BOTH, OPTIONAL, defaults).tryParse("2023-01-01T00:00", null), hasResult(new DateTime(2023, 1, 1, 0, 0, defaults.get(SECOND_OF_MINUTE), defaults.get(NANO_OF_SECOND), null, null)) ); assertThat( - new Iso8601Parser(Set.of(), true, defaults).tryParse("2023-01-01T00:00:00", null), + new Iso8601Parser(Set.of(), true, null, BOTH, OPTIONAL, defaults).tryParse("2023-01-01T00:00:00", null), hasResult(new DateTime(2023, 1, 1, 0, 0, 0, defaults.get(NANO_OF_SECOND), null, null)) ); assertThat( - new Iso8601Parser(Set.of(), true, defaults).tryParse("2023-01-01T00:00:00.0", null), + new Iso8601Parser(Set.of(), true, null, BOTH, OPTIONAL, defaults).tryParse("2023-01-01T00:00:00.0", null), hasResult(new DateTime(2023, 1, 1, 0, 0, 0, 0, null, null)) ); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseKnnBitVectorsFormatTestCase.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseKnnBitVectorsFormatTestCase.java new file mode 100644 index 0000000000000..ba4d5275214b6 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/BaseKnnBitVectorsFormatTestCase.java @@ -0,0 +1,149 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.StoredFields; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseIndexFileFormatTestCase; +import org.elasticsearch.common.logging.LogConfigurator; + +import java.io.IOException; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +abstract class BaseKnnBitVectorsFormatTestCase extends BaseIndexFileFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Override + protected void addRandomFields(Document doc) { + doc.add(new KnnByteVectorField("v2", randomVector(30), similarityFunction)); + } + + protected VectorSimilarityFunction similarityFunction; + + protected VectorSimilarityFunction randomSimilarity() { + return VectorSimilarityFunction.values()[random().nextInt(VectorSimilarityFunction.values().length)]; + } + + byte[] randomVector(int dims) { + byte[] vector = new byte[dims]; + random().nextBytes(vector); + return vector; + } + + public void testRandom() throws Exception { + IndexWriterConfig iwc = newIndexWriterConfig(); + if (random().nextBoolean()) { + iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT))); + } + String fieldName = "field"; + try (Directory dir = newDirectory(); IndexWriter iw = new IndexWriter(dir, iwc)) { + int numDoc = atLeast(100); + int dimension = atLeast(10); + if (dimension % 2 != 0) { + dimension++; + } + byte[] scratch = new byte[dimension]; + int numValues = 0; + byte[][] values = new byte[numDoc][]; + for (int i = 0; i < numDoc; i++) { + if (random().nextInt(7) != 3) { + // usually index a vector value for a doc + values[i] = randomVector(dimension); + ++numValues; + } + if (random().nextBoolean() && values[i] != null) { + // sometimes use a shared scratch array + System.arraycopy(values[i], 0, scratch, 0, scratch.length); + add(iw, fieldName, i, scratch, similarityFunction); + } else { + add(iw, fieldName, i, values[i], similarityFunction); + } + if (random().nextInt(10) == 2) { + // sometimes delete a random document + int idToDelete = random().nextInt(i + 1); + iw.deleteDocuments(new Term("id", Integer.toString(idToDelete))); + // and remember that it was deleted + if (values[idToDelete] != null) { + values[idToDelete] = null; + --numValues; + } + } + if (random().nextInt(10) == 3) { + iw.commit(); + } + } + int numDeletes = 0; + try (IndexReader reader = DirectoryReader.open(iw)) { + int valueCount = 0, totalSize = 0; + for (LeafReaderContext ctx : reader.leaves()) { + ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName); + if (vectorValues == null) { + continue; + } + totalSize += vectorValues.size(); + StoredFields storedFields = ctx.reader().storedFields(); + int docId; + while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) { + byte[] v = vectorValues.vectorValue(); + assertEquals(dimension, v.length); + String idString = storedFields.document(docId).getField("id").stringValue(); + int id = Integer.parseInt(idString); + if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) { + assertArrayEquals(idString, values[id], v); + ++valueCount; + } else { + ++numDeletes; + assertNull(values[id]); + } + } + } + assertEquals(numValues, valueCount); + assertEquals(numValues, totalSize - numDeletes); + } + } + } + + private void add(IndexWriter iw, String field, int id, byte[] vector, VectorSimilarityFunction similarity) throws IOException { + add(iw, field, id, random().nextInt(100), vector, similarity); + } + + private void add(IndexWriter iw, String field, int id, int sortKey, byte[] vector, VectorSimilarityFunction similarityFunction) + throws IOException { + Document doc = new Document(); + if (vector != null) { + doc.add(new KnnByteVectorField(field, vector, similarityFunction)); + } + doc.add(new NumericDocValuesField("sortkey", sortKey)); + String idString = Integer.toString(id); + doc.add(new StringField("id", idString, Field.Store.YES)); + Term idTerm = new Term("id", idString); + iw.updateDocument(idTerm, doc); + } + +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormatTests.java index 2f9148e80988e..b4f82e91c39c1 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormatTests.java @@ -12,8 +12,15 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99Codec; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.elasticsearch.common.logging.LogConfigurator; public class ES813FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + @Override protected Codec getCodec() { return new Lucene99Codec() { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormatTests.java index 07a922efd21a6..7bb2e9e0284f1 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormatTests.java @@ -12,8 +12,15 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99Codec; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.elasticsearch.common.logging.LogConfigurator; public class ES813Int8FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + @Override protected Codec getCodec() { return new Lucene99Codec() { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorFormatTests.java new file mode 100644 index 0000000000000..c9a5a8e76a041 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorFormatTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99Codec; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.junit.Before; + +public class ES815BitFlatVectorFormatTests extends BaseKnnBitVectorsFormatTestCase { + + @Override + protected Codec getCodec() { + return new Lucene99Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new ES815BitFlatVectorFormat(); + } + }; + } + + @Before + public void init() { + similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + } + +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormatTests.java new file mode 100644 index 0000000000000..3525d5b619565 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormatTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99Codec; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.junit.Before; + +public class ES815HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCase { + + @Override + protected Codec getCodec() { + return new Lucene99Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new ES815HnswBitVectorsFormat(); + } + }; + } + + @Before + public void init() { + similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/KeywordFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/KeywordFieldMapperTests.java index e06ed1736cca2..afebe1a008468 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/KeywordFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/KeywordFieldMapperTests.java @@ -646,6 +646,12 @@ protected boolean supportsIgnoreMalformed() { return false; } + @Override + protected BlockReaderSupport getSupportedReaders(MapperService mapper, String loaderFieldName) { + MappedFieldType ft = mapper.fieldType(loaderFieldName); + return new BlockReaderSupport(ft.hasDocValues(), ft.hasDocValues() || ft.isStored(), mapper, loaderFieldName); + } + @Override protected Function loadBlockExpected() { return v -> ((BytesRef) v).utf8ToString(); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/TextFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/TextFieldMapperTests.java index 50d15be2256ed..8330cf1f5f794 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/TextFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/TextFieldMapperTests.java @@ -1336,12 +1336,7 @@ public void testBlockLoaderParentFromRowStrideReader() throws IOException { private void testBlockLoaderFromParent(boolean columnReader, boolean syntheticSource) throws IOException { boolean storeParent = randomBoolean(); - KeywordFieldSyntheticSourceSupport kwdSupport = new KeywordFieldSyntheticSourceSupport( - null, - storeParent, - null, - false == storeParent - ); + KeywordFieldSyntheticSourceSupport kwdSupport = new KeywordFieldSyntheticSourceSupport(null, storeParent, null, false); SyntheticSourceExample example = kwdSupport.example(5); CheckedConsumer buildFields = b -> { b.startObject("field"); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/BinaryDenseVectorScriptDocValuesTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/BinaryDenseVectorScriptDocValuesTests.java index ff5baf8ba0877..1df42368041ac 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/BinaryDenseVectorScriptDocValuesTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/BinaryDenseVectorScriptDocValuesTests.java @@ -236,8 +236,8 @@ public long cost() { public static BytesRef mockEncodeDenseVector(float[] values, ElementType elementType, IndexVersion indexVersion) { int numBytes = indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION) - ? elementType.elementBytes * values.length + DenseVectorFieldMapper.MAGNITUDE_BYTES - : elementType.elementBytes * values.length; + ? elementType.getNumBytes(values.length) + DenseVectorFieldMapper.MAGNITUDE_BYTES + : elementType.getNumBytes(values.length); double dotProduct = 0f; ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes); for (float value : values) { diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 6c3f2e19ad4b1..5397e4cd335ff 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -71,11 +71,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase { private final ElementType elementType; private final boolean indexed; private final boolean indexOptionsSet; + private final int dims; public DenseVectorFieldMapperTests() { - this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT); + this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT); this.indexed = randomBoolean(); this.indexOptionsSet = this.indexed && randomBoolean(); + this.dims = ElementType.BIT == elementType ? 4 * Byte.SIZE : 4; } @Override @@ -89,7 +91,7 @@ protected void minimalMapping(XContentBuilder b, IndexVersion indexVersion) thro } private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws IOException { - b.field("type", "dense_vector").field("dims", 4); + b.field("type", "dense_vector").field("dims", dims); if (elementType != ElementType.FLOAT) { b.field("element_type", elementType.toString()); } @@ -108,7 +110,7 @@ private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws I b.endObject(); } if (indexed) { - b.field("similarity", "dot_product"); + b.field("similarity", elementType == ElementType.BIT ? "l2_norm" : "dot_product"); if (indexOptionsSet) { b.startObject("index_options"); b.field("type", "hnsw"); @@ -121,52 +123,86 @@ private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws I @Override protected Object getSampleValueForDocument() { - return elementType == ElementType.BYTE ? List.of((byte) 1, (byte) 1, (byte) 1, (byte) 1) : List.of(0.5, 0.5, 0.5, 0.5); + return elementType == ElementType.FLOAT ? List.of(0.5, 0.5, 0.5, 0.5) : List.of((byte) 1, (byte) 1, (byte) 1, (byte) 1); } @Override protected void registerParameters(ParameterChecker checker) throws IOException { checker.registerConflictCheck( "dims", - fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4)), - fieldMapping(b -> b.field("type", "dense_vector").field("dims", 5)) + fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims)), + fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims + 8)) ); checker.registerConflictCheck( "similarity", - fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "dot_product")), - fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "l2_norm")) + fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "dot_product")), + fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "l2_norm")) ); checker.registerConflictCheck( "index", - fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "dot_product")), - fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", false)) + fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "dot_product")), + fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", false)) ); checker.registerConflictCheck( "element_type", fieldMapping( b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .field("similarity", "dot_product") .field("element_type", "byte") ), fieldMapping( b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .field("similarity", "dot_product") .field("element_type", "float") ) ); + checker.registerConflictCheck( + "element_type", + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dims", dims) + .field("index", true) + .field("similarity", "l2_norm") + .field("element_type", "float") + ), + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dims", dims) + .field("index", true) + .field("similarity", "l2_norm") + .field("element_type", "bit") + ) + ); + checker.registerConflictCheck( + "element_type", + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dims", dims) + .field("index", true) + .field("similarity", "l2_norm") + .field("element_type", "byte") + ), + fieldMapping( + b -> b.field("type", "dense_vector") + .field("dims", dims) + .field("index", true) + .field("similarity", "l2_norm") + .field("element_type", "bit") + ) + ); checker.registerUpdateCheck( b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "flat") .endObject(), b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "int8_flat") @@ -175,13 +211,13 @@ protected void registerParameters(ParameterChecker checker) throws IOException { ); checker.registerUpdateCheck( b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "flat") .endObject(), b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "hnsw") @@ -190,13 +226,13 @@ protected void registerParameters(ParameterChecker checker) throws IOException { ); checker.registerUpdateCheck( b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "flat") .endObject(), b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "int8_hnsw") @@ -205,13 +241,13 @@ protected void registerParameters(ParameterChecker checker) throws IOException { ); checker.registerUpdateCheck( b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "int8_flat") .endObject(), b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "hnsw") @@ -220,13 +256,13 @@ protected void registerParameters(ParameterChecker checker) throws IOException { ); checker.registerUpdateCheck( b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "int8_flat") .endObject(), b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "int8_hnsw") @@ -235,13 +271,13 @@ protected void registerParameters(ParameterChecker checker) throws IOException { ); checker.registerUpdateCheck( b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "hnsw") .endObject(), b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "int8_hnsw") @@ -252,7 +288,7 @@ protected void registerParameters(ParameterChecker checker) throws IOException { "index_options", fieldMapping( b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "hnsw") @@ -260,7 +296,7 @@ protected void registerParameters(ParameterChecker checker) throws IOException { ), fieldMapping( b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("index", true) .startObject("index_options") .field("type", "flat") @@ -353,7 +389,7 @@ public void testMergeDims() throws IOException { mapping = mapping(b -> { b.startObject("field"); b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("similarity", "cosine") .field("index", true) .startObject("index_options") @@ -648,7 +684,7 @@ public void testInvalidParameters() { () -> createDocumentMapper( fieldMapping( b -> b.field("type", "dense_vector") - .field("dims", 4) + .field("dims", dims) .field("element_type", "byte") .field("similarity", "l2_norm") .field("index", true) @@ -1020,6 +1056,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) { } yield floats; } + case BIT -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions() / 8); }; } @@ -1196,7 +1233,7 @@ public void testKnnVectorsFormat() throws IOException { boolean setEfConstruction = randomBoolean(); MapperService mapperService = createMapperService(fieldMapping(b -> { b.field("type", "dense_vector"); - b.field("dims", 4); + b.field("dims", dims); b.field("index", true); b.field("similarity", "dot_product"); b.startObject("index_options"); @@ -1234,7 +1271,7 @@ public void testKnnQuantizedFlatVectorsFormat() throws IOException { for (String quantizedFlatFormat : new String[] { "int8_flat", "int4_flat" }) { MapperService mapperService = createMapperService(fieldMapping(b -> { b.field("type", "dense_vector"); - b.field("dims", 4); + b.field("dims", dims); b.field("index", true); b.field("similarity", "dot_product"); b.startObject("index_options"); @@ -1275,7 +1312,7 @@ public void testKnnQuantizedHNSWVectorsFormat() throws IOException { float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true); MapperService mapperService = createMapperService(fieldMapping(b -> { b.field("type", "dense_vector"); - b.field("dims", 4); + b.field("dims", dims); b.field("index", true); b.field("similarity", "dot_product"); b.startObject("index_options"); @@ -1316,7 +1353,7 @@ public void testKnnHalfByteQuantizedHNSWVectorsFormat() throws IOException { float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true); MapperService mapperService = createMapperService(fieldMapping(b -> { b.field("type", "dense_vector"); - b.field("dims", 4); + b.field("dims", dims); b.field("index", true); b.field("similarity", "dot_product"); b.startObject("index_options"); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index f178e66955fdc..96917a42cff65 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -185,10 +185,12 @@ public void testCreateNestedKnnQuery() { queryVector[i] = randomByte(); floatQueryVector[i] = queryVector[i]; } - Query query = field.createKnnQuery(queryVector, 10, null, null, producer); + VectorData vectorData = new VectorData(null, queryVector); + Query query = field.createKnnQuery(vectorData, 10, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); - query = field.createKnnQuery(floatQueryVector, 10, null, null, producer); + vectorData = new VectorData(floatQueryVector, null); + query = field.createKnnQuery(vectorData, 10, null, null, producer); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); } } @@ -321,7 +323,8 @@ public void testCreateKnnQueryMaxDims() { for (int i = 0; i < 4096; i++) { queryVector[i] = randomByte(); } - Query query = fieldWith4096dims.createKnnQuery(queryVector, 10, null, null, null); + VectorData vectorData = new VectorData(null, queryVector); + Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, null, null, null); assertThat(query, instanceOf(KnnByteVectorQuery.class)); } } @@ -359,7 +362,10 @@ public void testByteCreateKnnQuery() { ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); - e = expectThrows(IllegalArgumentException.class, () -> cosineField.createKnnQuery(new byte[] { 0, 0, 0 }, 10, null, null, null)); + e = expectThrows( + IllegalArgumentException.class, + () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, null, null, null) + ); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); } } diff --git a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java index d272aaab1b231..9d53b95e01db3 100644 --- a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java +++ b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java @@ -4922,7 +4922,11 @@ public void testShardExposesWriteLoadStats() throws Exception { final var recoveryFinishedLatch = new CountDownLatch(1); final var recoveryListener = new PeerRecoveryTargetService.RecoveryListener() { @Override - public void onRecoveryDone(RecoveryState state, ShardLongFieldRange timestampMillisFieldRange) { + public void onRecoveryDone( + RecoveryState state, + ShardLongFieldRange timestampMillisFieldRange, + ShardLongFieldRange eventIngestedMillisFieldRange + ) { recoveryFinishedLatch.countDown(); } diff --git a/server/src/test/java/org/elasticsearch/index/store/FsDirectoryFactoryTests.java b/server/src/test/java/org/elasticsearch/index/store/FsDirectoryFactoryTests.java index 49de52357d0ba..2fdeda052381b 100644 --- a/server/src/test/java/org/elasticsearch/index/store/FsDirectoryFactoryTests.java +++ b/server/src/test/java/org/elasticsearch/index/store/FsDirectoryFactoryTests.java @@ -166,7 +166,7 @@ private void doTestStoreDirectory(Path tempDir, String typeSettingValue, IndexMo assertTrue(type + " " + directory.toString(), directory instanceof NIOFSDirectory); break; case MMAPFS: - assertTrue(type + " " + directory.toString(), directory instanceof MMapDirectory); + assertTrue(type + " " + directory.getClass().getName() + " " + directory, directory instanceof MMapDirectory); break; case FS: if (Constants.JRE_IS_64BIT && MMapDirectory.UNMAP_SUPPORTED) { diff --git a/server/src/test/java/org/elasticsearch/indices/cluster/ClusterStateChanges.java b/server/src/test/java/org/elasticsearch/indices/cluster/ClusterStateChanges.java index 60d73f873bbd4..0b67da4067fc9 100644 --- a/server/src/test/java/org/elasticsearch/indices/cluster/ClusterStateChanges.java +++ b/server/src/test/java/org/elasticsearch/indices/cluster/ClusterStateChanges.java @@ -489,6 +489,7 @@ public ClusterState applyStartedShards(ClusterState clusterState, Map future = new PlainActionFuture<>(); RecoveryTarget recoveryTarget = new RecoveryTarget(shard, null, 0L, null, null, new PeerRecoveryTargetService.RecoveryListener() { @Override - public void onRecoveryDone(RecoveryState state, ShardLongFieldRange timestampMillisFieldRange) { + public void onRecoveryDone( + RecoveryState state, + ShardLongFieldRange timestampMillisFieldRange, + ShardLongFieldRange eventIngestedMillisFieldRange + ) { future.onResponse(null); } diff --git a/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java b/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java index d5ac683569eba..fc8f1988a732b 100644 --- a/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java +++ b/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java @@ -444,7 +444,11 @@ public long addDocuments(Iterable> expectThrows(Exception.class, () -> group.recoverReplica(replica, (shard, sourceNode) -> { return new RecoveryTarget(shard, sourceNode, 0L, null, null, new PeerRecoveryTargetService.RecoveryListener() { @Override - public void onRecoveryDone(RecoveryState state, ShardLongFieldRange timestampMillisFieldRange) { + public void onRecoveryDone( + RecoveryState state, + ShardLongFieldRange timestampMillisFieldRange, + ShardLongFieldRange eventIngestedMillisFieldRange + ) { throw new AssertionError("recovery must fail"); } diff --git a/server/src/test/java/org/elasticsearch/recovery/RecoveriesCollectionTests.java b/server/src/test/java/org/elasticsearch/recovery/RecoveriesCollectionTests.java index 1540d3223ae72..fb159f8fb208d 100644 --- a/server/src/test/java/org/elasticsearch/recovery/RecoveriesCollectionTests.java +++ b/server/src/test/java/org/elasticsearch/recovery/RecoveriesCollectionTests.java @@ -31,7 +31,11 @@ public class RecoveriesCollectionTests extends ESIndexLevelReplicationTestCase { static final PeerRecoveryTargetService.RecoveryListener listener = new PeerRecoveryTargetService.RecoveryListener() { @Override - public void onRecoveryDone(RecoveryState state, ShardLongFieldRange timestampMillisFieldRange) { + public void onRecoveryDone( + RecoveryState state, + ShardLongFieldRange timestampMillisFieldRange, + ShardLongFieldRange eventIngestedMillisFieldRange + ) { } @@ -69,7 +73,11 @@ public void testRecoveryTimeout() throws Exception { shards.addReplica(), new PeerRecoveryTargetService.RecoveryListener() { @Override - public void onRecoveryDone(RecoveryState state, ShardLongFieldRange timestampMillisFieldRange) { + public void onRecoveryDone( + RecoveryState state, + ShardLongFieldRange timestampMillisFieldRange, + ShardLongFieldRange eventIngestedMillisFieldRange + ) { latch.countDown(); } diff --git a/server/src/test/java/org/elasticsearch/reservedstate/ReservedClusterStateHandlerTests.java b/server/src/test/java/org/elasticsearch/reservedstate/ReservedClusterStateHandlerTests.java index c3b35dc429ebc..096d40ff5b979 100644 --- a/server/src/test/java/org/elasticsearch/reservedstate/ReservedClusterStateHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/reservedstate/ReservedClusterStateHandlerTests.java @@ -16,6 +16,8 @@ import java.io.IOException; +import static org.hamcrest.Matchers.is; + public class ReservedClusterStateHandlerTests extends ESTestCase { public void testValidation() { ReservedClusterStateHandler handler = new ReservedClusterStateHandler<>() { @@ -36,9 +38,9 @@ public ValidRequest fromXContent(XContentParser parser) throws IOException { }; handler.validate(new ValidRequest()); - assertEquals( - "Validation error", - expectThrows(IllegalStateException.class, () -> handler.validate(new InvalidRequest())).getMessage() + assertThat( + expectThrows(IllegalStateException.class, () -> handler.validate(new InvalidRequest())).getMessage(), + is("Validation error") ); } diff --git a/server/src/test/java/org/elasticsearch/reservedstate/action/ReservedClusterSettingsActionTests.java b/server/src/test/java/org/elasticsearch/reservedstate/action/ReservedClusterSettingsActionTests.java index 08e8e46d4b95a..8e7c70ebef896 100644 --- a/server/src/test/java/org/elasticsearch/reservedstate/action/ReservedClusterSettingsActionTests.java +++ b/server/src/test/java/org/elasticsearch/reservedstate/action/ReservedClusterSettingsActionTests.java @@ -26,7 +26,8 @@ import static org.elasticsearch.common.settings.Setting.Property.Dynamic; import static org.elasticsearch.common.settings.Setting.Property.NodeScope; import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; public class ReservedClusterSettingsActionTests extends ESTestCase { @@ -53,9 +54,9 @@ public void testValidation() throws Exception { "indices.recovery.min_bytes_per_sec": "50mb" }"""; - assertEquals( - "persistent setting [indices.recovery.min_bytes_per_sec], not recognized", - expectThrows(IllegalArgumentException.class, () -> processJSON(action, prevState, badPolicyJSON)).getMessage() + assertThat( + expectThrows(IllegalArgumentException.class, () -> processJSON(action, prevState, badPolicyJSON)).getMessage(), + is("persistent setting [indices.recovery.min_bytes_per_sec], not recognized") ); } @@ -69,7 +70,7 @@ public void testSetUnsetSettings() throws Exception { String emptyJSON = ""; TransformState updatedState = processJSON(action, prevState, emptyJSON); - assertEquals(0, updatedState.keys().size()); + assertThat(updatedState.keys(), empty()); assertEquals(prevState.state(), updatedState.state()); String settingsJSON = """ @@ -89,8 +90,8 @@ public void testSetUnsetSettings() throws Exception { prevState = updatedState; updatedState = processJSON(action, prevState, settingsJSON); assertThat(updatedState.keys(), containsInAnyOrder("indices.recovery.max_bytes_per_sec", "cluster.remote.cluster_one.seeds")); - assertEquals("50mb", updatedState.state().metadata().persistentSettings().get("indices.recovery.max_bytes_per_sec")); - assertEquals("[127.0.0.1:9300]", updatedState.state().metadata().persistentSettings().get("cluster.remote.cluster_one.seeds")); + assertThat(updatedState.state().metadata().persistentSettings().get("indices.recovery.max_bytes_per_sec"), is("50mb")); + assertThat(updatedState.state().metadata().persistentSettings().get("cluster.remote.cluster_one.seeds"), is("[127.0.0.1:9300]")); String oneSettingJSON = """ { @@ -100,12 +101,12 @@ public void testSetUnsetSettings() throws Exception { prevState = updatedState; updatedState = processJSON(action, prevState, oneSettingJSON); assertThat(updatedState.keys(), containsInAnyOrder("indices.recovery.max_bytes_per_sec")); - assertEquals("25mb", updatedState.state().metadata().persistentSettings().get("indices.recovery.max_bytes_per_sec")); + assertThat(updatedState.state().metadata().persistentSettings().get("indices.recovery.max_bytes_per_sec"), is("25mb")); assertNull(updatedState.state().metadata().persistentSettings().get("cluster.remote.cluster_one.seeds")); prevState = updatedState; updatedState = processJSON(action, prevState, emptyJSON); - assertEquals(0, updatedState.keys().size()); + assertThat(updatedState.keys(), empty()); assertNull(updatedState.state().metadata().persistentSettings().get("indices.recovery.max_bytes_per_sec")); } @@ -130,8 +131,8 @@ public void testSettingNameNormalization() throws Exception { TransformState newState = processJSON(testAction, prevState, json); assertThat(newState.keys(), containsInAnyOrder("dummy.setting1", "dummy.setting2")); - assertThat(newState.state().metadata().persistentSettings().get("dummy.setting1"), equalTo("value1")); - assertThat(newState.state().metadata().persistentSettings().get("dummy.setting2"), equalTo("value2")); + assertThat(newState.state().metadata().persistentSettings().get("dummy.setting1"), is("value1")); + assertThat(newState.state().metadata().persistentSettings().get("dummy.setting2"), is("value2")); String jsonRemoval = """ { @@ -142,6 +143,6 @@ public void testSettingNameNormalization() throws Exception { """; TransformState newState2 = processJSON(testAction, prevState, jsonRemoval); assertThat(newState2.keys(), containsInAnyOrder("dummy.setting2")); - assertThat(newState2.state().metadata().persistentSettings().get("dummy.setting2"), equalTo("value2")); + assertThat(newState2.state().metadata().persistentSettings().get("dummy.setting2"), is("value2")); } } diff --git a/server/src/test/java/org/elasticsearch/reservedstate/service/FileSettingsServiceTests.java b/server/src/test/java/org/elasticsearch/reservedstate/service/FileSettingsServiceTests.java index aca5d2cbee2c9..cc5f0e22ad4ee 100644 --- a/server/src/test/java/org/elasticsearch/reservedstate/service/FileSettingsServiceTests.java +++ b/server/src/test/java/org/elasticsearch/reservedstate/service/FileSettingsServiceTests.java @@ -35,7 +35,6 @@ import org.mockito.stubbing.Answer; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardCopyOption; @@ -142,7 +141,7 @@ public void testInitialFileError() throws Exception { doAnswer((Answer) invocation -> { ((Consumer) invocation.getArgument(2)).accept(new IllegalStateException("Some exception")); return null; - }).when(stateService).process(any(), (XContentParser) any(), any()); + }).when(stateService).process(any(), any(XContentParser.class), any()); AtomicBoolean settingsChanged = new AtomicBoolean(false); CountDownLatch latch = new CountDownLatch(1); @@ -186,7 +185,7 @@ public void testInitialFileWorks() throws Exception { doAnswer((Answer) invocation -> { ((Consumer) invocation.getArgument(2)).accept(null); return null; - }).when(stateService).process(any(), (XContentParser) any(), any()); + }).when(stateService).process(any(), any(XContentParser.class), any()); CountDownLatch latch = new CountDownLatch(1); @@ -263,8 +262,7 @@ public void testStopWorksInMiddleOfProcessing() throws Exception { // helpers private void writeTestFile(Path path, String contents) throws IOException { Path tempFilePath = createTempFile(); - - Files.write(tempFilePath, contents.getBytes(StandardCharsets.UTF_8)); + Files.writeString(tempFilePath, contents); Files.move(tempFilePath, path, StandardCopyOption.ATOMIC_MOVE); } } diff --git a/server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java b/server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java index fe9401284b9f5..5d675b99ba9ab 100644 --- a/server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java +++ b/server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.ActionResponse; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateAckListener; @@ -34,11 +33,11 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.junit.Assert; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; @@ -50,14 +49,15 @@ import java.util.stream.Collectors; import static org.elasticsearch.reservedstate.service.ReservedStateUpdateTask.checkMetadataVersion; -import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.startsWith; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doReturn; @@ -106,9 +106,9 @@ public void testOperatorController() throws IOException { AtomicReference x = new AtomicReference<>(); try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, testJSON)) { - controller.process("operator", parser, (e) -> x.set(e)); + controller.process("operator", parser, x::set); - assertTrue(x.get() instanceof IllegalStateException); + assertThat(x.get(), instanceOf(IllegalStateException.class)); assertThat(x.get().getMessage(), containsString("Error processing state change request for operator")); } @@ -136,16 +136,11 @@ public void testOperatorController() throws IOException { """; try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, testJSON)) { - controller.process("operator", parser, (e) -> { - if (e != null) { - fail("Should not fail"); - } - }); + controller.process("operator", parser, Assert::assertNull); } } public void testUpdateStateTasks() throws Exception { - ClusterService clusterService = mock(ClusterService.class); RerouteService rerouteService = mock(RerouteService.class); ClusterState state = ClusterState.builder(new ClusterName("test")).build(); @@ -155,21 +150,7 @@ public void testUpdateStateTasks() throws Exception { AtomicBoolean successCalled = new AtomicBoolean(false); ReservedStateUpdateTask task = spy( - new ReservedStateUpdateTask( - "test", - null, - List.of(), - Collections.emptyMap(), - Collections.emptySet(), - errorState -> {}, - new ActionListener<>() { - @Override - public void onResponse(ActionResponse.Empty empty) {} - - @Override - public void onFailure(Exception e) {} - } - ) + new ReservedStateUpdateTask("test", null, List.of(), Map.of(), Set.of(), errorState -> {}, ActionListener.noop()) ); doReturn(state).when(task).execute(any()); @@ -223,15 +204,7 @@ public void testErrorStateTask() throws Exception { ReservedStateErrorTask task = spy( new ReservedStateErrorTask( new ErrorState("test", 1L, List.of("some parse error", "some io error"), ReservedStateErrorMetadata.ErrorKind.PARSING), - new ActionListener<>() { - @Override - public void onResponse(ActionResponse.Empty empty) { - listenerCompleted.set(true); - } - - @Override - public void onFailure(Exception e) {} - } + ActionListener.running(() -> listenerCompleted.set(true)) ) ); @@ -276,8 +249,8 @@ public Releasable captureResponseHeaders() { ReservedStateMetadata operatorMetadata = newState.metadata().reservedStateMetadata().get("test"); assertNotNull(operatorMetadata); assertNotNull(operatorMetadata.errorMetadata()); - assertEquals(1L, (long) operatorMetadata.errorMetadata().version()); - assertEquals(ReservedStateErrorMetadata.ErrorKind.PARSING, operatorMetadata.errorMetadata().errorKind()); + assertThat(operatorMetadata.errorMetadata().version(), is(1L)); + assertThat(operatorMetadata.errorMetadata().errorKind(), is(ReservedStateErrorMetadata.ErrorKind.PARSING)); assertThat(operatorMetadata.errorMetadata().errors(), contains("some parse error", "some io error")); assertTrue(listenerCompleted.get()); } @@ -352,13 +325,7 @@ public Map fromXContent(XContentParser parser) throws IOExceptio Map.of(exceptionThrower.name(), exceptionThrower, newStateMaker.name(), newStateMaker), orderedHandlers, errorState -> assertFalse(ReservedStateErrorTask.isNewError(operatorMetadata, errorState.version())), - new ActionListener<>() { - @Override - public void onResponse(ActionResponse.Empty empty) {} - - @Override - public void onFailure(Exception e) {} - } + ActionListener.noop() ); ClusterService clusterService = mock(ClusterService.class); @@ -367,9 +334,8 @@ public void onFailure(Exception e) {} ); var trialRunResult = controller.trialRun("namespace_one", state, chunk, new LinkedHashSet<>(orderedHandlers)); - assertEquals(0, trialRunResult.nonStateTransforms().size()); - assertEquals(1, trialRunResult.errors().size()); - assertTrue(trialRunResult.errors().get(0).contains("Error processing one state change:")); + assertThat(trialRunResult.nonStateTransforms(), empty()); + assertThat(trialRunResult.errors(), contains(containsString("Error processing one state change:"))); // We exit on duplicate errors before we update the cluster state error metadata assertThat( @@ -438,7 +404,7 @@ public Map fromXContent(XContentParser parser) throws IOExceptio public void testHandlerOrdering() { ReservedClusterStateHandler> oh1 = makeHandlerHelper("one", List.of("two", "three")); - ReservedClusterStateHandler> oh2 = makeHandlerHelper("two", Collections.emptyList()); + ReservedClusterStateHandler> oh2 = makeHandlerHelper("two", List.of()); ReservedClusterStateHandler> oh3 = makeHandlerHelper("three", List.of("two")); ClusterService clusterService = mock(ClusterService.class); @@ -447,16 +413,16 @@ public void testHandlerOrdering() { assertThat(ordered, contains("two", "three", "one")); // assure that we bail on unknown handler - assertEquals( - "Unknown handler type: four", + assertThat( expectThrows(IllegalStateException.class, () -> controller.orderedStateHandlers(Set.of("one", "two", "three", "four"))) - .getMessage() + .getMessage(), + is("Unknown handler type: four") ); // assure that we bail on missing dependency link - assertEquals( - "Missing handler dependency definition: one -> three", - expectThrows(IllegalStateException.class, () -> controller.orderedStateHandlers(Set.of("one", "two"))).getMessage() + assertThat( + expectThrows(IllegalStateException.class, () -> controller.orderedStateHandlers(Set.of("one", "two"))).getMessage(), + is("Missing handler dependency definition: one -> three") ); // Change the second handler so that we create cycle @@ -481,7 +447,7 @@ public void testDuplicateHandlerNames() { ClusterState state = ClusterState.builder(clusterName).build(); when(clusterService.state()).thenReturn(state); - assertTrue( + assertThat( expectThrows( IllegalStateException.class, () -> new ReservedClusterStateService( @@ -489,7 +455,8 @@ public void testDuplicateHandlerNames() { mock(RerouteService.class), List.of(new ReservedClusterSettingsAction(clusterSettings), new TestHandler()) ) - ).getMessage().startsWith("Duplicate key cluster_settings") + ).getMessage(), + startsWith("Duplicate key cluster_settings") ); } @@ -506,8 +473,8 @@ public void testCheckAndReportError() { var version = new ReservedStateVersion(2L, Version.CURRENT); var error = controller.checkAndReportError("test", List.of("test error"), version); - assertThat(error, allOf(notNullValue(), instanceOf(IllegalStateException.class))); - assertEquals("Error processing state change request for test, errors: test error", error.getMessage()); + assertThat(error, instanceOf(IllegalStateException.class)); + assertThat(error.getMessage(), is("Error processing state change request for test, errors: test error")); verify(controller, times(1)).updateErrorState(any()); } @@ -538,7 +505,7 @@ public String name() { @Override public TransformState transform(Object source, TransformState prevState) { - return new TransformState(prevState.state(), prevState.keys(), (l) -> internalKeys(l)); + return new TransformState(prevState.state(), prevState.keys(), this::internalKeys); } private void internalKeys(ActionListener listener) { @@ -577,13 +544,13 @@ public Map fromXContent(XContentParser parser) throws IOExceptio var trialRunResult = controller.trialRun("namespace_one", state, chunk, new LinkedHashSet<>(orderedHandlers)); - assertEquals(1, trialRunResult.nonStateTransforms().size()); - assertEquals(0, trialRunResult.errors().size()); + assertThat(trialRunResult.nonStateTransforms(), hasSize(1)); + assertThat(trialRunResult.errors(), empty()); trialRunResult.nonStateTransforms().get(0).accept(new ActionListener<>() { @Override public void onResponse(NonStateTransformResult nonStateTransformResult) { assertThat(nonStateTransformResult.updatedKeys(), containsInAnyOrder("key non-state")); - assertEquals("non-state", nonStateTransformResult.handlerName()); + assertThat(nonStateTransformResult.handlerName(), is("non-state")); } @Override @@ -610,7 +577,7 @@ public String name() { @Override public TransformState transform(Object source, TransformState prevState) { - return new TransformState(prevState.state(), prevState.keys(), (l) -> internalKeys(l)); + return new TransformState(prevState.state(), prevState.keys(), this::internalKeys); } private void internalKeys(ActionListener listener) { @@ -642,10 +609,10 @@ public Map fromXContent(XContentParser parser) throws IOExceptio "namespace_one", state, chunk, - new LinkedHashSet<>(handlers.stream().map(h -> h.name()).toList()) + handlers.stream().map(ReservedClusterStateHandler::name).collect(Collectors.toCollection(LinkedHashSet::new)) ); - assertEquals(count, trialRunResult.nonStateTransforms().size()); + assertThat(trialRunResult.nonStateTransforms(), hasSize(count)); ReservedClusterStateService.executeNonStateTransformationSteps(trialRunResult.nonStateTransforms(), new ActionListener<>() { @Override public void onResponse(Collection nonStateTransformResults) { @@ -657,12 +624,15 @@ public void onResponse(Collection nonStateTransformResu expectedValues.add("key non-state:" + i); } assertThat( - nonStateTransformResults.stream().map(n -> n.handlerName()).collect(Collectors.toSet()), - containsInAnyOrder(expectedHandlers.toArray(new String[0])) + nonStateTransformResults.stream().map(NonStateTransformResult::handlerName).collect(Collectors.toSet()), + containsInAnyOrder(expectedHandlers.toArray()) ); assertThat( - nonStateTransformResults.stream().map(n -> n.updatedKeys()).flatMap(Set::stream).collect(Collectors.toSet()), - containsInAnyOrder(expectedValues.toArray(new String[0])) + nonStateTransformResults.stream() + .map(NonStateTransformResult::updatedKeys) + .flatMap(Set::stream) + .collect(Collectors.toSet()), + containsInAnyOrder(expectedValues.toArray()) ); } @@ -673,7 +643,7 @@ public void onFailure(Exception e) { }); } - class TestHandler implements ReservedClusterStateHandler> { + static class TestHandler implements ReservedClusterStateHandler> { @Override public String name() { diff --git a/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java b/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java index 80c93e05b8bd5..8bd53047b2dc7 100644 --- a/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java +++ b/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java @@ -114,10 +114,10 @@ public void testFloatVectorClassBindings() throws IOException { ); e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, queryVector, fieldName)); - assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors")); + assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors")); e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName)); - assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors")); + assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors")); // Check scripting infrastructure integration DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName); diff --git a/server/src/test/java/org/elasticsearch/search/TelemetryMetrics/SearchResponseCountTelemetryTests.java b/server/src/test/java/org/elasticsearch/search/TelemetryMetrics/SearchResponseCountTelemetryTests.java new file mode 100644 index 0000000000000..af2137c046235 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/TelemetryMetrics/SearchResponseCountTelemetryTests.java @@ -0,0 +1,241 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.TelemetryMetrics; + +import org.elasticsearch.action.search.ClearScrollResponse; +import org.elasticsearch.action.search.SearchPhaseExecutionException; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.PluginsService; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.rest.action.search.SearchResponseMetrics; +import org.elasticsearch.search.query.ThrowingQueryBuilder; +import org.elasticsearch.telemetry.Measurement; +import org.elasticsearch.telemetry.TestTelemetryPlugin; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.junit.After; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import static org.elasticsearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; +import static org.elasticsearch.index.query.QueryBuilders.simpleQueryStringQuery; +import static org.elasticsearch.rest.action.search.SearchResponseMetrics.RESPONSE_COUNT_TOTAL_COUNTER_NAME; +import static org.elasticsearch.rest.action.search.SearchResponseMetrics.RESPONSE_COUNT_TOTAL_STATUS_ATTRIBUTE_NAME; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertScrollResponsesAndHitCount; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchHitsWithoutFailures; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +public class SearchResponseCountTelemetryTests extends ESSingleNodeTestCase { + + private static final String indexName = "test_search_response_count_metrics"; + + private TestTelemetryPlugin getTestTelemetryPlugin() { + return getInstanceFromNode(PluginsService.class).filterPlugins(TestTelemetryPlugin.class).toList().get(0); + } + + @After + private void resetMeter() { + getTestTelemetryPlugin().resetMeter(); + } + + @Override + protected boolean resetNodeAfterTest() { + return true; + } + + @Before + public void setUpIndex() throws Exception { + var numPrimaries = randomIntBetween(3, 5); + createIndex( + indexName, + Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numPrimaries) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .build() + ); + ensureGreen(indexName); + + prepareIndex(indexName).setId("1").setSource("body", "red").setRefreshPolicy(IMMEDIATE).get(); + prepareIndex(indexName).setId("2").setSource("body", "green").setRefreshPolicy(IMMEDIATE).get(); + prepareIndex(indexName).setId("3").setSource("body", "blue").setRefreshPolicy(IMMEDIATE).get(); + prepareIndex(indexName).setId("4").setSource("body", "blue").setRefreshPolicy(IMMEDIATE).get(); + prepareIndex(indexName).setId("5").setSource("body", "pink").setRefreshPolicy(IMMEDIATE).get(); + prepareIndex(indexName).setId("6").setSource("body", "brown").setRefreshPolicy(IMMEDIATE).get(); + prepareIndex(indexName).setId("7").setSource("body", "red").setRefreshPolicy(IMMEDIATE).get(); + prepareIndex(indexName).setId("8").setSource("body", "purple").setRefreshPolicy(IMMEDIATE).get(); + prepareIndex(indexName).setId("9").setSource("body", "black").setRefreshPolicy(IMMEDIATE).get(); + prepareIndex(indexName).setId("10").setSource("body", "green").setRefreshPolicy(IMMEDIATE).get(); + } + + @Override + protected Collection> getPlugins() { + return pluginList(TestTelemetryPlugin.class, TestQueryBuilderPlugin.class); + } + + public static class TestQueryBuilderPlugin extends Plugin implements SearchPlugin { + public TestQueryBuilderPlugin() {} + + @Override + public List> getQueries() { + QuerySpec throwingSpec = new QuerySpec<>(ThrowingQueryBuilder.NAME, ThrowingQueryBuilder::new, p -> { + throw new IllegalStateException("not implemented"); + }); + + return List.of(throwingSpec); + } + } + + public void testSimpleQuery() throws Exception { + assertSearchHitsWithoutFailures(client().prepareSearch(indexName).setQuery(simpleQueryStringQuery("green")), "2", "10"); + assertBusy(() -> { + List measurements = getTestTelemetryPlugin().getLongCounterMeasurement(RESPONSE_COUNT_TOTAL_COUNTER_NAME); + assertThat(measurements.size(), equalTo(1)); + assertThat(measurements.get(0).getLong(), equalTo(1L)); + assertThat( + measurements.get(0).attributes().get(RESPONSE_COUNT_TOTAL_STATUS_ATTRIBUTE_NAME), + equalTo(SearchResponseMetrics.ResponseCountTotalStatus.SUCCESS.getDisplayName()) + ); + }); + } + + public void testSearchWithSingleShardFailure() throws Exception { + ThrowingQueryBuilder queryBuilder = new ThrowingQueryBuilder(randomLong(), new IllegalStateException("something bad"), 0); + SearchResponse searchResponse = client().prepareSearch(indexName).setQuery(queryBuilder).get(); + try { + assertThat(searchResponse.getFailedShards(), equalTo(1)); + assertBusy(() -> { + List measurements = getTestTelemetryPlugin().getLongCounterMeasurement(RESPONSE_COUNT_TOTAL_COUNTER_NAME); + assertThat(measurements.size(), equalTo(1)); + assertThat(measurements.get(0).getLong(), equalTo(1L)); + assertThat( + measurements.get(0).attributes().get(RESPONSE_COUNT_TOTAL_STATUS_ATTRIBUTE_NAME), + equalTo(SearchResponseMetrics.ResponseCountTotalStatus.PARTIAL_FAILURE.getDisplayName()) + ); + }); + } finally { + searchResponse.decRef(); + } + } + + public void testSearchWithAllShardsFail() throws Exception { + ThrowingQueryBuilder queryBuilder = new ThrowingQueryBuilder(randomLong(), new IllegalStateException("something bad"), indexName); + SearchPhaseExecutionException exception = expectThrows( + SearchPhaseExecutionException.class, + client().prepareSearch(indexName).setQuery(queryBuilder) + ); + assertThat(exception.getCause().getMessage(), containsString("something bad")); + assertBusy(() -> { + List measurements = getTestTelemetryPlugin().getLongCounterMeasurement(RESPONSE_COUNT_TOTAL_COUNTER_NAME); + assertThat(measurements.size(), equalTo(1)); + assertThat(measurements.get(0).getLong(), equalTo(1L)); + assertThat( + measurements.get(0).attributes().get(RESPONSE_COUNT_TOTAL_STATUS_ATTRIBUTE_NAME), + equalTo(SearchResponseMetrics.ResponseCountTotalStatus.FAILURE.getDisplayName()) + ); + }); + } + + public void testScroll() { + assertScrollResponsesAndHitCount( + client(), + TimeValue.timeValueSeconds(60), + client().prepareSearch(indexName).setSize(1).setQuery(simpleQueryStringQuery("green")), + 2, + (respNum, response) -> { + if (respNum <= 2) { + try { + assertBusy(() -> { + List measurements = getTestTelemetryPlugin().getLongCounterMeasurement( + RESPONSE_COUNT_TOTAL_COUNTER_NAME + ); + assertThat(measurements.size(), equalTo(1)); + assertThat(measurements.get(0).getLong(), equalTo(1L)); + assertThat( + measurements.get(0).attributes().get(RESPONSE_COUNT_TOTAL_STATUS_ATTRIBUTE_NAME), + equalTo(SearchResponseMetrics.ResponseCountTotalStatus.SUCCESS.getDisplayName()) + ); + }); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + resetMeter(); + } + ); + } + + public void testScrollWithSingleShardFailure() throws Exception { + ThrowingQueryBuilder queryBuilder = new ThrowingQueryBuilder(randomLong(), new IllegalStateException("something bad"), 0); + SearchRequestBuilder searchRequestBuilder = client().prepareSearch(indexName).setSize(1).setQuery(queryBuilder); + TimeValue keepAlive = TimeValue.timeValueSeconds(60); + searchRequestBuilder.setScroll(keepAlive); + List responses = new ArrayList<>(); + var scrollResponse = searchRequestBuilder.get(); + responses.add(scrollResponse); + try { + assertBusy(() -> { + List measurements = getTestTelemetryPlugin().getLongCounterMeasurement(RESPONSE_COUNT_TOTAL_COUNTER_NAME); + assertThat(measurements.size(), equalTo(1)); + assertThat(measurements.get(0).getLong(), equalTo(1L)); + assertThat( + measurements.get(0).attributes().get(RESPONSE_COUNT_TOTAL_STATUS_ATTRIBUTE_NAME), + equalTo(SearchResponseMetrics.ResponseCountTotalStatus.PARTIAL_FAILURE.getDisplayName()) + ); + }); + int numResponses = 1; + while (scrollResponse.getHits().getHits().length > 0) { + scrollResponse = client().prepareSearchScroll(scrollResponse.getScrollId()).setScroll(keepAlive).get(); + int expectedNumMeasurements = ++numResponses; + responses.add(scrollResponse); + assertBusy(() -> { + List measurements = getTestTelemetryPlugin().getLongCounterMeasurement(RESPONSE_COUNT_TOTAL_COUNTER_NAME); + // verify that one additional measurement recorded (in TransportScrollSearchAction) + assertThat(measurements.size(), equalTo(expectedNumMeasurements)); + // verify that zero shards failed in secondary scroll search rounds + assertThat(measurements.get(expectedNumMeasurements - 1).getLong(), equalTo(1L)); + assertThat( + measurements.get(expectedNumMeasurements - 1).attributes().get(RESPONSE_COUNT_TOTAL_STATUS_ATTRIBUTE_NAME), + equalTo(SearchResponseMetrics.ResponseCountTotalStatus.SUCCESS.getDisplayName()) + ); + }); + } + } finally { + ClearScrollResponse clear = client().prepareClearScroll().setScrollIds(Arrays.asList(scrollResponse.getScrollId())).get(); + responses.forEach(SearchResponse::decRef); + assertThat(clear.isSucceeded(), equalTo(true)); + } + } + + public void testScrollWithAllShardsFail() throws Exception { + ThrowingQueryBuilder queryBuilder = new ThrowingQueryBuilder(randomLong(), new IllegalStateException("something bad"), indexName); + SearchPhaseExecutionException exception = expectThrows( + SearchPhaseExecutionException.class, + client().prepareSearch(indexName).setSize(1).setQuery(queryBuilder).setScroll(TimeValue.timeValueSeconds(60)) + ); + assertThat(exception.getCause().getMessage(), containsString("something bad")); + assertBusy(() -> { + List measurements = getTestTelemetryPlugin().getLongCounterMeasurement(RESPONSE_COUNT_TOTAL_COUNTER_NAME); + assertThat(measurements.size(), equalTo(1)); + assertThat(measurements.get(0).getLong(), equalTo(1L)); + assertThat( + measurements.get(0).attributes().get(RESPONSE_COUNT_TOTAL_STATUS_ATTRIBUTE_NAME), + equalTo(SearchResponseMetrics.ResponseCountTotalStatus.FAILURE.getDisplayName()) + ); + }); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index b760262cd1ea6..fdd9b94cb5050 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -122,7 +122,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery; // The field should always be resolved to the concrete field Query knnVectorQueryBuilt = switch (elementType()) { - case BYTE -> new ESKnnByteVectorQuery( + case BYTE, BIT -> new ESKnnByteVectorQuery( VECTOR_FIELD, queryBuilder.queryVector().asByteVector(), queryBuilder.numCands(), @@ -145,7 +145,10 @@ public void testWrongDimension() { SearchExecutionContext context = createSearchExecutionContext(); KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10, null); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); - assertThat(e.getMessage(), containsString("the query vector has a different dimension [2] than the index vectors [3]")); + assertThat( + e.getMessage(), + containsString("The query vector has a different number of dimensions [2] than the document vectors [3]") + ); } public void testNonexistentField() { diff --git a/test/external-modules/jvm-crash/src/javaRestTest/java/org/elasticsearch/test/jvm_crash/JvmCrashIT.java b/test/external-modules/jvm-crash/src/javaRestTest/java/org/elasticsearch/test/jvm_crash/JvmCrashIT.java index 3e73310ee824f..517cb5b65a529 100644 --- a/test/external-modules/jvm-crash/src/javaRestTest/java/org/elasticsearch/test/jvm_crash/JvmCrashIT.java +++ b/test/external-modules/jvm-crash/src/javaRestTest/java/org/elasticsearch/test/jvm_crash/JvmCrashIT.java @@ -22,9 +22,11 @@ import org.elasticsearch.test.cluster.local.distribution.LocalDistributionResolver; import org.elasticsearch.test.cluster.local.distribution.ReleasedDistributionResolver; import org.elasticsearch.test.cluster.local.distribution.SnapshotDistributionResolver; +import org.elasticsearch.test.cluster.util.OS; import org.elasticsearch.test.rest.ESRestTestCase; import org.hamcrest.Matcher; import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.ClassRule; import java.io.BufferedReader; @@ -46,6 +48,11 @@ public class JvmCrashIT extends ESRestTestCase { + @BeforeClass + public static void dontRunWindows() { + assumeFalse("JVM crash log doesn't go to stdout on windows", OS.current() == OS.WINDOWS); + } + private static class StdOutCatchingClusterBuilder extends AbstractLocalClusterSpecBuilder { private StdOutCatchingClusterBuilder() { diff --git a/test/framework/src/main/java/org/elasticsearch/action/support/replication/ClusterStateCreationUtils.java b/test/framework/src/main/java/org/elasticsearch/action/support/replication/ClusterStateCreationUtils.java index 0f60ba9731966..5b656598451a3 100644 --- a/test/framework/src/main/java/org/elasticsearch/action/support/replication/ClusterStateCreationUtils.java +++ b/test/framework/src/main/java/org/elasticsearch/action/support/replication/ClusterStateCreationUtils.java @@ -9,6 +9,7 @@ package org.elasticsearch.action.support.replication; import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; @@ -141,13 +142,17 @@ public static ClusterState state( discoBuilder.localNodeId(newNode(0).getId()); discoBuilder.masterNodeId(newNode(1).getId()); // we need a non-local master to test shard failures final int primaryTerm = 1 + randomInt(200); + IndexLongFieldRange timeFieldRange = primaryState == ShardRoutingState.STARTED || primaryState == ShardRoutingState.RELOCATING + ? IndexLongFieldRange.UNKNOWN + : IndexLongFieldRange.NO_SHARDS; + IndexMetadata indexMetadata = IndexMetadata.builder(index) .settings(indexSettings(IndexVersion.current(), 1, numberOfReplicas).put(SETTING_CREATION_DATE, System.currentTimeMillis())) .primaryTerm(0, primaryTerm) - .timestampRange( - primaryState == ShardRoutingState.STARTED || primaryState == ShardRoutingState.RELOCATING - ? IndexLongFieldRange.UNKNOWN - : IndexLongFieldRange.NO_SHARDS + .timestampRange(timeFieldRange) + .eventIngestedRange( + timeFieldRange, + timeFieldRange == IndexLongFieldRange.UNKNOWN ? null : TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE ) .build(); @@ -281,6 +286,10 @@ public static ClusterState state(final int numberOfNodes, final String[] indices .settings( indexSettings(IndexVersion.current(), numberOfPrimaries, 0).put(SETTING_CREATION_DATE, System.currentTimeMillis()) ) + .eventIngestedRange( + IndexLongFieldRange.UNKNOWN, + randomFrom(TransportVersions.V_8_0_0, TransportVersions.EVENT_INGESTED_RANGE_IN_CLUSTER_STATE) + ) .build(); IndexRoutingTable.Builder indexRoutingTable = IndexRoutingTable.builder(indexMetadata.getIndex()); @@ -386,6 +395,7 @@ public static ClusterState stateWithAssignedPrimariesAndReplicas( ) ) .timestampRange(IndexLongFieldRange.UNKNOWN) + .eventIngestedRange(IndexLongFieldRange.UNKNOWN, null) .build(); metadataBuilder.put(indexMetadata, false).generateClusterUuidIfNeeded(); IndexRoutingTable.Builder indexRoutingTableBuilder = IndexRoutingTable.builder(indexMetadata.getIndex()); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/KeywordFieldSyntheticSourceSupport.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/KeywordFieldSyntheticSourceSupport.java index 53ecb75c18d9a..6abe923851318 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/KeywordFieldSyntheticSourceSupport.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/KeywordFieldSyntheticSourceSupport.java @@ -17,7 +17,9 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; @@ -27,15 +29,20 @@ public class KeywordFieldSyntheticSourceSupport implements MapperTestCase.Synthe private final boolean store; private final boolean docValues; private final String nullValue; - private final boolean exampleSortsUsingIgnoreAbove; - KeywordFieldSyntheticSourceSupport(Integer ignoreAbove, boolean store, String nullValue, boolean exampleSortsUsingIgnoreAbove) { + KeywordFieldSyntheticSourceSupport(Integer ignoreAbove, boolean store, String nullValue, boolean useFallbackSyntheticSource) { this.ignoreAbove = ignoreAbove; this.allIgnored = ignoreAbove != null && LuceneTestCase.rarely(); this.store = store; this.nullValue = nullValue; - this.exampleSortsUsingIgnoreAbove = exampleSortsUsingIgnoreAbove; - this.docValues = store ? ESTestCase.randomBoolean() : true; + this.docValues = useFallbackSyntheticSource == false || ESTestCase.randomBoolean(); + } + + @Override + public boolean preservesExactSource() { + // We opt in into fallback synthetic source implementation + // if there is nothing else to use, and it preserves exact source data. + return store == false && docValues == false; } @Override @@ -46,36 +53,48 @@ public MapperTestCase.SyntheticSourceExample example(int maxValues) { public MapperTestCase.SyntheticSourceExample example(int maxValues, boolean loadBlockFromSource) { if (ESTestCase.randomBoolean()) { Tuple v = generateValue(); + Object sourceValue = preservesExactSource() ? v.v1() : v.v2(); Object loadBlock = v.v2(); if (loadBlockFromSource == false && ignoreAbove != null && v.v2().length() > ignoreAbove) { loadBlock = null; } - return new MapperTestCase.SyntheticSourceExample(v.v1(), v.v2(), loadBlock, this::mapping); + return new MapperTestCase.SyntheticSourceExample(v.v1(), sourceValue, loadBlock, this::mapping); } List> values = ESTestCase.randomList(1, maxValues, this::generateValue); List in = values.stream().map(Tuple::v1).toList(); - List outPrimary = new ArrayList<>(); - List outExtraValues = new ArrayList<>(); + + List validValues = new ArrayList<>(); + List ignoredValues = new ArrayList<>(); values.stream().map(Tuple::v2).forEach(v -> { - if (exampleSortsUsingIgnoreAbove && ignoreAbove != null && v.length() > ignoreAbove) { - outExtraValues.add(v); + if (ignoreAbove != null && v.length() > ignoreAbove) { + ignoredValues.add(v); } else { - outPrimary.add(v); + validValues.add(v); } }); - List outList = store ? outPrimary : new HashSet<>(outPrimary).stream().sorted().collect(Collectors.toList()); + List outputFromDocValues = new HashSet<>(validValues).stream().sorted().collect(Collectors.toList()); + + Object out; + if (preservesExactSource()) { + out = in; + } else { + var validValuesInCorrectOrder = store ? validValues : outputFromDocValues; + var syntheticSourceOutputList = Stream.concat(validValuesInCorrectOrder.stream(), ignoredValues.stream()).toList(); + out = syntheticSourceOutputList.size() == 1 ? syntheticSourceOutputList.get(0) : syntheticSourceOutputList; + } + List loadBlock; if (loadBlockFromSource) { // The block loader infrastructure will never return nulls. Just zap them all. - loadBlock = in.stream().filter(m -> m != null).toList(); + loadBlock = in.stream().filter(Objects::nonNull).toList(); } else if (docValues) { - loadBlock = new HashSet<>(outPrimary).stream().sorted().collect(Collectors.toList()); + loadBlock = List.copyOf(outputFromDocValues); } else { - loadBlock = List.copyOf(outList); + // Meaning loading from terms. + loadBlock = List.copyOf(validValues); } + Object loadBlockResult = loadBlock.size() == 1 ? loadBlock.get(0) : loadBlock; - outList.addAll(outExtraValues); - Object out = outList.size() == 1 ? outList.get(0) : outList; return new MapperTestCase.SyntheticSourceExample(in, out, loadBlockResult, this::mapping); } @@ -110,13 +129,6 @@ private void mapping(XContentBuilder b) throws IOException { @Override public List invalidExample() throws IOException { return List.of( - new MapperTestCase.SyntheticSourceInvalidExample( - equalTo( - "field [field] of type [keyword] doesn't support synthetic source because " - + "it doesn't have doc values and isn't stored" - ), - b -> b.field("type", "keyword").field("doc_values", false) - ), new MapperTestCase.SyntheticSourceInvalidExample( equalTo("field [field] of type [keyword] doesn't support synthetic source because it declares a normalizer"), b -> b.field("type", "keyword").field("normalizer", "lowercase") diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapperTests.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapperTests.java index c60a913a63b33..ec2bbc35a68b1 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapperTests.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapperTests.java @@ -36,7 +36,6 @@ import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.matchesPattern; import static org.hamcrest.Matchers.notANumber; public abstract class NumberFieldMapperTests extends MapperTestCase { @@ -376,6 +375,14 @@ public void testAllowMultipleValuesField() throws IOException { assertThat(e.getCause().getMessage(), containsString("Only one field can be stored per key")); } + @Override + protected BlockReaderSupport getSupportedReaders(MapperService mapper, String loaderFieldName) { + MappedFieldType ft = mapper.fieldType(loaderFieldName); + // Block loader can either use doc values or source. + // So with synthetic source it only works when doc values are enabled. + return new BlockReaderSupport(ft.hasDocValues(), ft.hasDocValues(), mapper, loaderFieldName); + } + @Override protected Function loadBlockExpected() { return n -> ((Number) n); // Just assert it's a number @@ -391,6 +398,7 @@ protected Matcher blockItemMatcher(Object expected) { protected final class NumberSyntheticSourceSupport implements SyntheticSourceSupport { private final Long nullValue = usually() ? null : randomNumber().longValue(); private final boolean coerce = rarely(); + private final boolean docValues = randomBoolean(); private final Function round; private final boolean ignoreMalformed; @@ -400,10 +408,26 @@ protected NumberSyntheticSourceSupport(Function round, boolean i this.ignoreMalformed = ignoreMalformed; } + @Override + public boolean preservesExactSource() { + // We opt in into fallback synthetic source if there is no doc values + // which preserves exact source. + return docValues == false; + } + @Override public SyntheticSourceExample example(int maxVals) { if (randomBoolean()) { Tuple v = generateValue(); + if (preservesExactSource()) { + var rawInput = v.v1(); + + // This code actually runs with synthetic source disabled + // to test block loader loading from source. + // That's why we need to set expected block loader value here. + var blockLoaderResult = v.v2() instanceof Number n ? round.apply(n) : null; + return new SyntheticSourceExample(rawInput, rawInput, blockLoaderResult, this::mapping); + } if (v.v2() instanceof Number n) { Number result = round.apply(n); return new SyntheticSourceExample(v.v1(), result, result, this::mapping); @@ -413,19 +437,33 @@ public SyntheticSourceExample example(int maxVals) { } List> values = randomList(1, maxVals, this::generateValue); List in = values.stream().map(Tuple::v1).toList(); - List outList = values.stream() - .filter(v -> v.v2() instanceof Number) - .map(t -> round.apply((Number) t.v2())) - .sorted() - .collect(Collectors.toCollection(ArrayList::new)); - values.stream().filter(v -> false == v.v2() instanceof Number).map(v -> v.v2()).forEach(outList::add); - Object out = outList.size() == 1 ? outList.get(0) : outList; - - List outBlockList = values.stream() - .filter(v -> v.v2() instanceof Number) - .map(t -> round.apply((Number) t.v2())) - .sorted() - .collect(Collectors.toCollection(ArrayList::new)); + Object out; + List outBlockList; + if (preservesExactSource()) { + // This code actually runs with synthetic source disabled + // to test block loader loading from source. + // That's why we need to set expected block loader value here. + out = in; + outBlockList = values.stream() + .filter(v -> v.v2() instanceof Number) + .map(t -> round.apply((Number) t.v2())) + .collect(Collectors.toCollection(ArrayList::new)); + } else { + List outList = values.stream() + .filter(v -> v.v2() instanceof Number) + .map(t -> round.apply((Number) t.v2())) + .sorted() + .collect(Collectors.toCollection(ArrayList::new)); + values.stream().filter(v -> false == v.v2() instanceof Number).map(Tuple::v2).forEach(outList::add); + out = outList.size() == 1 ? outList.get(0) : outList; + + outBlockList = values.stream() + .filter(v -> v.v2() instanceof Number) + .map(t -> round.apply((Number) t.v2())) + .sorted() + .collect(Collectors.toCollection(ArrayList::new)); + } + Object outBlock = outBlockList.size() == 1 ? outBlockList.get(0) : outBlockList; return new SyntheticSourceExample(in, out, outBlock, this::mapping); } @@ -459,19 +497,14 @@ private void mapping(XContentBuilder b) throws IOException { if (ignoreMalformed) { b.field("ignore_malformed", true); } + if (docValues == false) { + b.field("doc_values", "false"); + } } @Override public List invalidExample() throws IOException { - return List.of( - new SyntheticSourceInvalidExample( - matchesPattern("field \\[field] of type \\[.+] doesn't support synthetic source because it doesn't have doc values"), - b -> { - minimalMapping(b); - b.field("doc_values", false); - } - ) - ); + return List.of(); } } } diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/TextFieldFamilySyntheticSourceTestSetup.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/TextFieldFamilySyntheticSourceTestSetup.java index df4377adc3e35..953d71b9a791b 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/TextFieldFamilySyntheticSourceTestSetup.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/TextFieldFamilySyntheticSourceTestSetup.java @@ -10,6 +10,9 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; import org.hamcrest.Matcher; import java.io.IOException; @@ -19,6 +22,7 @@ import static org.elasticsearch.test.ESTestCase.between; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; +import static org.elasticsearch.test.ESTestCase.randomAlphaOfLengthBetween; import static org.elasticsearch.test.ESTestCase.randomBoolean; import static org.hamcrest.Matchers.equalTo; @@ -78,48 +82,50 @@ public static void validateRoundTripReader(String syntheticSource, DirectoryRead private static class TextFieldFamilySyntheticSourceSupport implements MapperTestCase.SyntheticSourceSupport { private final String fieldType; - private final boolean storeTextField; - private final boolean storedKeywordField; - private final boolean indexText; + private final boolean store; + private final boolean index; private final Integer ignoreAbove; - private final KeywordFieldSyntheticSourceSupport keywordSupport; + private final KeywordFieldSyntheticSourceSupport keywordMultiFieldSyntheticSourceSupport; TextFieldFamilySyntheticSourceSupport(String fieldType, boolean supportsCustomIndexConfiguration) { this.fieldType = fieldType; - this.storeTextField = randomBoolean(); - this.storedKeywordField = storeTextField || randomBoolean(); - this.indexText = supportsCustomIndexConfiguration ? randomBoolean() : true; + this.store = randomBoolean(); + this.index = supportsCustomIndexConfiguration == false || randomBoolean(); this.ignoreAbove = randomBoolean() ? null : between(10, 100); - this.keywordSupport = new KeywordFieldSyntheticSourceSupport(ignoreAbove, storedKeywordField, null, false == storeTextField); + this.keywordMultiFieldSyntheticSourceSupport = new KeywordFieldSyntheticSourceSupport( + ignoreAbove, + randomBoolean(), + null, + false + ); } @Override public MapperTestCase.SyntheticSourceExample example(int maxValues) { - if (storeTextField) { - MapperTestCase.SyntheticSourceExample delegate = keywordSupport.example(maxValues, true); - return new MapperTestCase.SyntheticSourceExample( - delegate.inputValue(), - delegate.expectedForSyntheticSource(), - delegate.expectedForBlockLoader(), - b -> { - b.field("type", fieldType); - b.field("store", true); - if (indexText == false) { - b.field("index", false); - } + if (store) { + CheckedConsumer mapping = b -> { + b.field("type", fieldType); + b.field("store", true); + if (index == false) { + b.field("index", false); } - ); + }; + + return storedFieldExample(maxValues, mapping); } - // We'll load from _source if ignore_above is defined, otherwise we load from the keyword field. + + // Block loader will not use keyword multi-field if it has ignore_above configured. + // And in this case it will use values from source. boolean loadingFromSource = ignoreAbove != null; - MapperTestCase.SyntheticSourceExample delegate = keywordSupport.example(maxValues, loadingFromSource); + MapperTestCase.SyntheticSourceExample delegate = keywordMultiFieldSyntheticSourceSupport.example(maxValues, loadingFromSource); + return new MapperTestCase.SyntheticSourceExample( delegate.inputValue(), delegate.expectedForSyntheticSource(), delegate.expectedForBlockLoader(), b -> { b.field("type", fieldType); - if (indexText == false) { + if (index == false) { b.field("index", false); } b.startObject("fields"); @@ -133,6 +139,25 @@ public MapperTestCase.SyntheticSourceExample example(int maxValues) { ); } + private MapperTestCase.SyntheticSourceExample storedFieldExample( + int maxValues, + CheckedConsumer mapping + ) { + if (randomBoolean()) { + var randomString = randomString(); + return new MapperTestCase.SyntheticSourceExample(randomString, randomString, randomString, mapping); + } + + var list = ESTestCase.randomList(1, maxValues, this::randomString); + var output = list.size() == 1 ? list.get(0) : list; + + return new MapperTestCase.SyntheticSourceExample(list, output, output, mapping); + } + + private String randomString() { + return randomAlphaOfLengthBetween(0, 10); + } + @Override public List invalidExample() throws IOException { Matcher err = equalTo( diff --git a/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java index 442a8c3b82dc6..0488614f04dfb 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java @@ -135,7 +135,11 @@ public abstract class IndexShardTestCase extends ESTestCase { protected static final PeerRecoveryTargetService.RecoveryListener recoveryListener = new PeerRecoveryTargetService.RecoveryListener() { @Override - public void onRecoveryDone(RecoveryState state, ShardLongFieldRange timestampMillisFieldRange) { + public void onRecoveryDone( + RecoveryState state, + ShardLongFieldRange timestampMillisFieldRange, + ShardLongFieldRange eventIngestedMillisFieldRange + ) { } diff --git a/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java index 50e723ebd49d2..a80996359c52f 100644 --- a/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java @@ -424,6 +424,10 @@ public ShardLongFieldRange getTimestampRange() { return ShardLongFieldRange.EMPTY; } + @Override + public ShardLongFieldRange getEventIngestedRange() { + return ShardLongFieldRange.EMPTY; + } } public static void awaitIndexShardCloseAsyncTasks(IndicesClusterStateService indicesClusterStateService) { diff --git a/test/framework/src/main/java/org/elasticsearch/readiness/MockReadinessService.java b/test/framework/src/main/java/org/elasticsearch/readiness/MockReadinessService.java index e5841071a787b..e4ec541256f31 100644 --- a/test/framework/src/main/java/org/elasticsearch/readiness/MockReadinessService.java +++ b/test/framework/src/main/java/org/elasticsearch/readiness/MockReadinessService.java @@ -29,9 +29,9 @@ public class MockReadinessService extends ReadinessService { */ public static class TestPlugin extends Plugin {} - private static final int RETRIES = 3; + private static final int RETRIES = 30; - private static final int RETRY_DELAY_IN_MILLIS = 10; + private static final int RETRY_DELAY_IN_MILLIS = 100; private static final String METHOD_NOT_MOCKED = "This method has not been mocked"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/support/xcontent/WatcherXContentParser.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/support/xcontent/WatcherXContentParser.java index 96fa4de6c0d9b..3270a839778fe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/support/xcontent/WatcherXContentParser.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/support/xcontent/WatcherXContentParser.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.time.Clock; import java.time.ZonedDateTime; +import java.util.Arrays; /** * A xcontent parser that is used by watcher. This is a special parser that is @@ -50,7 +51,9 @@ public static Secret secretOrNull(XContentParser parser) throws IOException { throw new ElasticsearchParseException("found redacted password in field [{}]", parser.currentName()); } } else if (watcherParser.cryptoService != null) { - return new Secret(watcherParser.cryptoService.encrypt(chars)); + char[] encryptedChars = watcherParser.cryptoService.encrypt(chars); + Arrays.fill(chars, '\0'); // Clear chars from unencrypted buffer + return new Secret(encryptedChars); } } diff --git a/x-pack/plugin/core/template-resources/src/main/resources/profiling/component-template/profiling-hosts.json b/x-pack/plugin/core/template-resources/src/main/resources/profiling/component-template/profiling-hosts.json index 3d5e5d0fdc9b7..d9b92f5cd4f0c 100644 --- a/x-pack/plugin/core/template-resources/src/main/resources/profiling/component-template/profiling-hosts.json +++ b/x-pack/plugin/core/template-resources/src/main/resources/profiling/component-template/profiling-hosts.json @@ -76,6 +76,9 @@ "type": "date", "format": "epoch_millis" }, + "protocol": { + "type": "keyword" + }, "config.bpf_log_level": { "type": "long" }, diff --git a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DimensionFieldValueFetcher.java b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DimensionFieldValueFetcher.java index c6ef43cfdacfa..342b6e57c9e51 100644 --- a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DimensionFieldValueFetcher.java +++ b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/DimensionFieldValueFetcher.java @@ -9,6 +9,7 @@ import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.flattened.FlattenedFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; import java.util.ArrayList; @@ -19,13 +20,12 @@ public class DimensionFieldValueFetcher extends FieldValueFetcher { private final DimensionFieldProducer dimensionFieldProducer = createFieldProducer(); - protected DimensionFieldValueFetcher(final MappedFieldType fieldType, final IndexFieldData fieldData) { - super(fieldType.name(), fieldType, fieldData); + protected DimensionFieldValueFetcher(final String fieldName, final MappedFieldType fieldType, final IndexFieldData fieldData) { + super(fieldName, fieldType, fieldData); } private DimensionFieldProducer createFieldProducer() { - final String filedName = fieldType.name(); - return new DimensionFieldProducer(filedName, new DimensionFieldProducer.Dimension(filedName)); + return new DimensionFieldProducer(name, new DimensionFieldProducer.Dimension(name)); } @Override @@ -42,12 +42,18 @@ static List create(final SearchExecutionContext context, fina MappedFieldType fieldType = context.getFieldType(dimension); assert fieldType != null : "Unknown dimension field type for dimension field: [" + dimension + "]"; - if (context.fieldExistsInIndex(dimension)) { + if (context.fieldExistsInIndex(fieldType.name())) { final IndexFieldData fieldData = context.getForField(fieldType, MappedFieldType.FielddataOperation.SEARCH); - final String fieldName = context.isMultiField(dimension) - ? fieldType.name().substring(0, fieldType.name().lastIndexOf('.')) - : fieldType.name(); - fetchers.add(new DimensionFieldValueFetcher(fieldType, fieldData)); + if (fieldType instanceof FlattenedFieldMapper.KeyedFlattenedFieldType flattenedFieldType) { + // Name of the field type and name of the dimension are different in this case. + var dimensionName = flattenedFieldType.rootName() + '.' + flattenedFieldType.key(); + fetchers.add(new DimensionFieldValueFetcher(dimensionName, fieldType, fieldData)); + } else { + final String fieldName = context.isMultiField(dimension) + ? fieldType.name().substring(0, fieldType.name().lastIndexOf('.')) + : fieldType.name(); + fetchers.add(new DimensionFieldValueFetcher(fieldName, fieldType, fieldData)); + } } } return Collections.unmodifiableList(fetchers); diff --git a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/TimeseriesFieldTypeHelper.java b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/TimeseriesFieldTypeHelper.java index 691279187e1a9..e539722481df8 100644 --- a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/TimeseriesFieldTypeHelper.java +++ b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/TimeseriesFieldTypeHelper.java @@ -11,6 +11,7 @@ import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MappingLookup; import org.elasticsearch.index.mapper.TimeSeriesParams; +import org.elasticsearch.index.mapper.flattened.FlattenedFieldMapper; import java.io.IOException; import java.util.List; @@ -49,6 +50,19 @@ public boolean isTimeSeriesDimension(final String unused, final Map f return Boolean.TRUE.equals(fieldMapping.get(TIME_SERIES_DIMENSION_PARAM)); } + public List extractFlattenedDimensions(final String field, final Map fieldMapping) { + var mapper = mapperService.mappingLookup().getMapper(field); + if (mapper instanceof FlattenedFieldMapper == false) { + return null; + } + Object dimensions = fieldMapping.get(FlattenedFieldMapper.TIME_SERIES_DIMENSIONS_ARRAY_PARAM); + if (dimensions instanceof List actualList) { + return actualList.stream().map(field_in_flattened -> field + '.' + field_in_flattened).toList(); + } + + return null; + } + static class Builder { private final MapperService mapperService; diff --git a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/TransportDownsampleAction.java b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/TransportDownsampleAction.java index e370ab5383fd5..66511f2cc15f0 100644 --- a/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/TransportDownsampleAction.java +++ b/x-pack/plugin/downsample/src/main/java/org/elasticsearch/xpack/downsample/TransportDownsampleAction.java @@ -310,7 +310,10 @@ protected void masterOperation( request.getDownsampleConfig().getTimestampField() ); MappingVisitor.visitMapping(sourceIndexMappings, (field, mapping) -> { - if (helper.isTimeSeriesDimension(field, mapping)) { + var flattenedDimensions = helper.extractFlattenedDimensions(field, mapping); + if (flattenedDimensions != null) { + dimensionFields.addAll(flattenedDimensions); + } else if (helper.isTimeSeriesDimension(field, mapping)) { dimensionFields.add(field); } else if (helper.isTimeSeriesMetric(field, mapping)) { metricFields.add(field); diff --git a/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleActionSingleNodeTests.java b/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleActionSingleNodeTests.java index 80bb0368a1afc..5012bacf319b6 100644 --- a/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleActionSingleNodeTests.java +++ b/x-pack/plugin/downsample/src/test/java/org/elasticsearch/xpack/downsample/DownsampleActionSingleNodeTests.java @@ -70,7 +70,6 @@ import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.InternalTopHits; -import org.elasticsearch.search.aggregations.metrics.Max; import org.elasticsearch.search.aggregations.metrics.MaxAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder; @@ -134,6 +133,8 @@ public class DownsampleActionSingleNodeTests extends ESSingleNodeTestCase { public static final String FIELD_TIMESTAMP = "@timestamp"; public static final String FIELD_DIMENSION_1 = "dimension_kw"; public static final String FIELD_DIMENSION_2 = "dimension_long"; + public static final String FIELD_DIMENSION_3 = "dimension_flattened"; + public static final String FIELD_DIMENSION_4 = "dimension_kw_multifield"; public static final String FIELD_NUMERIC_1 = "numeric_1"; public static final String FIELD_NUMERIC_2 = "numeric_2"; public static final String FIELD_AGG_METRIC = "agg_metric_1"; @@ -212,6 +213,19 @@ public void setup() throws IOException { // Dimensions mapping.startObject(FIELD_DIMENSION_1).field("type", "keyword").field("time_series_dimension", true).endObject(); mapping.startObject(FIELD_DIMENSION_2).field("type", "long").field("time_series_dimension", true).endObject(); + mapping.startObject(FIELD_DIMENSION_3) + .field("type", "flattened") + .array("time_series_dimensions", "level1_value", "level1_obj.level2_value") + .endObject(); + mapping.startObject(FIELD_DIMENSION_4) + .field("type", "text") + .startObject("fields") + .startObject("keyword") + .field("type", "keyword") + .field("time_series_dimension", true) + .endObject() + .endObject() + .endObject(); // Metrics mapping.startObject(FIELD_NUMERIC_1).field("type", "long").field("time_series_metric", "gauge").endObject(); @@ -307,6 +321,42 @@ public void testDownsampleIndex() throws Exception { assertDownsampleIndex(sourceIndex, downsampleIndex, config); } + public void testDownsampleIndexWithFlattenedAndMultiFieldDimensions() throws Exception { + DownsampleConfig config = new DownsampleConfig(randomInterval()); + SourceSupplier sourceSupplier = () -> { + String ts = randomDateForInterval(config.getInterval()); + double labelDoubleValue = DATE_FORMATTER.parseMillis(ts); + return XContentFactory.jsonBuilder() + .startObject() + .field(FIELD_TIMESTAMP, ts) + .field(FIELD_DIMENSION_1, "dim1") // not important for this test + .startObject(FIELD_DIMENSION_3) + .field("level1_value", randomFrom(dimensionValues)) + .field("level1_othervalue", randomFrom(dimensionValues)) + .startObject("level1_object") + .field("level2_value", randomFrom(dimensionValues)) + .field("level2_othervalue", randomFrom(dimensionValues)) + .endObject() + .endObject() + .field(FIELD_DIMENSION_4, randomFrom(dimensionValues)) + .field(FIELD_NUMERIC_1, randomInt()) + .field(FIELD_NUMERIC_2, DATE_FORMATTER.parseMillis(ts)) + .startObject(FIELD_AGG_METRIC) + .field("min", randomDoubleBetween(-2000, -1001, true)) + .field("max", randomDoubleBetween(-1000, 1000, true)) + .field("sum", randomIntBetween(100, 10000)) + .field("value_count", randomIntBetween(100, 1000)) + .endObject() + .field(FIELD_LABEL_DOUBLE, labelDoubleValue) + .field(FIELD_METRIC_LABEL_DOUBLE, labelDoubleValue) + .endObject(); + }; + bulkIndex(sourceSupplier); + prepareSourceIndex(sourceIndex, true); + downsample(sourceIndex, downsampleIndex, config); + assertDownsampleIndex(sourceIndex, downsampleIndex, config); + } + public void testDownsampleOfDownsample() throws Exception { int intervalMinutes = randomIntBetween(10, 120); DownsampleConfig config = new DownsampleConfig(DateHistogramInterval.minutes(intervalMinutes)); @@ -1103,7 +1153,7 @@ private RolloverResponse rollover(String dataStreamName) throws ExecutionExcepti } private InternalAggregations aggregate(final String index, AggregationBuilder aggregationBuilder) { - var resp = client().prepareSearch(index).addAggregation(aggregationBuilder).get(); + var resp = client().prepareSearch(index).setSize(0).addAggregation(aggregationBuilder).get(); try { return resp.getAggregations(); } finally { @@ -1304,14 +1354,17 @@ private void assertDownsampleIndexAggregations( originalFieldsList.contains(field) ) ); - Object originalLabelValue = originalHit.getDocumentFields().values().stream().toList().get(0).getValue(); - Object downsampleLabelValue = downsampleHit.getDocumentFields().values().stream().toList().get(0).getValue(); - Optional labelAsMetric = nonTopHitsOriginalAggregations.stream() + String labelName = originalHit.getDocumentFields().values().stream().findFirst().get().getName(); + Object originalLabelValue = originalHit.getDocumentFields().values().stream().findFirst().get().getValue(); + Object downsampleLabelValue = downsampleHit.getDocumentFields().values().stream().findFirst().get().getValue(); + Optional labelAsMetric = topHitsOriginalAggregations.stream() .filter(agg -> agg.getName().equals("metric_" + downsampleTopHits.getName())) .findFirst(); // NOTE: this check is possible only if the label can be indexed as a metric (the label is a numeric field) if (labelAsMetric.isPresent()) { - double metricValue = ((Max) labelAsMetric.get()).value(); + double metricValue = ((InternalTopHits) labelAsMetric.get()).getHits().getHits()[0].field( + "metric_" + labelName + ).getValue(); assertEquals(metricValue, downsampleLabelValue); assertEquals(metricValue, originalLabelValue); } diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/sync_job/70_connector_sync_job_update_stats.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/sync_job/70_connector_sync_job_update_stats.yml index 85156bf800582..31dfea2e01d11 100644 --- a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/sync_job/70_connector_sync_job_update_stats.yml +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/sync_job/70_connector_sync_job_update_stats.yml @@ -184,6 +184,77 @@ setup: - match: { indexed_document_volume: 1000 } - match: { last_seen: 2023-12-04T08:45:50.567149Z } +--- +"Update the ingestion stats for a connector sync job - with optional metadata": + - do: + connector.sync_job_post: + body: + id: test-connector + job_type: full + trigger_method: on_demand + - set: { id: id } + + - do: + connector.sync_job_update_stats: + connector_sync_job_id: $id + body: + deleted_document_count: 10 + indexed_document_count: 20 + indexed_document_volume: 1000 + metadata: { someKey1: test, someKey2: test2 } + + - match: { result: updated } + + - do: + connector.sync_job_get: + connector_sync_job_id: $id + + - match: { deleted_document_count: 10 } + - match: { indexed_document_count: 20 } + - match: { indexed_document_volume: 1000 } + - match: { metadata: { someKey1: test, someKey2: test2 } } + + +--- +"Update the ingestion stats for a connector sync job - metadata wrong type string": + - do: + connector.sync_job_post: + body: + id: test-connector + job_type: full + trigger_method: on_demand + - set: { id: id } + + - do: + catch: bad_request + connector.sync_job_update_stats: + connector_sync_job_id: $id + body: + deleted_document_count: 10 + indexed_document_count: 20 + indexed_document_volume: 1000 + metadata: "abc" + +--- +"Update the ingestion stats for a connector sync job - metadata wrong type number": + - do: + connector.sync_job_post: + body: + id: test-connector + job_type: full + trigger_method: on_demand + - set: { id: id } + + - do: + catch: bad_request + connector.sync_job_update_stats: + connector_sync_job_id: $id + body: + deleted_document_count: 10 + indexed_document_count: 20 + indexed_document_volume: 1000 + metadata: 123 + --- "Update the ingestion stats for a Connector Sync Job - Connector Sync Job does not exist": - do: diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorStateMachine.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorStateMachine.java index f722955cc0f9e..87b6c4c3da53f 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorStateMachine.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorStateMachine.java @@ -25,11 +25,11 @@ public class ConnectorStateMachine { ConnectorStatus.NEEDS_CONFIGURATION, EnumSet.of(ConnectorStatus.CONFIGURED, ConnectorStatus.ERROR), ConnectorStatus.CONFIGURED, - EnumSet.of(ConnectorStatus.NEEDS_CONFIGURATION, ConnectorStatus.CONNECTED, ConnectorStatus.ERROR), + EnumSet.of(ConnectorStatus.NEEDS_CONFIGURATION, ConnectorStatus.CONFIGURED, ConnectorStatus.CONNECTED, ConnectorStatus.ERROR), ConnectorStatus.CONNECTED, - EnumSet.of(ConnectorStatus.CONFIGURED, ConnectorStatus.ERROR), + EnumSet.of(ConnectorStatus.CONNECTED, ConnectorStatus.CONFIGURED, ConnectorStatus.ERROR), ConnectorStatus.ERROR, - EnumSet.of(ConnectorStatus.CONNECTED, ConnectorStatus.CONFIGURED) + EnumSet.of(ConnectorStatus.CONNECTED, ConnectorStatus.CONFIGURED, ConnectorStatus.ERROR) ); /** diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/PostConnectorAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/PostConnectorAction.java index 26371ffbed159..fad349cd31877 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/PostConnectorAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/PostConnectorAction.java @@ -7,21 +7,16 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import java.io.IOException; import java.util.Objects; @@ -95,14 +90,6 @@ public Request(StreamInput in) throws IOException { PARSER.declareString(optionalConstructorArg(), new ParseField("service_type")); } - public static Request fromXContentBytes(BytesReference source, XContentType xContentType) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return Request.fromXContent(parser); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static Request fromXContent(XContentParser parser) throws IOException { return PARSER.parse(parser, null); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/PutConnectorAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/PutConnectorAction.java index 96ef483236823..687a801ab8fd6 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/PutConnectorAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/PutConnectorAction.java @@ -7,23 +7,18 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import java.io.IOException; import java.util.Objects; @@ -110,14 +105,6 @@ public Request(StreamInput in) throws IOException { PARSER.declareString(optionalConstructorArg(), new ParseField("service_type")); } - public static Request fromXContentBytes(String connectorId, BytesReference source, XContentType xContentType) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public boolean isConnectorIdNullOrEmpty() { return Strings.isNullOrEmpty(connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestDeleteConnectorAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestDeleteConnectorAction.java index d945930d9ee32..8d4a6dccd95fe 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestDeleteConnectorAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestDeleteConnectorAction.java @@ -23,6 +23,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestDeleteConnectorAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_delete_action"; @@ -30,13 +32,13 @@ public String getName() { @Override public List routes() { - return List.of(new Route(DELETE, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}")); + return List.of(new Route(DELETE, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}")); } @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - String connectorId = restRequest.param("connector_id"); + String connectorId = restRequest.param(CONNECTOR_ID_PARAM); boolean shouldDeleteSyncJobs = restRequest.paramAsBoolean("delete_sync_jobs", false); DeleteConnectorAction.Request request = new DeleteConnectorAction.Request(connectorId, shouldDeleteSyncJobs); diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestGetConnectorAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestGetConnectorAction.java index 79922755e67ef..8d3d5914ca695 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestGetConnectorAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestGetConnectorAction.java @@ -22,6 +22,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestGetConnectorAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_get_action"; @@ -29,12 +31,12 @@ public String getName() { @Override public List routes() { - return List.of(new Route(GET, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}")); + return List.of(new Route(GET, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}")); } @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - GetConnectorAction.Request request = new GetConnectorAction.Request(restRequest.param("connector_id")); + GetConnectorAction.Request request = new GetConnectorAction.Request(restRequest.param(CONNECTOR_ID_PARAM)); return channel -> client.execute(GetConnectorAction.INSTANCE, request, new RestToXContentListener<>(channel)); } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestPostConnectorAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestPostConnectorAction.java index 51ddcac3cd58c..99bd2e7ed536d 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestPostConnectorAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestPostConnectorAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.POST; @@ -33,11 +34,10 @@ public List routes() { } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { PostConnectorAction.Request request; - // Handle empty REST request body if (restRequest.hasContent()) { - request = PostConnectorAction.Request.fromXContentBytes(restRequest.content(), restRequest.getXContentType()); + request = PostConnectorAction.Request.fromXContent(restRequest.contentParser()); } else { request = new PostConnectorAction.Request(); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestPutConnectorAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestPutConnectorAction.java index fcd292eefc531..feedad45dd890 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestPutConnectorAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestPutConnectorAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.rest.action.RestToXContentListener; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +23,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestPutConnectorAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_put_action"; @@ -30,18 +33,17 @@ public String getName() { @Override public List routes() { return List.of( - new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}"), + new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}"), new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT) ); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - String connectorId = restRequest.param("connector_id"); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String connectorId = restRequest.param(CONNECTOR_ID_PARAM); PutConnectorAction.Request request; - // Handle empty REST request body if (restRequest.hasContent()) { - request = PutConnectorAction.Request.fromXContentBytes(connectorId, restRequest.content(), restRequest.getXContentType()); + request = PutConnectorAction.Request.fromXContent(restRequest.contentParser(), connectorId); } else { request = new PutConnectorAction.Request(connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorActiveFilteringAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorActiveFilteringAction.java index fbf44487651cf..4bc58e3b5d52a 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorActiveFilteringAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorActiveFilteringAction.java @@ -22,6 +22,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorActiveFilteringAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_active_filtering_action"; @@ -29,13 +31,15 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_filtering/_activate")); + return List.of( + new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_filtering/_activate") + ); } @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { UpdateConnectorActiveFilteringAction.Request request = new UpdateConnectorActiveFilteringAction.Request( - restRequest.param("connector_id") + restRequest.param(CONNECTOR_ID_PARAM) ); return channel -> client.execute( UpdateConnectorActiveFilteringAction.INSTANCE, diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorApiKeyIdAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorApiKeyIdAction.java index 0cb42f6f448a2..093fa0936c817 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorApiKeyIdAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorApiKeyIdAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorApiKeyIdAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_api_key_id_action"; @@ -29,20 +33,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_api_key_id")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_api_key_id")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorApiKeyIdAction.Request request = UpdateConnectorApiKeyIdAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorApiKeyIdAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorApiKeyIdAction.Request request = UpdateConnectorApiKeyIdAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorApiKeyIdAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorConfigurationAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorConfigurationAction.java index f4cc47da2f109..7f2447abfdc34 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorConfigurationAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorConfigurationAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorConfigurationAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_configuration_action"; @@ -29,20 +33,22 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_configuration")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_configuration")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorConfigurationAction.Request request = UpdateConnectorConfigurationAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorConfigurationAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorConfigurationAction.Request request = UpdateConnectorConfigurationAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorConfigurationAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorErrorAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorErrorAction.java index df56f5825f84e..85f94682f0825 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorErrorAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorErrorAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorErrorAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_error_action"; @@ -29,20 +33,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_error")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_error")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorErrorAction.Request request = UpdateConnectorErrorAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorErrorAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorErrorAction.Request request = UpdateConnectorErrorAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorErrorAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFeaturesAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFeaturesAction.java index 48bf87b114548..c26dcba52b705 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFeaturesAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFeaturesAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorFeaturesAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_features_action"; @@ -29,20 +33,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_features")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_features")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorFeaturesAction.Request request = UpdateConnectorFeaturesAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorFeaturesAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorFeaturesAction.Request request = UpdateConnectorFeaturesAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorFeaturesAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringAction.java index ae294dfebd111..0ee665561b888 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorFilteringAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_filtering_action"; @@ -29,20 +33,22 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_filtering")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_filtering")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorFilteringAction.Request request = UpdateConnectorFilteringAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorFilteringAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorFilteringAction.Request request = UpdateConnectorFilteringAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorFilteringAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } + } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringValidationAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringValidationAction.java index 32020eea4b8b9..697cf95b984ef 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringValidationAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringValidationAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorFilteringValidationAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_filtering_validation_action"; @@ -29,20 +33,23 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_filtering/_validation")); + return List.of( + new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_filtering/_validation") + ); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorFilteringValidationAction.Request request = UpdateConnectorFilteringValidationAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorFilteringValidationAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorFilteringValidationAction.Request request = UpdateConnectorFilteringValidationAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorFilteringValidationAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorIndexNameAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorIndexNameAction.java index ce6dd0a5ba24f..89870643901b9 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorIndexNameAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorIndexNameAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorIndexNameAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_index_name_action"; @@ -29,20 +33,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_index_name")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_index_name")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorIndexNameAction.Request request = UpdateConnectorIndexNameAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorIndexNameAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorIndexNameAction.Request request = UpdateConnectorIndexNameAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorIndexNameAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorLastSeenAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorLastSeenAction.java index bef6c357fdda3..6f76e70971a9f 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorLastSeenAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorLastSeenAction.java @@ -22,6 +22,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorLastSeenAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_last_seen_action"; @@ -29,12 +31,12 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_check_in")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_check_in")); } @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorLastSeenAction.Request request = new UpdateConnectorLastSeenAction.Request(restRequest.param("connector_id")); + UpdateConnectorLastSeenAction.Request request = new UpdateConnectorLastSeenAction.Request(restRequest.param(CONNECTOR_ID_PARAM)); return channel -> client.execute( UpdateConnectorLastSeenAction.INSTANCE, request, diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorLastSyncStatsAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorLastSyncStatsAction.java index 6275e84a28952..804b792810ffd 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorLastSyncStatsAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorLastSyncStatsAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorLastSyncStatsAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_last_sync_stats_action"; @@ -29,20 +33,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_last_sync")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_last_sync")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorLastSyncStatsAction.Request request = UpdateConnectorLastSyncStatsAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorLastSyncStatsAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorLastSyncStatsAction.Request request = UpdateConnectorLastSyncStatsAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorLastSyncStatsAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNameAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNameAction.java index 7fbd42cbff272..21d7d74166b7a 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNameAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNameAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorNameAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_name_action"; @@ -29,20 +33,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_name")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_name")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorNameAction.Request request = UpdateConnectorNameAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorNameAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorNameAction.Request request = UpdateConnectorNameAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorNameAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNativeAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNativeAction.java index 464d682567043..e2f9730df723d 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNativeAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNativeAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorNativeAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_native_action"; @@ -29,20 +33,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_native")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_native")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorNativeAction.Request request = UpdateConnectorNativeAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorNativeAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorNativeAction.Request request = UpdateConnectorNativeAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorNativeAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorPipelineAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorPipelineAction.java index 465414491bb95..24502d0def1df 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorPipelineAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorPipelineAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorPipelineAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_pipeline_action"; @@ -29,20 +33,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_pipeline")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_pipeline")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorPipelineAction.Request request = UpdateConnectorPipelineAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorPipelineAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorPipelineAction.Request request = UpdateConnectorPipelineAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorPipelineAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorSchedulingAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorSchedulingAction.java index dfc12659d394b..191def3a8af52 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorSchedulingAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorSchedulingAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorSchedulingAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_scheduling_action"; @@ -29,20 +33,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_scheduling")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_scheduling")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorSchedulingAction.Request request = UpdateConnectorSchedulingAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorSchedulingAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorSchedulingAction.Request request = UpdateConnectorSchedulingAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorSchedulingAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorServiceTypeAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorServiceTypeAction.java index 89c3303f8cc94..9375c338d64b4 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorServiceTypeAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorServiceTypeAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorServiceTypeAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_service_type_action"; @@ -29,20 +33,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_service_type")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_service_type")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorServiceTypeAction.Request request = UpdateConnectorServiceTypeAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorServiceTypeAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorServiceTypeAction.Request request = UpdateConnectorServiceTypeAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorServiceTypeAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorStatusAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorStatusAction.java index 9770a051ce4fc..cb741fa8301a9 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorStatusAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorStatusAction.java @@ -13,8 +13,10 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; +import java.io.IOException; import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; @@ -22,6 +24,8 @@ @ServerlessScope(Scope.PUBLIC) public class RestUpdateConnectorStatusAction extends BaseRestHandler { + private static final String CONNECTOR_ID_PARAM = "connector_id"; + @Override public String getName() { return "connector_update_status_action"; @@ -29,20 +33,21 @@ public String getName() { @Override public List routes() { - return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{connector_id}/_status")); + return List.of(new Route(PUT, "/" + EnterpriseSearch.CONNECTOR_API_ENDPOINT + "/{" + CONNECTOR_ID_PARAM + "}/_status")); } @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - UpdateConnectorStatusAction.Request request = UpdateConnectorStatusAction.Request.fromXContentBytes( - restRequest.param("connector_id"), - restRequest.content(), - restRequest.getXContentType() - ); - return channel -> client.execute( - UpdateConnectorStatusAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorStatusAction.Request request = UpdateConnectorStatusAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_ID_PARAM) + ); + return channel -> client.execute( + UpdateConnectorStatusAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorApiKeyIdAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorApiKeyIdAction.java index b7f07fc4a34cb..7f726f21ce225 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorApiKeyIdAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorApiKeyIdAction.java @@ -7,21 +7,16 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import java.io.IOException; @@ -100,18 +95,6 @@ public ActionRequestValidationException validate() { PARSER.declareStringOrNull(optionalConstructorArg(), Connector.API_KEY_SECRET_ID_FIELD); } - public static UpdateConnectorApiKeyIdAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorApiKeyIdAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorApiKeyIdAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorConfigurationAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorConfigurationAction.java index 9069f832e1c44..5d36c5f886ea0 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorConfigurationAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorConfigurationAction.java @@ -7,22 +7,17 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorConfiguration; @@ -123,18 +118,6 @@ public ActionRequestValidationException validate() { PARSER.declareField(optionalConstructorArg(), (p, c) -> p.map(), VALUES_FIELD, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); } - public static UpdateConnectorConfigurationAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorConfigurationAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse connector configuration.", e); - } - } - public static UpdateConnectorConfigurationAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorErrorAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorErrorAction.java index ae86c1fc98df1..3e506fc835f65 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorErrorAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorErrorAction.java @@ -7,21 +7,16 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import java.io.IOException; @@ -84,18 +79,6 @@ public ActionRequestValidationException validate() { PARSER.declareStringOrNull(constructorArg(), Connector.ERROR_FIELD); } - public static UpdateConnectorErrorAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorErrorAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorErrorAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFeaturesAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFeaturesAction.java index c1f62c0efe6e8..56656855583aa 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFeaturesAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFeaturesAction.java @@ -7,20 +7,15 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorFeatures; @@ -83,18 +78,6 @@ public ActionRequestValidationException validate() { PARSER.declareObject(optionalConstructorArg(), (p, c) -> ConnectorFeatures.fromXContent(p), Connector.FEATURES_FIELD); } - public static UpdateConnectorFeaturesAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorFeaturesAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorFeaturesAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java index 54c9a6e6417dc..660956b2e9d7f 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java @@ -7,21 +7,16 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorFiltering; import org.elasticsearch.xpack.application.connector.filtering.FilteringAdvancedSnippet; @@ -138,18 +133,6 @@ public ActionRequestValidationException validate() { PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FilteringRule.fromXContent(p), FilteringRules.RULES_FIELD); } - public static UpdateConnectorFilteringAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorFilteringAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorFilteringAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringValidationAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringValidationAction.java index 2164019c62ba3..92291506d0719 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringValidationAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringValidationAction.java @@ -7,20 +7,15 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.filtering.FilteringRules; import org.elasticsearch.xpack.application.connector.filtering.FilteringValidationInfo; @@ -91,18 +86,6 @@ public ActionRequestValidationException validate() { PARSER.declareObject(constructorArg(), (p, c) -> FilteringValidationInfo.fromXContent(p), FilteringRules.VALIDATION_FIELD); } - public static UpdateConnectorFilteringValidationAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorFilteringValidationAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorFilteringValidationAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorIndexNameAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorIndexNameAction.java index c6cb18089ad06..e7840e1f84fad 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorIndexNameAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorIndexNameAction.java @@ -7,21 +7,16 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import java.io.IOException; @@ -73,18 +68,6 @@ public String getIndexName() { PARSER.declareStringOrNull(constructorArg(), Connector.INDEX_NAME_FIELD); } - public static UpdateConnectorIndexNameAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorIndexNameAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorIndexNameAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorLastSyncStatsAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorLastSyncStatsAction.java index 1628a493cbec5..ae3be3801786c 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorLastSyncStatsAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorLastSyncStatsAction.java @@ -7,22 +7,17 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorSyncInfo; import org.elasticsearch.xpack.application.connector.ConnectorSyncStatus; @@ -157,18 +152,6 @@ public ActionRequestValidationException validate() { PARSER.declareObjectOrNull(optionalConstructorArg(), (p, c) -> p.map(), null, Connector.SYNC_CURSOR_FIELD); } - public static UpdateConnectorLastSyncStatsAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorLastSyncStatsAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorLastSyncStatsAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorNameAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorNameAction.java index 1aa10f0b7dd45..bbc1f992b48e2 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorNameAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorNameAction.java @@ -7,21 +7,16 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import java.io.IOException; @@ -101,18 +96,6 @@ public ActionRequestValidationException validate() { PARSER.declareStringOrNull(optionalConstructorArg(), Connector.DESCRIPTION_FIELD); } - public static UpdateConnectorNameAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorNameAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorNameAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorNativeAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorNativeAction.java index 9b539d055ef7e..7b3f2e4577f4e 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorNativeAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorNativeAction.java @@ -7,20 +7,15 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import java.io.IOException; @@ -70,18 +65,6 @@ public boolean isNative() { PARSER.declareBoolean(constructorArg(), Connector.IS_NATIVE_FIELD); } - public static UpdateConnectorNativeAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorNativeAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorNativeAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineAction.java index ee1f24ea6d20d..e58d614f4ef21 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorPipelineAction.java @@ -7,20 +7,15 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorIngestPipeline; @@ -87,18 +82,6 @@ public ActionRequestValidationException validate() { PARSER.declareObject(constructorArg(), (p, c) -> ConnectorIngestPipeline.fromXContent(p), Connector.PIPELINE_FIELD); } - public static UpdateConnectorPipelineAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorPipelineAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorPipelineAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorServiceTypeAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorServiceTypeAction.java index 68aec9624d30f..de07a6db21bab 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorServiceTypeAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorServiceTypeAction.java @@ -7,20 +7,15 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import java.io.IOException; @@ -71,18 +66,6 @@ public String getServiceType() { PARSER.declareString(constructorArg(), Connector.SERVICE_TYPE_FIELD); } - public static UpdateConnectorServiceTypeAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorServiceTypeAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorServiceTypeAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorStatusAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorStatusAction.java index cd8b36df2e148..aebaa0afb9052 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorStatusAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorStatusAction.java @@ -7,21 +7,16 @@ package org.elasticsearch.xpack.application.connector.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorStatus; @@ -77,18 +72,6 @@ public ConnectorStatus getStatus() { ); } - public static UpdateConnectorStatusAction.Request fromXContentBytes( - String connectorId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorStatusAction.Request.fromXContent(parser, connectorId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static UpdateConnectorStatusAction.Request fromXContent(XContentParser parser, String connectorId) throws IOException { return PARSER.parse(parser, connectorId); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJob.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJob.java index b72bffab81e1f..4aabb9e1af663 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJob.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJob.java @@ -91,7 +91,7 @@ public class ConnectorSyncJob implements Writeable, ToXContentObject { public static final ParseField LAST_SEEN_FIELD = new ParseField("last_seen"); - static final ParseField METADATA_FIELD = new ParseField("metadata"); + public static final ParseField METADATA_FIELD = new ParseField("metadata"); static final ParseField STARTED_AT_FIELD = new ParseField("started_at"); diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexService.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexService.java index 72ca1f1d8499b..9ef895a3a5786 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexService.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexService.java @@ -464,6 +464,11 @@ public void updateConnectorSyncJobIngestionStats( Instant lastSeen = Objects.nonNull(request.getLastSeen()) ? request.getLastSeen() : Instant.now(); fieldsToUpdate.put(ConnectorSyncJob.LAST_SEEN_FIELD.getPreferredName(), lastSeen); + Map metadata = request.getMetadata(); + if (Objects.nonNull(metadata)) { + fieldsToUpdate.put(ConnectorSyncJob.METADATA_FIELD.getPreferredName(), metadata); + } + final UpdateRequest updateRequest = new UpdateRequest(CONNECTOR_SYNC_JOB_INDEX_NAME, syncJobId).setRefreshPolicy( WriteRequest.RefreshPolicy.IMMEDIATE ).doc(fieldsToUpdate); diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobStateMachine.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobStateMachine.java index dc624b5bf8ba1..952cd12d1ee7c 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobStateMachine.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobStateMachine.java @@ -45,7 +45,7 @@ public class ConnectorSyncJobStateMachine { * @param next The proposed next {link ConnectorSyncStatus} of the {@link ConnectorSyncJob}. */ public static boolean isValidTransition(ConnectorSyncStatus current, ConnectorSyncStatus next) { - return VALID_TRANSITIONS.getOrDefault(current, Collections.emptySet()).contains(next); + return validNextStates(current).contains(next); } /** @@ -60,4 +60,8 @@ public static void assertValidStateTransition(ConnectorSyncStatus current, Conne if (isValidTransition(current, next)) return; throw new ConnectorSyncJobInvalidStatusTransitionException(current, next); } + + public static Set validNextStates(ConnectorSyncStatus current) { + return VALID_TRANSITIONS.getOrDefault(current, Collections.emptySet()); + } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/ClaimConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/ClaimConnectorSyncJobAction.java index 74a7e1bdd0282..b108116a5e68c 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/ClaimConnectorSyncJobAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/ClaimConnectorSyncJobAction.java @@ -7,21 +7,16 @@ package org.elasticsearch.xpack.application.connector.syncjob.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.action.ConnectorUpdateActionResponse; import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJob; @@ -105,14 +100,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static Request fromXContentBytes(String connectorSyncJobId, BytesReference source, XContentType xContentType) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return fromXContent(parser, connectorSyncJobId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse request" + source.utf8ToString()); - } - } - @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/PostConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/PostConnectorSyncJobAction.java index 5e898d9524d0b..8c1d24e466daa 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/PostConnectorSyncJobAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/PostConnectorSyncJobAction.java @@ -7,21 +7,16 @@ package org.elasticsearch.xpack.application.connector.syncjob.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorTemplateRegistry; import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJob; @@ -98,14 +93,6 @@ public ConnectorSyncJobTriggerMethod getTriggerMethod() { return triggerMethod; } - public static Request fromXContentBytes(BytesReference source, XContentType xContentType) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return Request.fromXContent(parser); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString(), e); - } - } - public static Request fromXContent(XContentParser parser) throws IOException { return PARSER.parse(parser, null); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestClaimConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestClaimConnectorSyncJobAction.java index c048f43b6baa6..bea26e77ca531 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestClaimConnectorSyncJobAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestClaimConnectorSyncJobAction.java @@ -13,6 +13,7 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; import java.io.IOException; @@ -41,12 +42,13 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - ClaimConnectorSyncJobAction.Request request = ClaimConnectorSyncJobAction.Request.fromXContentBytes( - restRequest.param(CONNECTOR_SYNC_JOB_ID_PARAM), - restRequest.content(), - restRequest.getXContentType() - ); - - return channel -> client.execute(ClaimConnectorSyncJobAction.INSTANCE, request, new RestToXContentListener<>(channel)); + try (XContentParser parser = restRequest.contentParser()) { + ClaimConnectorSyncJobAction.Request request = ClaimConnectorSyncJobAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_SYNC_JOB_ID_PARAM) + ); + + return channel -> client.execute(ClaimConnectorSyncJobAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestPostConnectorSyncJobAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestPostConnectorSyncJobAction.java index eac645ab3dc77..66a620d22f753 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestPostConnectorSyncJobAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestPostConnectorSyncJobAction.java @@ -14,6 +14,7 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; import java.io.IOException; @@ -36,15 +37,13 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - PostConnectorSyncJobAction.Request request = PostConnectorSyncJobAction.Request.fromXContentBytes( - restRequest.content(), - restRequest.getXContentType() - ); - - return channel -> client.execute( - PostConnectorSyncJobAction.INSTANCE, - request, - new RestToXContentListener<>(channel, r -> RestStatus.CREATED, r -> null) - ); + try (XContentParser parser = restRequest.contentParser()) { + PostConnectorSyncJobAction.Request request = PostConnectorSyncJobAction.Request.fromXContent(parser); + return channel -> client.execute( + PostConnectorSyncJobAction.INSTANCE, + request, + new RestToXContentListener<>(channel, r -> RestStatus.CREATED, r -> null) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobErrorAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobErrorAction.java index 720bfdf416827..a158191a705ef 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobErrorAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobErrorAction.java @@ -13,6 +13,7 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; import org.elasticsearch.xpack.application.connector.action.ConnectorUpdateActionResponse; @@ -41,16 +42,16 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - UpdateConnectorSyncJobErrorAction.Request request = UpdateConnectorSyncJobErrorAction.Request.fromXContentBytes( - restRequest.param(CONNECTOR_SYNC_JOB_ID_PARAM), - restRequest.content(), - restRequest.getXContentType() - ); - - return restChannel -> client.execute( - UpdateConnectorSyncJobErrorAction.INSTANCE, - request, - new RestToXContentListener<>(restChannel, ConnectorUpdateActionResponse::status) - ); + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorSyncJobErrorAction.Request request = UpdateConnectorSyncJobErrorAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_SYNC_JOB_ID_PARAM) + ); + return restChannel -> client.execute( + UpdateConnectorSyncJobErrorAction.INSTANCE, + request, + new RestToXContentListener<>(restChannel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobIngestionStatsAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobIngestionStatsAction.java index d55d3ba87d1df..500da2a216b1e 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobIngestionStatsAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobIngestionStatsAction.java @@ -13,6 +13,7 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.application.EnterpriseSearch; import org.elasticsearch.xpack.application.connector.action.ConnectorUpdateActionResponse; @@ -40,16 +41,17 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - UpdateConnectorSyncJobIngestionStatsAction.Request request = UpdateConnectorSyncJobIngestionStatsAction.Request.fromXContentBytes( - restRequest.param(CONNECTOR_SYNC_JOB_ID_PARAM), - restRequest.content(), - restRequest.getXContentType() - ); - - return channel -> client.execute( - UpdateConnectorSyncJobIngestionStatsAction.INSTANCE, - request, - new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) - ); + try (XContentParser parser = restRequest.contentParser()) { + UpdateConnectorSyncJobIngestionStatsAction.Request request = UpdateConnectorSyncJobIngestionStatsAction.Request.fromXContent( + parser, + restRequest.param(CONNECTOR_SYNC_JOB_ID_PARAM) + ); + + return channel -> client.execute( + UpdateConnectorSyncJobIngestionStatsAction.INSTANCE, + request, + new RestToXContentListener<>(channel, ConnectorUpdateActionResponse::status) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobErrorAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobErrorAction.java index 3ce5d61e95fdb..2235ba7cfe720 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobErrorAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobErrorAction.java @@ -7,20 +7,15 @@ package org.elasticsearch.xpack.application.connector.syncjob.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.action.ConnectorUpdateActionResponse; import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJob; import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobConstants; @@ -66,14 +61,6 @@ public Request(String connectorSyncJobId, String error) { this.error = error; } - public static Request fromXContentBytes(String connectorSyncJobId, BytesReference source, XContentType xContentType) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorSyncJobErrorAction.Request.fromXContent(parser, connectorSyncJobId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString()); - } - } - public static UpdateConnectorSyncJobErrorAction.Request fromXContent(XContentParser parser, String connectorSyncJobId) throws IOException { return PARSER.parse(parser, connectorSyncJobId); diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsAction.java index d76f2c3b788fc..0fd9b6dec8184 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsAction.java @@ -7,22 +7,17 @@ package org.elasticsearch.xpack.application.connector.syncjob.action; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorUtils; import org.elasticsearch.xpack.application.connector.action.ConnectorUpdateActionResponse; @@ -30,6 +25,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.Map; import java.util.Objects; import static org.elasticsearch.action.ValidateActions.addValidationError; @@ -57,6 +53,7 @@ public static class Request extends ConnectorSyncJobActionRequest implements ToX private final Long indexedDocumentVolume; private final Long totalDocumentCount; private final Instant lastSeen; + private final Map metadata; public Request(StreamInput in) throws IOException { super(in); @@ -66,6 +63,7 @@ public Request(StreamInput in) throws IOException { this.indexedDocumentVolume = in.readLong(); this.totalDocumentCount = in.readOptionalLong(); this.lastSeen = in.readOptionalInstant(); + this.metadata = in.readGenericMap(); } public Request( @@ -74,7 +72,8 @@ public Request( Long indexedDocumentCount, Long indexedDocumentVolume, Long totalDocumentCount, - Instant lastSeen + Instant lastSeen, + Map metadata ) { this.connectorSyncJobId = connectorSyncJobId; this.deletedDocumentCount = deletedDocumentCount; @@ -82,6 +81,7 @@ public Request( this.indexedDocumentVolume = indexedDocumentVolume; this.totalDocumentCount = totalDocumentCount; this.lastSeen = lastSeen; + this.metadata = metadata; } public String getConnectorSyncJobId() { @@ -108,6 +108,10 @@ public Instant getLastSeen() { return lastSeen; } + public Map getMetadata() { + return metadata; + } + @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; @@ -135,6 +139,7 @@ public ActionRequestValidationException validate() { return validationException; } + @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("connector_sync_job_update_ingestion_stats", false, (args, connectorSyncJobId) -> { Long deletedDocumentCount = (Long) args[0]; @@ -143,6 +148,7 @@ public ActionRequestValidationException validate() { Long totalDocumentVolume = args[3] != null ? (Long) args[3] : null; Instant lastSeen = args[4] != null ? (Instant) args[4] : null; + Map metadata = (Map) args[5]; return new Request( connectorSyncJobId, @@ -150,7 +156,8 @@ public ActionRequestValidationException validate() { indexedDocumentCount, indexedDocumentVolume, totalDocumentVolume, - lastSeen + lastSeen, + metadata ); }); @@ -165,18 +172,7 @@ public ActionRequestValidationException validate() { ConnectorSyncJob.LAST_SEEN_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING ); - } - - public static UpdateConnectorSyncJobIngestionStatsAction.Request fromXContentBytes( - String connectorSyncJobId, - BytesReference source, - XContentType xContentType - ) { - try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { - return UpdateConnectorSyncJobIngestionStatsAction.Request.fromXContent(parser, connectorSyncJobId); - } catch (IOException e) { - throw new ElasticsearchParseException("Failed to parse: " + source.utf8ToString()); - } + PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(), ConnectorSyncJob.METADATA_FIELD); } public static Request fromXContent(XContentParser parser, String connectorSyncJobId) throws IOException { @@ -192,6 +188,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(ConnectorSyncJob.INDEXED_DOCUMENT_VOLUME_FIELD.getPreferredName(), indexedDocumentVolume); builder.field(ConnectorSyncJob.TOTAL_DOCUMENT_COUNT_FIELD.getPreferredName(), totalDocumentCount); builder.field(ConnectorSyncJob.LAST_SEEN_FIELD.getPreferredName(), lastSeen); + builder.field(ConnectorSyncJob.METADATA_FIELD.getPreferredName(), metadata); } builder.endObject(); return builder; @@ -206,6 +203,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeLong(indexedDocumentVolume); out.writeOptionalLong(totalDocumentCount); out.writeOptionalInstant(lastSeen); + out.writeGenericMap(metadata); } @Override @@ -218,7 +216,8 @@ public boolean equals(Object o) { && Objects.equals(indexedDocumentCount, request.indexedDocumentCount) && Objects.equals(indexedDocumentVolume, request.indexedDocumentVolume) && Objects.equals(totalDocumentCount, request.totalDocumentCount) - && Objects.equals(lastSeen, request.lastSeen); + && Objects.equals(lastSeen, request.lastSeen) + && Objects.equals(metadata, request.metadata); } @Override @@ -229,7 +228,8 @@ public int hashCode() { indexedDocumentCount, indexedDocumentVolume, totalDocumentCount, - lastSeen + lastSeen, + metadata ); } } diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java index a696c6e6dde54..e7de5b073b114 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java @@ -802,7 +802,7 @@ public void testUpdateConnectorStatus_WithInvalidStatus() throws Exception { Connector connector = ConnectorTestUtils.getRandomConnector(); String connectorId = randomUUID(); - ConnectorCreateActionResponse resp = awaitCreateConnector(connectorId, connector); + awaitCreateConnector(connectorId, connector); Connector indexedConnector = awaitGetConnector(connectorId); ConnectorStatus newInvalidStatus = ConnectorTestUtils.getRandomInvalidConnectorNextStatus(indexedConnector.getStatus()); diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorStateMachineTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorStateMachineTests.java index d1f08f80d02f2..739ad44fd6c4c 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorStateMachineTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorStateMachineTests.java @@ -17,6 +17,7 @@ public void testValidTransitionFromCreated() { } public void testInvalidTransitionFromCreated() { + assertFalse(ConnectorStateMachine.isValidTransition(ConnectorStatus.CREATED, ConnectorStatus.CREATED)); assertFalse(ConnectorStateMachine.isValidTransition(ConnectorStatus.CREATED, ConnectorStatus.CONFIGURED)); assertFalse(ConnectorStateMachine.isValidTransition(ConnectorStatus.CREATED, ConnectorStatus.CONNECTED)); } @@ -28,12 +29,14 @@ public void testValidTransitionFromNeedsConfiguration() { public void testInvalidTransitionFromNeedsConfiguration() { assertFalse(ConnectorStateMachine.isValidTransition(ConnectorStatus.NEEDS_CONFIGURATION, ConnectorStatus.CREATED)); assertFalse(ConnectorStateMachine.isValidTransition(ConnectorStatus.NEEDS_CONFIGURATION, ConnectorStatus.CONNECTED)); + assertFalse(ConnectorStateMachine.isValidTransition(ConnectorStatus.NEEDS_CONFIGURATION, ConnectorStatus.NEEDS_CONFIGURATION)); } public void testValidTransitionFromConfigured() { assertTrue(ConnectorStateMachine.isValidTransition(ConnectorStatus.CONFIGURED, ConnectorStatus.NEEDS_CONFIGURATION)); assertTrue(ConnectorStateMachine.isValidTransition(ConnectorStatus.CONFIGURED, ConnectorStatus.CONNECTED)); assertTrue(ConnectorStateMachine.isValidTransition(ConnectorStatus.CONFIGURED, ConnectorStatus.ERROR)); + assertTrue(ConnectorStateMachine.isValidTransition(ConnectorStatus.CONFIGURED, ConnectorStatus.CONFIGURED)); } public void testInvalidTransitionFromConfigured() { @@ -43,6 +46,7 @@ public void testInvalidTransitionFromConfigured() { public void testValidTransitionFromConnected() { assertTrue(ConnectorStateMachine.isValidTransition(ConnectorStatus.CONNECTED, ConnectorStatus.CONFIGURED)); assertTrue(ConnectorStateMachine.isValidTransition(ConnectorStatus.CONNECTED, ConnectorStatus.ERROR)); + assertTrue(ConnectorStateMachine.isValidTransition(ConnectorStatus.CONNECTED, ConnectorStatus.CONNECTED)); } public void testInvalidTransitionFromConnected() { @@ -53,6 +57,7 @@ public void testInvalidTransitionFromConnected() { public void testValidTransitionFromError() { assertTrue(ConnectorStateMachine.isValidTransition(ConnectorStatus.ERROR, ConnectorStatus.CONNECTED)); assertTrue(ConnectorStateMachine.isValidTransition(ConnectorStatus.ERROR, ConnectorStatus.CONFIGURED)); + assertTrue(ConnectorStateMachine.isValidTransition(ConnectorStatus.ERROR, ConnectorStatus.ERROR)); } public void testInvalidTransitionFromError() { @@ -60,12 +65,6 @@ public void testInvalidTransitionFromError() { assertFalse(ConnectorStateMachine.isValidTransition(ConnectorStatus.ERROR, ConnectorStatus.NEEDS_CONFIGURATION)); } - public void testTransitionToSameState() { - for (ConnectorStatus state : ConnectorStatus.values()) { - assertFalse("Transition from " + state + " to itself should be invalid", ConnectorStateMachine.isValidTransition(state, state)); - } - } - public void testAssertValidStateTransition_ExpectExceptionOnInvalidTransition() { assertThrows( ConnectorInvalidStatusTransitionException.class, diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorApiKeyIdActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorApiKeyIdActionTests.java new file mode 100644 index 0000000000000..53d5de565f9f8 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorApiKeyIdActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorApiKeyIdActionTests extends ESTestCase { + + private RestUpdateConnectorApiKeyIdAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorApiKeyIdAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_api_key_id") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorConfigurationActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorConfigurationActionTests.java new file mode 100644 index 0000000000000..6712d98dc1c69 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorConfigurationActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorConfigurationActionTests extends ESTestCase { + + private RestUpdateConnectorConfigurationAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorConfigurationAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_configuration") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorErrorActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorErrorActionTests.java new file mode 100644 index 0000000000000..71a49488d5d21 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorErrorActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorErrorActionTests extends ESTestCase { + + private RestUpdateConnectorErrorAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorErrorAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_error") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFeaturesActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFeaturesActionTests.java new file mode 100644 index 0000000000000..8728efbc23fea --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFeaturesActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorFeaturesActionTests extends ESTestCase { + + private RestUpdateConnectorFeaturesAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorFeaturesAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_features") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringActionTests.java new file mode 100644 index 0000000000000..3a2e009758625 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorFilteringActionTests extends ESTestCase { + + private RestUpdateConnectorFilteringAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorFilteringAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_filtering") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringValidationActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringValidationActionTests.java new file mode 100644 index 0000000000000..812f2a7c5ab5e --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorFilteringValidationActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorFilteringValidationActionTests extends ESTestCase { + + private RestUpdateConnectorFilteringValidationAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorFilteringValidationAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_filtering/_validation") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorIndexNameActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorIndexNameActionTests.java new file mode 100644 index 0000000000000..18839e02dd7a3 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorIndexNameActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorIndexNameActionTests extends ESTestCase { + + private RestUpdateConnectorIndexNameAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorIndexNameAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_index_name") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorLastSyncStatsActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorLastSyncStatsActionTests.java new file mode 100644 index 0000000000000..49f0c239debab --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorLastSyncStatsActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorLastSyncStatsActionTests extends ESTestCase { + + private RestUpdateConnectorLastSyncStatsAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorLastSyncStatsAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_last_sync") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNameActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNameActionTests.java new file mode 100644 index 0000000000000..10b0a7cebf94e --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNameActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorNameActionTests extends ESTestCase { + + private RestUpdateConnectorNameAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorNameAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_name") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNativeActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNativeActionTests.java new file mode 100644 index 0000000000000..8b65aad215efe --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorNativeActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorNativeActionTests extends ESTestCase { + + private RestUpdateConnectorNativeAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorNativeAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_native") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorPipelineActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorPipelineActionTests.java new file mode 100644 index 0000000000000..a5bb4e0689696 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorPipelineActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorPipelineActionTests extends ESTestCase { + + private RestUpdateConnectorPipelineAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorPipelineAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_pipeline") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorSchedulingActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorSchedulingActionTests.java new file mode 100644 index 0000000000000..91b26f1d9ada8 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorSchedulingActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorSchedulingActionTests extends ESTestCase { + + private RestUpdateConnectorSchedulingAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorSchedulingAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_scheduling") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorServiceTypeActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorServiceTypeActionTests.java new file mode 100644 index 0000000000000..16657f17e5d27 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorServiceTypeActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorServiceTypeActionTests extends ESTestCase { + + private RestUpdateConnectorServiceTypeAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorServiceTypeAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_service_type") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorStatusActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorStatusActionTests.java new file mode 100644 index 0000000000000..fc083ede81395 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/RestUpdateConnectorStatusActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorStatusActionTests extends ESTestCase { + + private RestUpdateConnectorStatusAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorStatusAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/123/_status") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexServiceTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexServiceTests.java index b9a77adc12a3c..f6c0a54f107b4 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexServiceTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobIndexServiceTests.java @@ -784,6 +784,7 @@ public void testUpdateConnectorSyncJobError_WithStatusPending_ExpectStatusExcept assertThrows(ElasticsearchStatusException.class, () -> awaitUpdateConnectorSyncJob(syncJobId, "some error")); } + @SuppressWarnings("unchecked") public void testUpdateConnectorSyncJobIngestionStats() throws Exception { PostConnectorSyncJobAction.Request syncJobRequest = ConnectorSyncJobTestUtils.getRandomPostConnectorSyncJobActionRequest( connectorOneId @@ -802,6 +803,7 @@ public void testUpdateConnectorSyncJobIngestionStats() throws Exception { Long requestIndexedDocumentVolume = request.getIndexedDocumentVolume(); Long requestTotalDocumentCount = request.getTotalDocumentCount(); Instant requestLastSeen = request.getLastSeen(); + Map metadata = request.getMetadata(); Long deletedDocumentCountAfterUpdate = (Long) syncJobSourceAfterUpdate.get( ConnectorSyncJob.DELETED_DOCUMENT_COUNT_FIELD.getPreferredName() @@ -818,6 +820,9 @@ public void testUpdateConnectorSyncJobIngestionStats() throws Exception { Instant lastSeenAfterUpdate = Instant.parse( (String) syncJobSourceAfterUpdate.get(ConnectorSyncJob.LAST_SEEN_FIELD.getPreferredName()) ); + Map metadataAfterUpdate = (Map) syncJobSourceAfterUpdate.get( + ConnectorSyncJob.METADATA_FIELD.getPreferredName() + ); assertThat(updateResponse.status(), equalTo(RestStatus.OK)); assertThat(deletedDocumentCountAfterUpdate, equalTo(requestDeletedDocumentCount)); @@ -825,6 +830,7 @@ public void testUpdateConnectorSyncJobIngestionStats() throws Exception { assertThat(indexedDocumentVolumeAfterUpdate, equalTo(requestIndexedDocumentVolume)); assertThat(totalDocumentCountAfterUpdate, equalTo(requestTotalDocumentCount)); assertThat(lastSeenAfterUpdate, equalTo(requestLastSeen)); + assertThat(metadataAfterUpdate, equalTo(metadata)); assertFieldsExceptAllIngestionStatsDidNotUpdate(syncJobSourceBeforeUpdate, syncJobSourceAfterUpdate); } @@ -838,12 +844,14 @@ public void testUpdateConnectorSyncJobIngestionStats_WithoutLastSeen_ExpectUpdat Instant lastSeenBeforeUpdate = Instant.parse( (String) syncJobSourceBeforeUpdate.get(ConnectorSyncJob.LAST_SEEN_FIELD.getPreferredName()) ); + UpdateConnectorSyncJobIngestionStatsAction.Request request = new UpdateConnectorSyncJobIngestionStatsAction.Request( syncJobId, 10L, 20L, 100L, 10L, + null, null ); @@ -866,7 +874,7 @@ public void testUpdateConnectorSyncJobIngestionStats_WithMissingSyncJobId_Expect expectThrows( ResourceNotFoundException.class, () -> awaitUpdateConnectorSyncJobIngestionStats( - new UpdateConnectorSyncJobIngestionStatsAction.Request(NON_EXISTING_SYNC_JOB_ID, 0L, 0L, 0L, 0L, Instant.now()) + new UpdateConnectorSyncJobIngestionStatsAction.Request(NON_EXISTING_SYNC_JOB_ID, 0L, 0L, 0L, 0L, Instant.now(), null) ) ); } @@ -1067,7 +1075,8 @@ private static void assertFieldsExceptAllIngestionStatsDidNotUpdate( ConnectorSyncJob.INDEXED_DOCUMENT_COUNT_FIELD, ConnectorSyncJob.INDEXED_DOCUMENT_VOLUME_FIELD, ConnectorSyncJob.TOTAL_DOCUMENT_COUNT_FIELD, - ConnectorSyncJob.LAST_SEEN_FIELD + ConnectorSyncJob.LAST_SEEN_FIELD, + ConnectorSyncJob.METADATA_FIELD ) ); } diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobTestUtils.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobTestUtils.java index a4ff76e6f2cf9..e72bf04fb7e55 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobTestUtils.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/ConnectorSyncJobTestUtils.java @@ -160,7 +160,8 @@ public static UpdateConnectorSyncJobIngestionStatsAction.Request getRandomUpdate randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), - randomInstantBetween(lowerBoundInstant, upperBoundInstant) + randomInstantBetween(lowerBoundInstant, upperBoundInstant), + randomMap(2, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))) ); } @@ -176,7 +177,8 @@ public static UpdateConnectorSyncJobIngestionStatsAction.Request getRandomUpdate randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), - randomInstantBetween(lowerBoundInstant, upperBoundInstant) + randomInstantBetween(lowerBoundInstant, upperBoundInstant), + randomMap(2, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))) ); } diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestClaimConnectorSyncJobActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestClaimConnectorSyncJobActionTests.java new file mode 100644 index 0000000000000..567fe803a250e --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestClaimConnectorSyncJobActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestClaimConnectorSyncJobActionTests extends ESTestCase { + + private RestClaimConnectorSyncJobAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestClaimConnectorSyncJobAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/_sync_job/456/_claim") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestPostConnectorSyncJobActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestPostConnectorSyncJobActionTests.java new file mode 100644 index 0000000000000..231e0d2b14144 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestPostConnectorSyncJobActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestPostConnectorSyncJobActionTests extends ESTestCase { + + private RestPostConnectorSyncJobAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestPostConnectorSyncJobAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/_sync_job") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobErrorActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobErrorActionTests.java new file mode 100644 index 0000000000000..19fd2df9dc3b6 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobErrorActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorSyncJobErrorActionTests extends ESTestCase { + + private RestUpdateConnectorSyncJobErrorAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorSyncJobErrorAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/_sync_job/456/_error") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobIngestionStatsActionTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobIngestionStatsActionTests.java new file mode 100644 index 0000000000000..edb0250524049 --- /dev/null +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/RestUpdateConnectorSyncJobIngestionStatsActionTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.application.connector.syncjob.action; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasToString; +import static org.mockito.Mockito.mock; + +public class RestUpdateConnectorSyncJobIngestionStatsActionTests extends ESTestCase { + + private RestUpdateConnectorSyncJobIngestionStatsAction action; + + @Override + public void setUp() throws Exception { + super.setUp(); + action = new RestUpdateConnectorSyncJobIngestionStatsAction(); + } + + public void testPrepareRequest_emptyPayload_badRequestError() { + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.PUT) + .withPath("/_connector/_sync_job/456/_stats") + .build(); + + final ElasticsearchParseException e = expectThrows( + ElasticsearchParseException.class, + () -> action.prepareRequest(request, mock(NodeClient.class)) + ); + assertThat(e, hasToString(containsString("request body is required"))); + } +} diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsActionRequestBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsActionRequestBWCSerializingTests.java index 6e2178d8341cf..ff586ae28109a 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsActionRequestBWCSerializingTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsActionRequestBWCSerializingTests.java @@ -55,7 +55,8 @@ protected UpdateConnectorSyncJobIngestionStatsAction.Request mutateInstanceForVe instance.getIndexedDocumentCount(), instance.getIndexedDocumentVolume(), instance.getTotalDocumentCount(), - instance.getLastSeen() + instance.getLastSeen(), + instance.getMetadata() ); } } diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsActionRequestTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsActionRequestTests.java index 4f78ad3ffa7e7..1f3ca480ee1c8 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsActionRequestTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/syncjob/action/UpdateConnectorSyncJobIngestionStatsActionRequestTests.java @@ -8,22 +8,46 @@ package org.elasticsearch.xpack.application.connector.syncjob.action; import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobTestUtils; +import org.junit.Before; +import java.io.IOException; import java.time.Instant; +import java.util.List; +import java.util.Map; +import static java.util.Collections.emptyList; import static org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobConstants.EMPTY_CONNECTOR_SYNC_JOB_ID_ERROR_MESSAGE; import static org.elasticsearch.xpack.application.connector.syncjob.action.UpdateConnectorSyncJobIngestionStatsAction.Request.DELETED_DOCUMENT_COUNT_NEGATIVE_ERROR_MESSAGE; import static org.elasticsearch.xpack.application.connector.syncjob.action.UpdateConnectorSyncJobIngestionStatsAction.Request.INDEXED_DOCUMENT_COUNT_NEGATIVE_ERROR_MESSAGE; import static org.elasticsearch.xpack.application.connector.syncjob.action.UpdateConnectorSyncJobIngestionStatsAction.Request.INDEXED_DOCUMENT_VOLUME_NEGATIVE_ERROR_MESSAGE; import static org.elasticsearch.xpack.application.connector.syncjob.action.UpdateConnectorSyncJobIngestionStatsAction.Request.TOTAL_DOCUMENT_COUNT_NEGATIVE_ERROR_MESSAGE; +import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; public class UpdateConnectorSyncJobIngestionStatsActionRequestTests extends ESTestCase { + private NamedWriteableRegistry namedWriteableRegistry; + + @Before + public void registerNamedObjects() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, emptyList()); + + List namedWriteables = searchModule.getNamedWriteables(); + namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); + } + public void testValidate_WhenRequestIsValid_ExpectNoValidationError() { UpdateConnectorSyncJobIngestionStatsAction.Request request = ConnectorSyncJobTestUtils .getRandomUpdateConnectorSyncJobIngestionStatsActionRequest(); @@ -39,7 +63,8 @@ public void testValidate_WhenConnectorSyncJobIdIsEmpty_ExpectValidationError() { 0L, 0L, 0L, - Instant.now() + Instant.now(), + null ); ActionRequestValidationException exception = request.validate(); @@ -54,7 +79,8 @@ public void testValidate_WhenConnectorSyncJobIdIsNull_ExpectValidationError() { 0L, 0L, 0L, - Instant.now() + Instant.now(), + null ); ActionRequestValidationException exception = request.validate(); @@ -69,7 +95,8 @@ public void testValidate_WhenDeletedDocumentCountIsNegative_ExpectValidationErro 0L, 0L, 0L, - Instant.now() + Instant.now(), + null ); ActionRequestValidationException exception = request.validate(); @@ -84,7 +111,8 @@ public void testValidate_WhenIndexedDocumentCountIsNegative_ExpectValidationErro -10L, 0L, 0L, - Instant.now() + Instant.now(), + null ); ActionRequestValidationException exception = request.validate(); @@ -99,7 +127,8 @@ public void testValidate_WhenIndexedDocumentVolumeIsNegative_ExpectValidationErr 0L, -10L, 0L, - Instant.now() + Instant.now(), + null ); ActionRequestValidationException exception = request.validate(); @@ -114,11 +143,92 @@ public void testValidate_WhenTotalDocumentCountIsNegative_ExpectValidationError( 0L, 0L, -10L, - Instant.now() + Instant.now(), + null ); ActionRequestValidationException exception = request.validate(); assertThat(exception, notNullValue()); assertThat(exception.getMessage(), containsString(TOTAL_DOCUMENT_COUNT_NEGATIVE_ERROR_MESSAGE)); } + + public void testParseRequest_requiredFields_validRequest() throws IOException { + String requestPayload = XContentHelper.stripWhitespace(""" + { + "deleted_document_count": 10, + "indexed_document_count": 20, + "indexed_document_volume": 1000 + } + """); + + UpdateConnectorSyncJobIngestionStatsAction.Request request = UpdateConnectorSyncJobIngestionStatsAction.Request.fromXContent( + XContentHelper.createParser(XContentParserConfiguration.EMPTY, new BytesArray(requestPayload), XContentType.JSON), + randomUUID() + ); + + assertThat(request.getDeletedDocumentCount(), equalTo(10L)); + assertThat(request.getIndexedDocumentCount(), equalTo(20L)); + assertThat(request.getIndexedDocumentVolume(), equalTo(1000L)); + } + + public void testParseRequest_allFieldsWithoutLastSeen_validRequest() throws IOException { + String requestPayload = XContentHelper.stripWhitespace(""" + { + "deleted_document_count": 10, + "indexed_document_count": 20, + "indexed_document_volume": 1000, + "total_document_count": 55, + "metadata": {"key1": 1, "key2": 2} + } + """); + + UpdateConnectorSyncJobIngestionStatsAction.Request request = UpdateConnectorSyncJobIngestionStatsAction.Request.fromXContent( + XContentHelper.createParser(XContentParserConfiguration.EMPTY, new BytesArray(requestPayload), XContentType.JSON), + randomUUID() + ); + + assertThat(request.getDeletedDocumentCount(), equalTo(10L)); + assertThat(request.getIndexedDocumentCount(), equalTo(20L)); + assertThat(request.getIndexedDocumentVolume(), equalTo(1000L)); + assertThat(request.getTotalDocumentCount(), equalTo(55L)); + assertThat(request.getMetadata(), equalTo(Map.of("key1", 1, "key2", 2))); + } + + public void testParseRequest_metadataTypeInt_invalidRequest() throws IOException { + String requestPayload = XContentHelper.stripWhitespace(""" + { + "deleted_document_count": 10, + "indexed_document_count": 20, + "indexed_document_volume": 1000, + "metadata": 42 + } + """); + + expectThrows( + XContentParseException.class, + () -> UpdateConnectorSyncJobIngestionStatsAction.Request.fromXContent( + XContentHelper.createParser(XContentParserConfiguration.EMPTY, new BytesArray(requestPayload), XContentType.JSON), + randomUUID() + ) + ); + } + + public void testParseRequest_metadataTypeString_invalidRequest() throws IOException { + String requestPayload = XContentHelper.stripWhitespace(""" + { + "deleted_document_count": 10, + "indexed_document_count": 20, + "indexed_document_volume": 1000, + "metadata": "I'm a wrong metadata type" + } + """); + + expectThrows( + XContentParseException.class, + () -> UpdateConnectorSyncJobIngestionStatsAction.Request.fromXContent( + XContentHelper.createParser(XContentParserConfiguration.EMPTY, new BytesArray(requestPayload), XContentType.JSON), + randomUUID() + ) + ); + } } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/AggRef.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/AggRef.java deleted file mode 100644 index 54e44f55c96ab..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/AggRef.java +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.execution.search; - -/** - * Reference to a ES aggregation (which can be either a GROUP BY or Metric agg). - */ -public abstract class AggRef implements FieldExtraction { - - @Override - public void collectFields(QlSourceBuilder sourceBuilder) { - // Aggregations do not need any special fields - } - - @Override - public boolean supportedByAggsOnlyQuery() { - return true; - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/FieldExtraction.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/FieldExtraction.java deleted file mode 100644 index 6751a8412153b..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/FieldExtraction.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.execution.search; - -import org.elasticsearch.search.builder.SearchSourceBuilder; - -/** - * An interface for something that needs to extract field(s) from a result. - */ -public interface FieldExtraction { - - /** - * Add whatever is necessary to the {@link SearchSourceBuilder} - * in order to fetch the field. This can include tracking the score, - * {@code _source} fields, doc values fields, and script fields. - */ - void collectFields(QlSourceBuilder sourceBuilder); - - /** - * Is this aggregation supported in an "aggregation only" query - * ({@code true}) or should it force a scroll query ({@code false})? - */ - boolean supportedByAggsOnlyQuery(); -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/QlSourceBuilder.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/QlSourceBuilder.java deleted file mode 100644 index a8a0198400027..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/QlSourceBuilder.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.execution.search; - -import org.elasticsearch.script.Script; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.fetch.subphase.FieldAndFormat; - -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.Map; -import java.util.Set; - -/** - * A {@code SqlSourceBuilder} is a builder object passed to objects implementing - * {@link FieldExtraction} that can "build" whatever needs to be extracted from - * the resulting ES document as a field. - */ -public class QlSourceBuilder { - // The LinkedHashMaps preserve the order of the fields in the response - private final Set fetchFields = new LinkedHashSet<>(); - private final Map scriptFields = new LinkedHashMap<>(); - - boolean trackScores = false; - - public QlSourceBuilder() {} - - /** - * Turns on returning the {@code _score} for documents. - */ - public void trackScores() { - this.trackScores = true; - } - - /** - * Retrieve the requested field using the "fields" API - */ - public void addFetchField(String field, String format) { - fetchFields.add(new FieldAndFormat(field, format)); - } - - /** - * Return the given field as a script field with the supplied script - */ - public void addScriptField(String name, Script script) { - scriptFields.put(name, script); - } - - /** - * Collect the necessary fields, modifying the {@code SearchSourceBuilder} - * to retrieve them from the document. - */ - public void build(SearchSourceBuilder sourceBuilder) { - sourceBuilder.trackScores(this.trackScores); - fetchFields.forEach(field -> sourceBuilder.fetchField(new FieldAndFormat(field.field, field.format, null))); - scriptFields.forEach(sourceBuilder::scriptField); - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/AbstractFieldHitExtractor.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/AbstractFieldHitExtractor.java deleted file mode 100644 index 9f7155a78e66f..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/AbstractFieldHitExtractor.java +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.execution.search.extractor; - -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.document.DocumentField; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xpack.esql.core.InvalidArgumentException; -import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; -import org.elasticsearch.xpack.esql.core.type.DataType; - -import java.io.IOException; -import java.time.ZoneId; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -/** - * Extractor for ES fields. Works for both 'normal' fields but also nested ones (which require hitName to be set). - * The latter is used as metadata in assembling the results in the tabular response. - */ -public abstract class AbstractFieldHitExtractor implements HitExtractor { - - private final String fieldName, hitName; - private final DataType dataType; - private final ZoneId zoneId; - - protected MultiValueSupport multiValueSupport; - - public enum MultiValueSupport { - NONE, - LENIENT, - FULL - } - - protected AbstractFieldHitExtractor(String name, DataType dataType, ZoneId zoneId) { - this(name, dataType, zoneId, null, MultiValueSupport.NONE); - } - - protected AbstractFieldHitExtractor(String name, DataType dataType, ZoneId zoneId, MultiValueSupport multiValueSupport) { - this(name, dataType, zoneId, null, multiValueSupport); - } - - protected AbstractFieldHitExtractor( - String name, - DataType dataType, - ZoneId zoneId, - String hitName, - MultiValueSupport multiValueSupport - ) { - this.fieldName = name; - this.dataType = dataType; - this.zoneId = zoneId; - this.multiValueSupport = multiValueSupport; - this.hitName = hitName; - - if (hitName != null) { - if (name.contains(hitName) == false) { - throw new QlIllegalArgumentException("Hitname [{}] specified but not part of the name [{}]", hitName, name); - } - } - } - - @SuppressWarnings("this-escape") - protected AbstractFieldHitExtractor(StreamInput in) throws IOException { - fieldName = in.readString(); - String typeName = in.readOptionalString(); - dataType = typeName != null ? loadTypeFromName(typeName) : null; - hitName = in.readOptionalString(); - if (in.getTransportVersion().before(TransportVersions.V_8_6_0)) { - this.multiValueSupport = in.readBoolean() ? MultiValueSupport.LENIENT : MultiValueSupport.NONE; - } else { - this.multiValueSupport = in.readEnum(MultiValueSupport.class); - } - zoneId = readZoneId(in); - } - - protected DataType loadTypeFromName(String typeName) { - return DataType.fromTypeName(typeName); - } - - protected abstract ZoneId readZoneId(StreamInput in) throws IOException; - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(fieldName); - out.writeOptionalString(dataType == null ? null : dataType.typeName()); - out.writeOptionalString(hitName); - if (out.getTransportVersion().before(TransportVersions.V_8_6_0)) { - out.writeBoolean(multiValueSupport != MultiValueSupport.NONE); - } else { - out.writeEnum(multiValueSupport); - } - - } - - @Override - public Object extract(SearchHit hit) { - Object value = null; - DocumentField field = null; - if (hitName != null) { - value = unwrapFieldsMultiValue(extractNestedField(hit)); - } else { - field = hit.field(fieldName); - if (field != null) { - value = unwrapFieldsMultiValue(field.getValues()); - } - } - return value; - } - - /* - * For a path of fields like root.nested1.nested2.leaf where nested1 and nested2 are nested field types, - * fieldName is root.nested1.nested2.leaf, while hitName is root.nested1.nested2 - * We first look for root.nested1.nested2 or root.nested1 or root in the SearchHit until we find something. - * If the DocumentField lives under "root.nested1" the remaining path to search for (in the DocumentField itself) is nested2. - * After this step is done, what remains to be done is just getting the leaf values. - */ - @SuppressWarnings("unchecked") - private Object extractNestedField(SearchHit hit) { - Object value; - DocumentField field; - String tempHitname = hitName; - List remainingPath = new ArrayList<>(); - // first, search for the "root" DocumentField under which the remaining path of nested document values is - while ((field = hit.field(tempHitname)) == null) { - int indexOfDot = tempHitname.lastIndexOf('.'); - if (indexOfDot < 0) {// there is no such field in the hit - return null; - } - remainingPath.add(0, tempHitname.substring(indexOfDot + 1)); - tempHitname = tempHitname.substring(0, indexOfDot); - } - // then dig into DocumentField's structure until we reach the field we are interested into - if (remainingPath.size() > 0) { - List values = field.getValues(); - Iterator pathIterator = remainingPath.iterator(); - while (pathIterator.hasNext()) { - String pathElement = pathIterator.next(); - Map> elements = (Map>) values.get(0); - values = elements.get(pathElement); - /* - * if this path is not found it means we hit another nested document (inner_root_1.inner_root_2.nested_field_2) - * something like this - * "root_field_1.root_field_2.nested_field_1" : [ - * { - * "inner_root_1.inner_root_2.nested_field_2" : [ - * { - * "leaf_field" : [ - * "abc2" - * ] - * So, start re-building the path until the right one is found, ie inner_root_1.inner_root_2...... - */ - while (values == null) { - pathElement += "." + pathIterator.next(); - values = elements.get(pathElement); - } - } - value = ((Map) values.get(0)).get(fieldName.substring(hitName.length() + 1)); - } else { - value = field.getValues(); - } - return value; - } - - protected Object unwrapFieldsMultiValue(Object values) { - if (values == null) { - return null; - } - if (values instanceof Map && hitName != null) { - // extract the sub-field from a nested field (dep.dep_name -> dep_name) - return unwrapFieldsMultiValue(((Map) values).get(fieldName.substring(hitName.length() + 1))); - } - if (values instanceof List list) { - if (list.isEmpty()) { - return null; - } else { - if (isPrimitive(list) == false) { - if (list.size() == 1 || multiValueSupport == MultiValueSupport.LENIENT) { - return unwrapFieldsMultiValue(list.get(0)); - } else if (multiValueSupport == MultiValueSupport.FULL) { - List unwrappedValues = new ArrayList<>(); - for (Object value : list) { - unwrappedValues.add(unwrapFieldsMultiValue(value)); - } - values = unwrappedValues; - } else { - // missing `field_multi_value_leniency` setting - throw new InvalidArgumentException("Arrays (returned by [{}]) are not supported", fieldName); - } - } - } - } - - Object unwrapped = unwrapCustomValue(values); - if (unwrapped != null && isListOfNulls(unwrapped) == false) { - return unwrapped; - } - - return values; - } - - private static boolean isListOfNulls(Object unwrapped) { - if (unwrapped instanceof List list) { - if (list.size() == 0) { - return false; - } - for (Object o : list) { - if (o != null) { - return false; - } - } - return true; - } - return false; - } - - protected abstract Object unwrapCustomValue(Object values); - - protected abstract boolean isPrimitive(List list); - - @Override - public String hitName() { - return hitName; - } - - public String fieldName() { - return fieldName; - } - - public ZoneId zoneId() { - return zoneId; - } - - public DataType dataType() { - return dataType; - } - - public MultiValueSupport multiValueSupport() { - return multiValueSupport; - } - - @Override - public String toString() { - return fieldName + "@" + hitName + "@" + zoneId; - } - - @Override - public boolean equals(Object obj) { - if (obj == null || obj.getClass() != getClass()) { - return false; - } - AbstractFieldHitExtractor other = (AbstractFieldHitExtractor) obj; - return fieldName.equals(other.fieldName) && hitName.equals(other.hitName) && multiValueSupport == other.multiValueSupport; - } - - @Override - public int hashCode() { - return Objects.hash(fieldName, hitName, multiValueSupport); - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/BucketExtractor.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/BucketExtractor.java deleted file mode 100644 index a25482d92ecce..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/BucketExtractor.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.execution.search.extractor; - -import org.elasticsearch.common.io.stream.NamedWriteable; -import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation.Bucket; - -/** - * Extracts an aggregation value from a {@link Bucket}. - */ -public interface BucketExtractor extends NamedWriteable { - - Object extract(Bucket bucket); -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/BucketExtractors.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/BucketExtractors.java deleted file mode 100644 index fa7443e190d31..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/BucketExtractors.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.execution.search.extractor; - -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry.Entry; - -import java.util.ArrayList; -import java.util.List; - -public final class BucketExtractors { - - private BucketExtractors() {} - - /** - * All of the named writeables needed to deserialize the instances of - * {@linkplain BucketExtractor}s. - */ - public static List getNamedWriteables() { - List entries = new ArrayList<>(); - entries.add(new Entry(BucketExtractor.class, ComputingExtractor.NAME, ComputingExtractor::new)); - entries.add(new Entry(BucketExtractor.class, ConstantExtractor.NAME, ConstantExtractor::new)); - return entries; - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/ComputingExtractor.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/ComputingExtractor.java deleted file mode 100644 index 1116a43022da2..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/ComputingExtractor.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.execution.search.extractor; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation.Bucket; -import org.elasticsearch.xpack.esql.core.expression.gen.processor.HitExtractorProcessor; -import org.elasticsearch.xpack.esql.core.expression.gen.processor.Processor; - -import java.io.IOException; -import java.util.Objects; - -/** - * Hit/BucketExtractor that delegates to a processor. The difference between this class - * and {@link HitExtractorProcessor} is that the latter is used inside a - * {@link Processor} tree as a leaf (and thus can effectively parse the - * {@link SearchHit} while this class is used when scrolling and passing down - * the results. - * - * In the future, the processor might be used across the board for all columns - * to reduce API complexity (and keep the {@link HitExtractor} only as an - * internal implementation detail). - */ -public class ComputingExtractor implements HitExtractor, BucketExtractor { - /** - * Stands for {@code comPuting}. We try to use short names for {@link HitExtractor}s - * to save a few bytes when when we send them back to the user. - */ - static final String NAME = "p"; - private final Processor processor; - private final String hitName; - - public ComputingExtractor(Processor processor) { - this(processor, null); - } - - public ComputingExtractor(Processor processor, String hitName) { - this.processor = processor; - this.hitName = hitName; - } - - // Visibility required for tests - public ComputingExtractor(StreamInput in) throws IOException { - processor = in.readNamedWriteable(Processor.class); - hitName = in.readOptionalString(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeNamedWriteable(processor); - out.writeOptionalString(hitName); - } - - @Override - public String getWriteableName() { - return NAME; - } - - public Processor processor() { - return processor; - } - - public Object extract(Object input) { - return processor.process(input); - } - - @Override - public Object extract(Bucket bucket) { - return processor.process(bucket); - } - - @Override - public Object extract(SearchHit hit) { - return processor.process(hit); - } - - @Override - public String hitName() { - return hitName; - } - - @Override - public boolean equals(Object obj) { - if (obj == null || obj.getClass() != getClass()) { - return false; - } - ComputingExtractor other = (ComputingExtractor) obj; - return Objects.equals(processor, other.processor) && Objects.equals(hitName, other.hitName); - } - - @Override - public int hashCode() { - return Objects.hash(processor, hitName); - } - - @Override - public String toString() { - return processor.toString(); - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/ConstantExtractor.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/ConstantExtractor.java deleted file mode 100644 index bba311a085ed2..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/ConstantExtractor.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.execution.search.extractor; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation.Bucket; - -import java.io.IOException; -import java.util.Objects; - -/** - * Returns the a constant for every search hit against which it is run. - */ -public class ConstantExtractor implements HitExtractor, BucketExtractor { - /** - * Stands for {@code constant}. We try to use short names for {@link HitExtractor}s - * to save a few bytes when when we send them back to the user. - */ - static final String NAME = "c"; - private final Object constant; - - public ConstantExtractor(Object constant) { - this.constant = constant; - } - - ConstantExtractor(StreamInput in) throws IOException { - constant = in.readGenericValue(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeGenericValue(constant); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public Object extract(SearchHit hit) { - return constant; - } - - @Override - public Object extract(Bucket bucket) { - return constant; - } - - @Override - public String hitName() { - return null; - } - - @Override - public boolean equals(Object obj) { - if (obj == null || obj.getClass() != getClass()) { - return false; - } - ConstantExtractor other = (ConstantExtractor) obj; - return Objects.equals(constant, other.constant); - } - - @Override - public int hashCode() { - return Objects.hashCode(constant); - } - - @Override - public String toString() { - return "^" + constant; - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/HitExtractor.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/HitExtractor.java deleted file mode 100644 index 38b72c5e8cd7e..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/HitExtractor.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.execution.search.extractor; - -import org.elasticsearch.common.io.stream.NamedWriteable; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.search.SearchHit; - -/** - * Extracts a column value from a {@link SearchHit}. - */ -public interface HitExtractor extends NamedWriteable { - /** - * Extract the value from a hit. - */ - Object extract(SearchHit hit); - - /** - * Name of the inner hit needed by this extractor if it needs one, {@code null} otherwise. - */ - @Nullable - String hitName(); -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/HitExtractors.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/HitExtractors.java deleted file mode 100644 index 743856d41f8d5..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/HitExtractors.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.execution.search.extractor; - -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry.Entry; - -import java.util.ArrayList; -import java.util.List; - -public final class HitExtractors { - - private HitExtractors() {} - - /** - * All of the named writeables needed to deserialize the instances of - * {@linkplain HitExtractor}. - */ - public static List getNamedWriteables() { - List entries = new ArrayList<>(); - entries.add(new Entry(HitExtractor.class, ConstantExtractor.NAME, ConstantExtractor::new)); - entries.add(new Entry(HitExtractor.class, ComputingExtractor.NAME, ComputingExtractor::new)); - return entries; - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/TotalHitsExtractor.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/TotalHitsExtractor.java deleted file mode 100644 index 52a9116619024..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/TotalHitsExtractor.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.execution.search.extractor; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; -import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; - -import java.io.IOException; - -public class TotalHitsExtractor extends ConstantExtractor { - - public TotalHitsExtractor(Long constant) { - super(constant); - } - - TotalHitsExtractor(StreamInput in) throws IOException { - super(in); - } - - @Override - public Object extract(MultiBucketsAggregation.Bucket bucket) { - return validate(super.extract(bucket)); - } - - @Override - public Object extract(SearchHit hit) { - return validate(super.extract(hit)); - } - - private static Object validate(Object value) { - if (Number.class.isInstance(value) == false) { - throw new QlIllegalArgumentException( - "Inconsistent total hits count handling, expected a numeric value but found a {}: {}", - value == null ? null : value.getClass().getSimpleName(), - value - ); - } - if (((Number) value).longValue() < 0) { - throw new QlIllegalArgumentException( - "Inconsistent total hits count handling, expected a non-negative value but found {}", - value - ); - } - return value; - } - -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java index eac3586cf139d..b6704f0569b27 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java @@ -26,6 +26,9 @@ import static org.elasticsearch.core.Tuple.tuple; public class MetadataAttribute extends TypedAttribute { + public static final String TIMESTAMP_FIELD = "@timestamp"; + public static final String TSID_FIELD = "_tsid"; + static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Attribute.class, "MetadataAttribute", diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/FunctionRegistry.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/FunctionRegistry.java index 6fa78348f328f..d3210ad6c2e6a 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/FunctionRegistry.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/FunctionRegistry.java @@ -390,7 +390,7 @@ protected interface ConfigurationAwareBuilder { * Build a {@linkplain FunctionDefinition} for a one-argument function that is configuration aware. */ @SuppressWarnings("overloads") - protected static FunctionDefinition def( + public static FunctionDefinition def( Class function, UnaryConfigurationAwareBuilder ctorRef, String... names @@ -405,7 +405,7 @@ protected static FunctionDefinition def( return def(function, builder, names); } - protected interface UnaryConfigurationAwareBuilder { + public interface UnaryConfigurationAwareBuilder { T build(Source source, Expression exp, Configuration configuration); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/FunctionTypeRegistry.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/FunctionTypeRegistry.java deleted file mode 100644 index 8ba40d5b167ff..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/FunctionTypeRegistry.java +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.expression.function; - -public interface FunctionTypeRegistry { - - String type(Class clazz); -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/BaseSurrogateFunction.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/BaseSurrogateFunction.java deleted file mode 100644 index efbcc4f869620..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/BaseSurrogateFunction.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.expression.function.scalar; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.Source; - -import java.util.List; - -public abstract class BaseSurrogateFunction extends ScalarFunction implements SurrogateFunction { - - private ScalarFunction lazySubstitute; - - public BaseSurrogateFunction(Source source) { - super(source); - } - - public BaseSurrogateFunction(Source source, List fields) { - super(source, fields); - } - - @Override - public ScalarFunction substitute() { - if (lazySubstitute == null) { - lazySubstitute = makeSubstitute(); - } - return lazySubstitute; - } - - protected abstract ScalarFunction makeSubstitute(); - - @Override - public boolean foldable() { - return substitute().foldable(); - } - - @Override - public Object fold() { - return substitute().fold(); - } - -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/ConfigurationFunction.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/ConfigurationFunction.java deleted file mode 100644 index fe2e527b57417..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/ConfigurationFunction.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.expression.function.scalar; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.session.Configuration; -import org.elasticsearch.xpack.esql.core.tree.Source; - -import java.util.List; - -public abstract class ConfigurationFunction extends ScalarFunction { - - private final Configuration configuration; - - protected ConfigurationFunction(Source source, List fields, Configuration configuration) { - super(source, fields); - this.configuration = configuration; - } - - public Configuration configuration() { - return configuration; - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/IntervalScripting.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/IntervalScripting.java deleted file mode 100644 index 121696f1df4f9..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/IntervalScripting.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.expression.function.scalar; - -// FIXME: accessor interface until making script generation pluggable -public interface IntervalScripting { - - String script(); - - String value(); - - String typeName(); - -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java index e5c2cedfd087b..1efda1e54185b 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java @@ -33,7 +33,7 @@ protected UnaryScalarFunction(StreamInput in) throws IOException { } @Override - public final void writeTo(StreamOutput out) throws IOException { + public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); ((PlanStreamOutput) out).writeExpression(field); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/string/BinaryComparisonCaseInsensitiveFunction.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/string/BinaryComparisonCaseInsensitiveFunction.java deleted file mode 100644 index 4739fe910b769..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/string/BinaryComparisonCaseInsensitiveFunction.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.expression.function.scalar.string; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.type.DataType; - -import java.util.Objects; - -import static java.util.Arrays.asList; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; -import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isStringAndExact; - -public abstract class BinaryComparisonCaseInsensitiveFunction extends CaseInsensitiveScalarFunction { - - private final Expression left, right; - - protected BinaryComparisonCaseInsensitiveFunction(Source source, Expression left, Expression right, boolean caseInsensitive) { - super(source, asList(left, right), caseInsensitive); - this.left = left; - this.right = right; - } - - @Override - protected TypeResolution resolveType() { - if (childrenResolved() == false) { - return new TypeResolution("Unresolved children"); - } - - TypeResolution sourceResolution = isStringAndExact(left, sourceText(), FIRST); - if (sourceResolution.unresolved()) { - return sourceResolution; - } - - return isStringAndExact(right, sourceText(), SECOND); - } - - public Expression left() { - return left; - } - - public Expression right() { - return right; - } - - @Override - public DataType dataType() { - return DataType.BOOLEAN; - } - - @Override - public boolean foldable() { - return left.foldable() && right.foldable(); - } - - @Override - public int hashCode() { - return Objects.hash(left, right, isCaseInsensitive()); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - - if (obj == null || getClass() != obj.getClass()) { - return false; - } - - BinaryComparisonCaseInsensitiveFunction other = (BinaryComparisonCaseInsensitiveFunction) obj; - return Objects.equals(left, other.left) - && Objects.equals(right, other.right) - && Objects.equals(isCaseInsensitive(), other.isCaseInsensitive()); - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/string/CaseInsensitiveScalarFunction.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/string/CaseInsensitiveScalarFunction.java deleted file mode 100644 index bd3b1aed73390..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/string/CaseInsensitiveScalarFunction.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.expression.function.scalar.string; - -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; -import org.elasticsearch.xpack.esql.core.tree.Source; - -import java.util.List; -import java.util.Objects; - -public abstract class CaseInsensitiveScalarFunction extends ScalarFunction { - - private final boolean caseInsensitive; - - protected CaseInsensitiveScalarFunction(Source source, List fields, boolean caseInsensitive) { - super(source, fields); - this.caseInsensitive = caseInsensitive; - } - - public boolean isCaseInsensitive() { - return caseInsensitive; - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), isCaseInsensitive()); - } - - @Override - public boolean equals(Object other) { - return super.equals(other) && Objects.equals(((CaseInsensitiveScalarFunction) other).caseInsensitive, caseInsensitive); - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/string/StartsWithFunctionProcessor.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/string/StartsWithFunctionProcessor.java deleted file mode 100644 index 8172971fc39f0..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/string/StartsWithFunctionProcessor.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.expression.function.scalar.string; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; -import org.elasticsearch.xpack.esql.core.expression.gen.processor.Processor; - -import java.io.IOException; -import java.util.Locale; -import java.util.Objects; - -public class StartsWithFunctionProcessor implements Processor { - - public static final String NAME = "sstw"; - - private final Processor source; - private final Processor pattern; - private final boolean caseInsensitive; - - public StartsWithFunctionProcessor(Processor source, Processor pattern, boolean caseInsensitive) { - this.source = source; - this.pattern = pattern; - this.caseInsensitive = caseInsensitive; - } - - public StartsWithFunctionProcessor(StreamInput in) throws IOException { - source = in.readNamedWriteable(Processor.class); - pattern = in.readNamedWriteable(Processor.class); - caseInsensitive = in.readBoolean(); - } - - @Override - public final void writeTo(StreamOutput out) throws IOException { - out.writeNamedWriteable(source); - out.writeNamedWriteable(pattern); - out.writeBoolean(caseInsensitive); - } - - @Override - public Object process(Object input) { - return doProcess(source.process(input), pattern.process(input), isCaseInsensitive()); - } - - public static Object doProcess(Object source, Object pattern, boolean caseInsensitive) { - if (source == null) { - return null; - } - if (source instanceof String == false && source instanceof Character == false) { - throw new QlIllegalArgumentException("A string/char is required; received [{}]", source); - } - if (pattern == null) { - return null; - } - if (pattern instanceof String == false && pattern instanceof Character == false) { - throw new QlIllegalArgumentException("A string/char is required; received [{}]", pattern); - } - - if (caseInsensitive == false) { - return source.toString().startsWith(pattern.toString()); - } else { - return source.toString().toLowerCase(Locale.ROOT).startsWith(pattern.toString().toLowerCase(Locale.ROOT)); - } - } - - protected Processor source() { - return source; - } - - protected Processor pattern() { - return pattern; - } - - protected boolean isCaseInsensitive() { - return caseInsensitive; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - - if (obj == null || getClass() != obj.getClass()) { - return false; - } - - StartsWithFunctionProcessor other = (StartsWithFunctionProcessor) obj; - return Objects.equals(source(), other.source()) - && Objects.equals(pattern(), other.pattern()) - && Objects.equals(isCaseInsensitive(), other.isCaseInsensitive()); - } - - @Override - public int hashCode() { - return Objects.hash(source(), pattern(), isCaseInsensitive()); - } - - @Override - public String getWriteableName() { - return NAME; - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/whitelist/InternalQlScriptUtils.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/whitelist/InternalQlScriptUtils.java deleted file mode 100644 index e361d2465a1c5..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/whitelist/InternalQlScriptUtils.java +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.expression.function.scalar.whitelist; - -import org.elasticsearch.index.fielddata.ScriptDocValues; -import org.elasticsearch.xpack.esql.core.expression.function.scalar.string.StartsWithFunctionProcessor; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogicProcessor.BinaryLogicOperation; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.NotProcessor; -import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.CheckNullProcessor.CheckNullOperation; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.UnaryArithmeticProcessor.UnaryArithmeticOperation; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparisonProcessor.BinaryComparisonOperation; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.InProcessor; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RegexProcessor.RegexOperation; -import org.elasticsearch.xpack.esql.core.util.StringUtils; - -import java.util.List; -import java.util.Map; - -import static org.elasticsearch.xpack.esql.core.type.DataType.fromTypeName; -import static org.elasticsearch.xpack.esql.core.type.DataTypeConverter.convert; -import static org.elasticsearch.xpack.esql.core.type.DataTypeConverter.toUnsignedLong; - -public class InternalQlScriptUtils { - - // - // Utilities - // - - // safe missing mapping/value extractor - public static Object docValue(Map> doc, String fieldName) { - if (doc.containsKey(fieldName)) { - ScriptDocValues docValues = doc.get(fieldName); - if (docValues.isEmpty() == false) { - return docValues.get(0); - } - } - return null; - } - - public static boolean nullSafeFilter(Boolean filter) { - return filter == null ? false : filter.booleanValue(); - } - - public static double nullSafeSortNumeric(Number sort) { - return sort == null ? 0.0d : sort.doubleValue(); - } - - public static String nullSafeSortString(Object sort) { - return sort == null ? StringUtils.EMPTY : sort.toString(); - } - - public static Number nullSafeCastNumeric(Number number, String typeName) { - return number == null || Double.isNaN(number.doubleValue()) ? null : (Number) convert(number, fromTypeName(typeName)); - } - - public static Number nullSafeCastToUnsignedLong(Number number) { - return number == null || Double.isNaN(number.doubleValue()) ? null : toUnsignedLong(number); - } - - // - // Operators - // - - // - // Logical - // - public static Boolean eq(Object left, Object right) { - return BinaryComparisonOperation.EQ.apply(left, right); - } - - public static Boolean nulleq(Object left, Object right) { - return BinaryComparisonOperation.NULLEQ.apply(left, right); - } - - public static Boolean neq(Object left, Object right) { - return BinaryComparisonOperation.NEQ.apply(left, right); - } - - public static Boolean lt(Object left, Object right) { - return BinaryComparisonOperation.LT.apply(left, right); - } - - public static Boolean lte(Object left, Object right) { - return BinaryComparisonOperation.LTE.apply(left, right); - } - - public static Boolean gt(Object left, Object right) { - return BinaryComparisonOperation.GT.apply(left, right); - } - - public static Boolean gte(Object left, Object right) { - return BinaryComparisonOperation.GTE.apply(left, right); - } - - public static Boolean in(Object value, List values) { - return InProcessor.apply(value, values); - } - - public static Boolean and(Boolean left, Boolean right) { - return BinaryLogicOperation.AND.apply(left, right); - } - - public static Boolean or(Boolean left, Boolean right) { - return BinaryLogicOperation.OR.apply(left, right); - } - - public static Boolean not(Boolean expression) { - return NotProcessor.apply(expression); - } - - public static Boolean isNull(Object expression) { - return CheckNullOperation.IS_NULL.test(expression); - } - - public static Boolean isNotNull(Object expression) { - return CheckNullOperation.IS_NOT_NULL.test(expression); - } - - // - // Regex - // - public static Boolean regex(String value, String pattern) { - return regex(value, pattern, Boolean.FALSE); - } - - public static Boolean regex(String value, String pattern, Boolean caseInsensitive) { - // TODO: this needs to be improved to avoid creating the pattern on every call - return RegexOperation.match(value, pattern, caseInsensitive); - } - - // - // Math - // - public static Number add(Number left, Number right) { - return (Number) DefaultBinaryArithmeticOperation.ADD.apply(left, right); - } - - public static Number div(Number left, Number right) { - return (Number) DefaultBinaryArithmeticOperation.DIV.apply(left, right); - } - - public static Number mod(Number left, Number right) { - return (Number) DefaultBinaryArithmeticOperation.MOD.apply(left, right); - } - - public static Number mul(Number left, Number right) { - return (Number) DefaultBinaryArithmeticOperation.MUL.apply(left, right); - } - - public static Number neg(Number value) { - return UnaryArithmeticOperation.NEGATE.apply(value); - } - - public static Number sub(Number left, Number right) { - return (Number) DefaultBinaryArithmeticOperation.SUB.apply(left, right); - } - - // - // String - // - public static Boolean startsWith(String s, String pattern, Boolean caseInsensitive) { - return (Boolean) StartsWithFunctionProcessor.doProcess(s, pattern, caseInsensitive); - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/gen/processor/BucketExtractorProcessor.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/gen/processor/BucketExtractorProcessor.java deleted file mode 100644 index afd4efc0e88e7..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/gen/processor/BucketExtractorProcessor.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.expression.gen.processor; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation.Bucket; -import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; -import org.elasticsearch.xpack.esql.core.execution.search.extractor.BucketExtractor; - -import java.io.IOException; -import java.util.Objects; - -/** - * Processor wrapping an {@link BucketExtractor}, essentially being a source/leaf of a - * Processor tree. - */ -public class BucketExtractorProcessor implements Processor { - - public static final String NAME = "a"; - - private final BucketExtractor extractor; - - public BucketExtractorProcessor(BucketExtractor extractor) { - this.extractor = extractor; - } - - public BucketExtractorProcessor(StreamInput in) throws IOException { - extractor = in.readNamedWriteable(BucketExtractor.class); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeNamedWriteable(extractor); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public Object process(Object input) { - if ((input instanceof Bucket) == false) { - throw new QlIllegalArgumentException("Expected an agg bucket but received {}", input); - } - return extractor.extract((Bucket) input); - } - - @Override - public int hashCode() { - return Objects.hash(extractor); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - - if (obj == null || getClass() != obj.getClass()) { - return false; - } - - BucketExtractorProcessor other = (BucketExtractorProcessor) obj; - return Objects.equals(extractor, other.extractor); - } - - @Override - public String toString() { - return extractor.toString(); - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/gen/processor/HitExtractorProcessor.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/gen/processor/HitExtractorProcessor.java deleted file mode 100644 index 1662a8192acf9..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/gen/processor/HitExtractorProcessor.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.expression.gen.processor; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; -import org.elasticsearch.xpack.esql.core.execution.search.extractor.HitExtractor; - -import java.io.IOException; -import java.util.Objects; - -/** - * Processor wrapping a {@link HitExtractor}, essentially being a source/leaf of a - * Processor tree. - */ -public class HitExtractorProcessor implements Processor { - - public static final String NAME = "h"; - - private final HitExtractor extractor; - - public HitExtractorProcessor(HitExtractor extractor) { - this.extractor = extractor; - } - - public HitExtractorProcessor(StreamInput in) throws IOException { - extractor = in.readNamedWriteable(HitExtractor.class); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeNamedWriteable(extractor); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public Object process(Object input) { - if ((input instanceof SearchHit) == false) { - throw new QlIllegalArgumentException("Expected a SearchHit but received {}", input); - } - return extractor.extract((SearchHit) input); - } - - @Override - public int hashCode() { - return Objects.hash(extractor); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - - if (obj == null || getClass() != obj.getClass()) { - return false; - } - - HitExtractorProcessor other = (HitExtractorProcessor) obj; - return Objects.equals(extractor, other.extractor); - } - - @Override - public String toString() { - return extractor.toString(); - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/And.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/And.java index 81418aa78ce57..e5ab86605657d 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/And.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/And.java @@ -6,6 +6,8 @@ */ package org.elasticsearch.xpack.esql.core.expression.predicate.logical; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; @@ -13,12 +15,24 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import java.io.IOException; + public class And extends BinaryLogic implements Negatable { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "And", And::new); public And(Source source, Expression left, Expression right) { super(source, left, right, BinaryLogicOperation.AND); } + private And(StreamInput in) throws IOException { + super(in, BinaryLogicOperation.AND); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, And::new, left(), right()); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/BinaryLogic.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/BinaryLogic.java index 39de0e0643c13..b52cd728dd773 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/BinaryLogic.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/BinaryLogic.java @@ -6,6 +6,8 @@ */ package org.elasticsearch.xpack.esql.core.expression.predicate.logical; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal; @@ -13,6 +15,10 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogicProcessor.BinaryLogicOperation; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamOutput; + +import java.io.IOException; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isBoolean; @@ -22,6 +28,22 @@ protected BinaryLogic(Source source, Expression left, Expression right, BinaryLo super(source, left, right, operation); } + protected BinaryLogic(StreamInput in, BinaryLogicOperation op) throws IOException { + this( + Source.readFrom((StreamInput & PlanStreamInput) in), + ((StreamInput & PlanStreamInput) in).readExpression(), + ((StreamInput & PlanStreamInput) in).readExpression(), + op + ); + } + + @Override + public final void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + ((StreamOutput & PlanStreamOutput) out).writeExpression(left()); + ((StreamOutput & PlanStreamOutput) out).writeExpression(right()); + } + @Override public DataType dataType() { return DataType.BOOLEAN; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Or.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Or.java index 16781426d2323..b3afb662a009d 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Or.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Or.java @@ -6,6 +6,8 @@ */ package org.elasticsearch.xpack.esql.core.expression.predicate.logical; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; @@ -13,12 +15,24 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import java.io.IOException; + public class Or extends BinaryLogic implements Negatable { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Or", Or::new); public Or(Source source, Expression left, Expression right) { super(source, left, right, BinaryLogicOperation.OR); } + private Or(StreamInput in) throws IOException { + super(in, BinaryLogicOperation.OR); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, Or::new, left(), right()); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/processor/Processors.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/processor/Processors.java index f72fdb7e43fb6..e47b80ee0ab59 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/processor/Processors.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/processor/Processors.java @@ -8,10 +8,8 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry.Entry; -import org.elasticsearch.xpack.esql.core.expression.gen.processor.BucketExtractorProcessor; import org.elasticsearch.xpack.esql.core.expression.gen.processor.ChainingProcessor; import org.elasticsearch.xpack.esql.core.expression.gen.processor.ConstantProcessor; -import org.elasticsearch.xpack.esql.core.expression.gen.processor.HitExtractorProcessor; import org.elasticsearch.xpack.esql.core.expression.gen.processor.Processor; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogicProcessor; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.NotProcessor; @@ -42,8 +40,6 @@ public static List getNamedWriteables() { entries.add(new Entry(Converter.class, DefaultConverter.NAME, DefaultConverter::read)); entries.add(new Entry(Processor.class, ConstantProcessor.NAME, ConstantProcessor::new)); - entries.add(new Entry(Processor.class, HitExtractorProcessor.NAME, HitExtractorProcessor::new)); - entries.add(new Entry(Processor.class, BucketExtractorProcessor.NAME, BucketExtractorProcessor::new)); entries.add(new Entry(Processor.class, ChainingProcessor.NAME, ChainingProcessor::new)); // logical diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java index b0cdaa7ff0021..a8ef1ea689878 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java @@ -14,6 +14,7 @@ import java.io.IOException; import java.math.BigInteger; import java.time.ZonedDateTime; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Comparator; @@ -21,15 +22,14 @@ import java.util.Map; import java.util.Set; import java.util.function.Function; -import java.util.stream.Stream; import static java.util.stream.Collectors.toMap; import static java.util.stream.Collectors.toUnmodifiableMap; public enum DataType { - UNSUPPORTED(builder(null).typeName("UNSUPPORTED")), - NULL(builder("null")), - BOOLEAN(builder("boolean").size(1)), + UNSUPPORTED(builder().typeName("UNSUPPORTED")), + NULL(builder().esType("null")), + BOOLEAN(builder().esType("boolean").size(1)), /** * These are numeric fields labeled as metric counters in time-series indices. Although stored @@ -38,37 +38,37 @@ public enum DataType { * These fields are strictly for use in retrieval from indices, rate aggregation, and casting to their * parent numeric type. */ - COUNTER_LONG(builder("counter_long").size(Long.BYTES).docValues().counter()), - COUNTER_INTEGER(builder("counter_integer").size(Integer.BYTES).docValues().counter()), - COUNTER_DOUBLE(builder("counter_double").size(Double.BYTES).docValues().counter()), - - LONG(builder("long").size(Long.BYTES).integer().docValues().counter(COUNTER_LONG)), - INTEGER(builder("integer").size(Integer.BYTES).integer().docValues().counter(COUNTER_INTEGER)), - SHORT(builder("short").size(Short.BYTES).integer().docValues().widenSmallNumeric(INTEGER)), - BYTE(builder("byte").size(Byte.BYTES).integer().docValues().widenSmallNumeric(INTEGER)), - UNSIGNED_LONG(builder("unsigned_long").size(Long.BYTES).integer().docValues()), - DOUBLE(builder("double").size(Double.BYTES).rational().docValues().counter(COUNTER_DOUBLE)), - FLOAT(builder("float").size(Float.BYTES).rational().docValues().widenSmallNumeric(DOUBLE)), - HALF_FLOAT(builder("half_float").size(Float.BYTES).rational().docValues().widenSmallNumeric(DOUBLE)), - SCALED_FLOAT(builder("scaled_float").size(Long.BYTES).rational().docValues().widenSmallNumeric(DOUBLE)), - - KEYWORD(builder("keyword").unknownSize().docValues()), - TEXT(builder("text").unknownSize()), - DATETIME(builder("date").typeName("DATETIME").size(Long.BYTES).docValues()), - IP(builder("ip").size(45).docValues()), - VERSION(builder("version").unknownSize().docValues()), - OBJECT(builder("object")), - NESTED(builder("nested")), - SOURCE(builder(SourceFieldMapper.NAME).unknownSize()), - DATE_PERIOD(builder(null).typeName("DATE_PERIOD").size(3 * Integer.BYTES)), - TIME_DURATION(builder(null).typeName("TIME_DURATION").size(Integer.BYTES + Long.BYTES)), - GEO_POINT(builder("geo_point").size(Double.BYTES * 2).docValues()), - CARTESIAN_POINT(builder("cartesian_point").size(Double.BYTES * 2).docValues()), - CARTESIAN_SHAPE(builder("cartesian_shape").unknownSize().docValues()), - GEO_SHAPE(builder("geo_shape").unknownSize().docValues()), - - DOC_DATA_TYPE(builder("_doc").size(Integer.BYTES * 3)), - TSID_DATA_TYPE(builder("_tsid").unknownSize().docValues()); + COUNTER_LONG(builder().esType("counter_long").size(Long.BYTES).docValues().counter()), + COUNTER_INTEGER(builder().esType("counter_integer").size(Integer.BYTES).docValues().counter()), + COUNTER_DOUBLE(builder().esType("counter_double").size(Double.BYTES).docValues().counter()), + + LONG(builder().esType("long").size(Long.BYTES).integer().docValues().counter(COUNTER_LONG)), + INTEGER(builder().esType("integer").size(Integer.BYTES).integer().docValues().counter(COUNTER_INTEGER)), + SHORT(builder().esType("short").size(Short.BYTES).integer().docValues().widenSmallNumeric(INTEGER)), + BYTE(builder().esType("byte").size(Byte.BYTES).integer().docValues().widenSmallNumeric(INTEGER)), + UNSIGNED_LONG(builder().esType("unsigned_long").size(Long.BYTES).integer().docValues()), + DOUBLE(builder().esType("double").size(Double.BYTES).rational().docValues().counter(COUNTER_DOUBLE)), + FLOAT(builder().esType("float").size(Float.BYTES).rational().docValues().widenSmallNumeric(DOUBLE)), + HALF_FLOAT(builder().esType("half_float").size(Float.BYTES).rational().docValues().widenSmallNumeric(DOUBLE)), + SCALED_FLOAT(builder().esType("scaled_float").size(Long.BYTES).rational().docValues().widenSmallNumeric(DOUBLE)), + + KEYWORD(builder().esType("keyword").unknownSize().docValues()), + TEXT(builder().esType("text").unknownSize()), + DATETIME(builder().esType("date").typeName("DATETIME").size(Long.BYTES).docValues()), + IP(builder().esType("ip").size(45).docValues()), + VERSION(builder().esType("version").unknownSize().docValues()), + OBJECT(builder().esType("object")), + NESTED(builder().esType("nested")), + SOURCE(builder().esType(SourceFieldMapper.NAME).unknownSize()), + DATE_PERIOD(builder().typeName("DATE_PERIOD").size(3 * Integer.BYTES)), + TIME_DURATION(builder().typeName("TIME_DURATION").size(Integer.BYTES + Long.BYTES)), + GEO_POINT(builder().esType("geo_point").size(Double.BYTES * 2).docValues()), + CARTESIAN_POINT(builder().esType("cartesian_point").size(Double.BYTES * 2).docValues()), + CARTESIAN_SHAPE(builder().esType("cartesian_shape").unknownSize().docValues()), + GEO_SHAPE(builder().esType("geo_shape").unknownSize().docValues()), + + DOC_DATA_TYPE(builder().esType("_doc").size(Integer.BYTES * 3)), + TSID_DATA_TYPE(builder().esType("_tsid").unknownSize().docValues()); private final String typeName; @@ -124,37 +124,10 @@ public enum DataType { this.counter = builder.counter; } - private static final Collection TYPES = Stream.of( - UNSUPPORTED, - NULL, - BOOLEAN, - BYTE, - SHORT, - INTEGER, - LONG, - UNSIGNED_LONG, - DOUBLE, - FLOAT, - HALF_FLOAT, - SCALED_FLOAT, - KEYWORD, - TEXT, - DATETIME, - IP, - VERSION, - OBJECT, - NESTED, - SOURCE, - DATE_PERIOD, - TIME_DURATION, - GEO_POINT, - CARTESIAN_POINT, - CARTESIAN_SHAPE, - GEO_SHAPE, - COUNTER_LONG, - COUNTER_INTEGER, - COUNTER_DOUBLE - ).sorted(Comparator.comparing(DataType::typeName)).toList(); + private static final Collection TYPES = Arrays.stream(values()) + .filter(d -> d != DOC_DATA_TYPE && d != TSID_DATA_TYPE) + .sorted(Comparator.comparing(DataType::typeName)) + .toList(); private static final Map NAME_TO_TYPE = TYPES.stream().collect(toUnmodifiableMap(DataType::typeName, t -> t)); @@ -162,7 +135,10 @@ public enum DataType { static { Map map = TYPES.stream().filter(e -> e.esType() != null).collect(toMap(DataType::esType, t -> t)); - map.put("date_nanos", DATETIME); + // TODO: Why don't we use the names ES uses as the esType field for these? + // ES calls this 'point', but ESQL calls it 'cartesian_point' + map.put("point", DataType.CARTESIAN_POINT); + map.put("shape", DataType.CARTESIAN_SHAPE); ES_TO_TYPE = Collections.unmodifiableMap(map); } @@ -277,6 +253,10 @@ public String esType() { return esType; } + public String outputType() { + return esType == null ? "unsupported" : esType; + } + public boolean isInteger() { return isInteger; } @@ -346,8 +326,8 @@ public static DataType fromNameOrAlias(String typeName) { return type != null ? type : UNSUPPORTED; } - static Builder builder(String esType) { - return new Builder(esType); + static Builder builder() { + return new Builder(); } /** @@ -355,7 +335,7 @@ static Builder builder(String esType) { * a builder in java.... */ private static class Builder { - private final String esType; + private String esType; private String typeName; @@ -393,8 +373,11 @@ private static class Builder { */ private DataType counter; - Builder(String esType) { + Builder() {} + + Builder esType(String esType) { this.esType = esType; + return this; } Builder typeName(String typeName) { diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java index bdc60ebab55ef..4ef20a724ab3c 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java @@ -55,7 +55,7 @@ public EsField(String name, DataType esDataType, Map properties public EsField(StreamInput in) throws IOException { this.name = in.readString(); this.esDataType = DataType.readFrom(in); - this.properties = in.readImmutableMap(StreamInput::readString, i -> i.readNamedWriteable(EsField.class)); + this.properties = in.readImmutableMap(i -> i.readNamedWriteable(EsField.class)); this.aggregatable = in.readBoolean(); this.isAlias = in.readBoolean(); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/Schema.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/Schema.java deleted file mode 100644 index fa7c1d7e1e3e6..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/Schema.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.type; - -import org.elasticsearch.xpack.esql.core.util.Check; - -import java.util.Iterator; -import java.util.List; -import java.util.NoSuchElementException; -import java.util.Spliterator; -import java.util.Spliterators; -import java.util.stream.Stream; -import java.util.stream.StreamSupport; - -import static java.util.Collections.emptyList; - -public class Schema implements Iterable { - - public interface Entry { - String name(); - - DataType type(); - } - - static class DefaultEntry implements Entry { - private final String name; - private final DataType type; - - DefaultEntry(String name, DataType type) { - this.name = name; - this.type = type; - } - - @Override - public String name() { - return name; - } - - @Override - public DataType type() { - return type; - } - } - - public static final Schema EMPTY = new Schema(emptyList(), emptyList()); - - private final List names; - private final List types; - - public Schema(List names, List types) { - Check.isTrue(names.size() == types.size(), "Different # of names {} vs types {}", names, types); - this.types = types; - this.names = names; - } - - public List names() { - return names; - } - - public List types() { - return types; - } - - public int size() { - return names.size(); - } - - public Entry get(int i) { - return new DefaultEntry(names.get(i), types.get(i)); - } - - public DataType type(String name) { - int indexOf = names.indexOf(name); - if (indexOf < 0) { - return null; - } - return types.get(indexOf); - } - - @Override - public Iterator iterator() { - return new Iterator<>() { - private final int size = size(); - private int pos = -1; - - @Override - public boolean hasNext() { - return pos < size - 1; - } - - @Override - public Entry next() { - if (pos++ >= size) { - throw new NoSuchElementException(); - } - return get(pos); - } - }; - } - - public Stream stream() { - return StreamSupport.stream(spliterator(), false); - } - - @Override - public Spliterator spliterator() { - return Spliterators.spliterator(iterator(), size(), 0); - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("["); - for (int i = 0; i < names.size(); i++) { - if (i > 0) { - sb.append(","); - } - sb.append(names.get(i)); - sb.append(":"); - sb.append(types.get(i).typeName()); - } - sb.append("]"); - return sb.toString(); - } -} diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/ConstantExtractorTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/ConstantExtractorTests.java deleted file mode 100644 index a7b55ba38be12..0000000000000 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/execution/search/extractor/ConstantExtractorTests.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.execution.search.extractor; - -import org.elasticsearch.common.io.stream.Writeable.Reader; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.test.AbstractWireSerializingTestCase; - -import java.util.function.Supplier; - -public class ConstantExtractorTests extends AbstractWireSerializingTestCase { - public static ConstantExtractor randomConstantExtractor() { - return new ConstantExtractor(randomValidConstant()); - } - - private static Object randomValidConstant() { - @SuppressWarnings("unchecked") - Supplier valueSupplier = randomFrom(() -> randomInt(), () -> randomDouble(), () -> randomAlphaOfLengthBetween(1, 140)); - return valueSupplier.get(); - } - - @Override - protected ConstantExtractor createTestInstance() { - return randomConstantExtractor(); - } - - @Override - protected Reader instanceReader() { - return ConstantExtractor::new; - } - - @Override - protected ConstantExtractor mutateInstance(ConstantExtractor instance) { - return new ConstantExtractor(instance.extract((SearchHit) null) + "mutated"); - } - - public void testGet() { - Object expected = randomValidConstant(); - int times = between(1, 1000); - for (int i = 0; i < times; i++) { - assertSame(expected, new ConstantExtractor(expected).extract((SearchHit) null)); - } - } - - public void testToString() { - assertEquals("^foo", new ConstantExtractor("foo").toString()); - assertEquals("^42", new ConstantExtractor("42").toString()); - } -} diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/function/FunctionRegistryTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/function/FunctionRegistryTests.java index 47f5befcf325e..8691b5e9153fb 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/function/FunctionRegistryTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/function/FunctionRegistryTests.java @@ -10,9 +10,7 @@ import org.elasticsearch.xpack.esql.core.ParsingException; import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.function.scalar.ConfigurationFunction; import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; -import org.elasticsearch.xpack.esql.core.session.Configuration; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.tree.SourceTests; @@ -166,18 +164,6 @@ public void testFunctionResolving() { assertThat(e.getMessage(), is("Cannot find function dummyfunction; this should have been caught during analysis")); } - public void testConfigurationOptionalFunction() { - UnresolvedFunction ur = uf(DEFAULT, mock(Expression.class)); - FunctionRegistry r = new FunctionRegistry( - def(DummyConfigurationOptionalArgumentFunction.class, (Source l, Expression e, Configuration c) -> { - assertSame(e, ur.children().get(0)); - return new DummyConfigurationOptionalArgumentFunction(l, List.of(ur), c); - }, "dummy") - ); - FunctionDefinition def = r.resolveFunction(r.resolveAlias("DUMMY")); - assertEquals(ur.source(), ur.buildResolved(randomConfiguration(), def).source()); - } - public static UnresolvedFunction uf(FunctionResolutionStrategy resolutionStrategy, Expression... children) { return new UnresolvedFunction(SourceTests.randomSource(), "dummy_function", resolutionStrategy, Arrays.asList(children)); } @@ -208,26 +194,4 @@ public DummyFunction2(Source source) { super(source); } } - - public static class DummyConfigurationOptionalArgumentFunction extends ConfigurationFunction implements OptionalArgument { - - public DummyConfigurationOptionalArgumentFunction(Source source, List fields, Configuration configuration) { - super(source, fields, configuration); - } - - @Override - public DataType dataType() { - return null; - } - - @Override - public Expression replaceChildren(List newChildren) { - return new DummyConfigurationOptionalArgumentFunction(source(), newChildren, configuration()); - } - - @Override - protected NodeInfo info() { - return NodeInfo.create(this, DummyConfigurationOptionalArgumentFunction::new, children(), configuration()); - } - } } diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/TypesTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/TypesTests.java index 489666976b592..1974eb3669f4b 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/TypesTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/TypesTests.java @@ -103,16 +103,6 @@ public void testDateMulti() { assertThat(field, is(instanceOf(DateEsField.class))); } - public void testDateNanosField() { - Map mapping = loadMapping("mapping-date_nanos.json"); - - assertThat(mapping.size(), is(1)); - EsField field = mapping.get("date_nanos"); - assertThat(field.getDataType(), is(DATETIME)); - assertThat(field.isAggregatable(), is(true)); - assertThat(field, is(instanceOf(DateEsField.class))); - } - public void testDocValueField() { Map mapping = loadMapping("mapping-docvalues.json"); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/boolean.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/boolean.csv-spec index c0572e7bbcd49..adbf24cee10b0 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/boolean.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/boolean.csv-spec @@ -67,10 +67,9 @@ required_capability: mv_warn from employees | keep emp_no, is_rehired, still_hired | where is_rehired in (still_hired, true) | where is_rehired != still_hired; ignoreOrder:true -warning:Line 1:63: evaluation of [is_rehired in (still_hired, true)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:63: java.lang.IllegalArgumentException: single-value function encountered multi-value -warning:Line 1:105: evaluation of [is_rehired != still_hired] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:105: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[is_rehired in \(still_hired, true\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[is_rehired != still_hired\] failed, treating result as null. Only first 20 failures recorded. emp_no:integer |is_rehired:boolean |still_hired:boolean 10021 |true |false 10029 |true |false diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec index 225ea37688689..f8a49c3a59f98 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dissect.csv-spec @@ -14,6 +14,43 @@ foo bar | null | null ; +shadowing +FROM employees +| KEEP first_name, last_name +| WHERE last_name == "Facello" +| EVAL left = "left", full_name = concat(first_name, " ", last_name) , last_name = "last_name", right = "right" +| DISSECT full_name "%{?} %{last_name}" +; + +first_name:keyword | left:keyword | full_name:keyword | right:keyword | last_name:keyword +Georgi | left | Georgi Facello | right | Facello +; + +shadowingSelf +FROM employees +| KEEP first_name, last_name +| WHERE last_name == "Facello" +| EVAL left = "left", name = concat(first_name, "1 ", last_name), right = "right" +| DISSECT name "%{name} %{?}" +; + +first_name:keyword | last_name:keyword | left:keyword | right:keyword | name:keyword +Georgi | Facello | left | right | Georgi1 +; + +shadowingMulti +FROM employees +| KEEP first_name, last_name +| WHERE last_name == "Facello" +| EVAL left = "left", foo = concat(first_name, "1 ", first_name, "2 ", last_name) , middle = "middle", bar = "bar", right = "right" +| DISSECT foo "%{bar} %{first_name} %{last_name_again}" +; + +last_name:keyword | left:keyword | foo:keyword | middle:keyword | right:keyword | bar:keyword | first_name:keyword | last_name_again:keyword +Facello | left | Georgi1 Georgi2 Facello | middle | right | Georgi1 | Georgi2 | Facello +; + + complexPattern ROW a = "1953-01-23T12:15:00Z - some text - 127.0.0.1;" | DISSECT a "%{Y}-%{M}-%{D}T%{h}:%{m}:%{s}Z - %{msg} - %{ip};" diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec index bd384886f0dd7..bc79d1c62bd67 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/enrich.csv-spec @@ -31,6 +31,82 @@ FROM sample_data median_duration:double | env:keyword ; +shadowing +required_capability: enrich_load +ROW left = "left", client_ip = "172.21.0.5", env = "env", right = "right" +| ENRICH clientip_policy ON client_ip +; + +left:keyword | client_ip:keyword | right:keyword | env:keyword +left | 172.21.0.5 | right | Development +; + +shadowingLimit0 +ROW left = "left", client_ip = "172.21.0.5", env = "env", right = "right" +| ENRICH clientip_policy ON client_ip +| LIMIT 0 +; + +left:keyword | client_ip:keyword | right:keyword | env:keyword +; + +shadowingWithAlias +required_capability: enrich_load +ROW left = "left", foo = "foo", client_ip = "172.21.0.5", env = "env", right = "right" +| ENRICH clientip_policy ON client_ip WITH foo = env +; + +left:keyword | client_ip:keyword | env:keyword | right:keyword | foo:keyword +left | 172.21.0.5 | env | right | Development +; + +shadowingWithAliasLimit0 +ROW left = "left", foo = "foo", client_ip = "172.21.0.5", env = "env", right = "right" +| ENRICH clientip_policy ON client_ip WITH foo = env +| LIMIT 0 +; + +left:keyword | client_ip:keyword | env:keyword | right:keyword | foo:keyword +; + +shadowingSelf +required_capability: enrich_load +ROW left = "left", client_ip = "172.21.0.5", env = "env", right = "right" +| ENRICH clientip_policy ON client_ip WITH client_ip = env +; + +left:keyword | env:keyword | right:keyword | client_ip:keyword +left | env | right | Development +; + +shadowingSelfLimit0 +ROW left = "left", client_ip = "172.21.0.5", env = "env", right = "right" +| ENRICH clientip_policy ON client_ip WITH client_ip = env +| LIMIT 0 +; + +left:keyword | env:keyword | right:keyword | client_ip:keyword +; + +shadowingMulti +required_capability: enrich_load +ROW left = "left", airport = "Zurich Airport ZRH", city = "Zürich", middle = "middle", region = "North-East Switzerland", right = "right" +| ENRICH city_names ON city WITH airport, region, city_boundary +; + +left:keyword | city:keyword | middle:keyword | right:keyword | airport:text | region:text | city_boundary:geo_shape +left | Zürich | middle | right | Zurich Int'l | Bezirk Zürich | "POLYGON((8.448 47.3802,8.4977 47.3452,8.5032 47.3202,8.6254 47.3547,8.5832 47.3883,8.5973 47.4063,8.5431 47.4329,8.4858 47.431,8.4691 47.4169,8.473 47.3951,8.448 47.3802))" +; + +shadowingMultiLimit0 +ROW left = "left", airport = "Zurich Airport ZRH", city = "Zurich", middle = "middle", region = "North-East Switzerland", right = "right" +| ENRICH city_names ON city WITH airport, region, city_boundary +| LIMIT 0 +; + +left:keyword | city:keyword | middle:keyword | right:keyword | airport:text | region:text | city_boundary:geo_shape +; + simple required_capability: enrich_load @@ -428,8 +504,8 @@ FROM airports | EVAL boundary_wkt_length = LENGTH(TO_STRING(city_boundary)) | STATS city_centroid = ST_CENTROID_AGG(city_location), count = COUNT(city_location), min_wkt = MIN(boundary_wkt_length), max_wkt = MAX(boundary_wkt_length) ; -warning:Line 3:30: evaluation of [LENGTH(TO_STRING(city_boundary))] failed, treating result as null. Only first 20 failures recorded. -warning:Line 3:30: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[LENGTH\(TO_STRING\(city_boundary\)\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value city_centroid:geo_point | count:long | min_wkt:integer | max_wkt:integer POINT(1.396561 24.127649) | 872 | 88 | 1044 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec index 571d7835451c3..3df3b85e5e3af 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/eval.csv-spec @@ -5,6 +5,34 @@ a:integer | b:integer 1 | 2 ; + +shadowing +ROW left = "left", x = 10000 , right = "right" +| EVAL x = 1 +; + +left:keyword | right:keyword | x:integer +left | right | 1 +; + +shadowingSelf +ROW left = "left", x = 10000 , right = "right" +| EVAL x = x + 1 +; + +left:keyword | right:keyword | x:integer +left | right | 10001 +; + +shadowingMulti +ROW left = "left", x = 0, middle = "middle", y = -1, right = "right" +| EVAL x = 9, y = 10 +; + +left:keyword | middle:keyword | right:keyword | x:integer | y:integer +left | middle | right | 9 | 10 +; + withMath row a = 1 | eval b = 2 + 3; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/floats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/floats.csv-spec index 66f4e9a33ceff..2ee7f783b7e97 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/floats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/floats.csv-spec @@ -95,8 +95,8 @@ lessThanMultivalue required_capability: mv_warn from employees | where salary_change < 1 | keep emp_no, salary_change | sort emp_no | limit 5; -warning:Line 1:24: evaluation of [salary_change < 1] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[salary_change < 1\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries aren't less than 1 - they are null - so they aren't included emp_no:integer |salary_change:double @@ -111,8 +111,8 @@ greaterThanMultivalue required_capability: mv_warn from employees | where salary_change > 1 | keep emp_no, salary_change | sort emp_no | limit 5; -warning:Line 1:24: evaluation of [salary_change > 1] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[salary_change > 1\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries aren't greater than 1 - they are null - so they aren't included emp_no:integer |salary_change:double @@ -165,8 +165,8 @@ notLessThanMultivalue required_capability: mv_warn from employees | where not(salary_change < 1) | keep emp_no, salary_change | sort emp_no | limit 5; -warning:Line 1:24: evaluation of [not(salary_change < 1)] failed, treating result as null. Only first 20 failures recorded.#[Emulated:Line 1:28: evaluation of [salary_change < 1] failed, treating result as null. Only first 20 failures recorded.] -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value#[Emulated:Line 1:28: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[.*salary_change < 1.*\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries aren't less than 1 - they are null - so they aren't included emp_no:integer |salary_change:double @@ -181,8 +181,8 @@ notGreaterThanMultivalue required_capability: mv_warn from employees | where not(salary_change > 1) | keep emp_no, salary_change | sort emp_no | limit 5; -warning:Line 1:24: evaluation of [not(salary_change > 1)] failed, treating result as null. Only first 20 failures recorded.#[Emulated:Line 1:28: evaluation of [salary_change > 1] failed, treating result as null. Only first 20 failures recorded.] -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value#[Emulated:Line 1:28: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[.*salary_change > 1.*\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries aren't less than 1 - they are null - so they aren't included emp_no:integer |salary_change:double @@ -197,8 +197,8 @@ notEqualToMultivalue required_capability: mv_warn from employees | where not(salary_change == 1.19) | keep emp_no, salary_change | sort emp_no | limit 5; -warning:Line 1:24: evaluation of [not(salary_change == 1.19)] failed, treating result as null. Only first 20 failures recorded.#[Emulated:Line 1:28: evaluation of [salary_change == 1.19] failed, treating result as null. Only first 20 failures recorded.] -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value#[Emulated:Line 1:28: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[.*salary_change == 1.19.*\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries aren't greater than 1 - they are null - so they aren't included emp_no:integer |salary_change:double diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec index fbe31deeb0f97..49a8085e0c186 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/grok.csv-spec @@ -14,6 +14,43 @@ foo bar | null ; +shadowing +FROM employees +| KEEP first_name, last_name +| WHERE last_name == "Facello" +| EVAL left = "left", full_name = concat(first_name, " ", last_name) , last_name = "last_name", right = "right" +| GROK full_name "%{WORD} %{WORD:last_name}" +; + +first_name:keyword | left:keyword | full_name:keyword | right:keyword | last_name:keyword +Georgi | left | Georgi Facello | right | Facello +; + +shadowingSelf +FROM employees +| KEEP first_name, last_name +| WHERE last_name == "Facello" +| EVAL left = "left", name = concat(first_name, "1 ", last_name), right = "right" +| GROK name "%{WORD:name} %{WORD}" +; + +first_name:keyword | last_name:keyword | left:keyword | right:keyword | name:keyword +Georgi | Facello | left | right | Georgi1 +; + +shadowingMulti +FROM employees +| KEEP first_name, last_name +| WHERE last_name == "Facello" +| EVAL left = "left", foo = concat(first_name, "1 ", first_name, "2 ", last_name) , middle = "middle", bar = "bar", right = "right" +| GROK foo "%{WORD:bar} %{WORD:first_name} %{WORD:last_name_again}" +; + +last_name:keyword | left:keyword | foo:keyword | middle:keyword | right:keyword | bar:keyword | first_name:keyword | last_name_again:keyword +Facello | left | Georgi1 Georgi2 Facello | middle | right | Georgi1 | Georgi2 | Facello +; + + complexPattern ROW a = "1953-01-23T12:15:00Z 127.0.0.1 some.email@foo.com 42" | GROK a "%{TIMESTAMP_ISO8601:date} %{IP:ip} %{EMAILADDRESS:email} %{NUMBER:num:int}" diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ints.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ints.csv-spec index 2e45febe0de1d..c8cb6cf88a4f0 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ints.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ints.csv-spec @@ -4,8 +4,8 @@ inLongAndInt required_capability: mv_warn from employees | where avg_worked_seconds in (372957040, salary_change.long, 236703986) | where emp_no in (10017, emp_no - 1) | keep emp_no, avg_worked_seconds; -warning:Line 1:24: evaluation of [avg_worked_seconds in (372957040, salary_change.long, 236703986)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[avg_worked_seconds in \(372957040, salary_change.long, 236703986\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value emp_no:integer |avg_worked_seconds:long 10017 |236703986 @@ -268,8 +268,8 @@ lessThanMultivalue required_capability: mv_warn from employees | where salary_change.int < 1 | keep emp_no, salary_change.int | sort emp_no | limit 5; -warning:Line 1:24: evaluation of [salary_change.int < 1] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[salary_change.int < 1\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries aren't less than 1 - they are null - so they aren't included emp_no:integer |salary_change.int:integer @@ -284,8 +284,8 @@ greaterThanMultivalue required_capability: mv_warn from employees | where salary_change.int > 1 | keep emp_no, salary_change.int | sort emp_no | limit 5; -warning:Line 1:24: evaluation of [salary_change.int > 1] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[salary_change.int > 1\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries aren't greater than 1 - they are null - so they aren't included emp_no:integer |salary_change.int:integer @@ -300,8 +300,8 @@ equalToMultivalue required_capability: mv_warn from employees | where salary_change.int == 0 | keep emp_no, salary_change.int | sort emp_no; -warning:Line 1:24: evaluation of [salary_change.int == 0] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[salary_change.int == 0\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries aren't greater than 1 - they are null - so they aren't included emp_no:integer |salary_change.int:integer @@ -315,8 +315,8 @@ equalToOrEqualToMultivalue required_capability: mv_warn from employees | where salary_change.int == 1 or salary_change.int == 8 | keep emp_no, salary_change.int | sort emp_no; -warning:Line 1:24: evaluation of [salary_change.int] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[salary_change.int\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries are filtered out emp_no:integer |salary_change.int:integer @@ -328,8 +328,8 @@ inMultivalue required_capability: mv_warn from employees | where salary_change.int in (1, 7) | keep emp_no, salary_change.int | sort emp_no; -warning:Line 1:24: evaluation of [salary_change.int in (1, 7)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[salary_change.int in \(1, 7\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries are filtered out emp_no:integer |salary_change.int:integer @@ -341,8 +341,8 @@ notLessThanMultivalue required_capability: mv_warn from employees | where not(salary_change.int < 1) | keep emp_no, salary_change.int | sort emp_no | limit 5; -warning:Line 1:24: evaluation of [not(salary_change.int < 1)] failed, treating result as null. Only first 20 failures recorded.#[Emulated:Line 1:28: evaluation of [salary_change.int < 1] failed, treating result as null. Only first 20 failures recorded.] -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value#[emulated:Line 1:28: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[.*salary_change.int < 1.*\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries aren't less than 1 - they are null - so they aren't included emp_no:integer |salary_change.int:integer @@ -357,8 +357,8 @@ notGreaterThanMultivalue required_capability: mv_warn from employees | where not(salary_change.int > 1) | keep emp_no, salary_change.int | sort emp_no | limit 5; -warning:Line 1:24: evaluation of [not(salary_change.int > 1)] failed, treating result as null. Only first 20 failures recorded.#[Emulated:Line 1:28: evaluation of [salary_change.int > 1] failed, treating result as null. Only first 20 failures recorded.] -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value#[Emulated:Line 1:28: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[.*salary_change.int > 1.*\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries aren't less than 1 - they are null - so they aren't included emp_no:integer |salary_change.int:integer @@ -373,8 +373,8 @@ notEqualToMultivalue required_capability: mv_warn from employees | where not(salary_change.int == 1) | keep emp_no, salary_change.int | sort emp_no | limit 5; -warning:Line 1:24: evaluation of [not(salary_change.int == 1)] failed, treating result as null. Only first 20 failures recorded.#[Emulated:Line 1:28: evaluation of [salary_change.int == 1] failed, treating result as null. Only first 20 failures recorded.] -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value#[Emulated:Line 1:28: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[.*salary_change.int == 1.*\] failed, treating result as null. Only first 20 failures recorded +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued salaries aren't greater than 1 - they are null - so they aren't included emp_no:integer |salary_change.int:integer diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec index 61f529d60bf90..54d5484bb4172 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec @@ -19,8 +19,8 @@ equals required_capability: mv_warn from hosts | sort host, card | where ip0 == ip1 | keep card, host, ip0, ip1; -warning:Line 1:38: evaluation of [ip0 == ip1] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:38: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[ip0 == ip1\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value card:keyword |host:keyword |ip0:ip |ip1:ip eth0 |alpha |127.0.0.1 |127.0.0.1 @@ -63,8 +63,8 @@ lessThan required_capability: mv_warn from hosts | sort host, card, ip1 | where ip0 < ip1 | keep card, host, ip0, ip1; -warning:Line 1:43: evaluation of [ip0 < ip1] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:43: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[ip0 < ip1\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value card:keyword |host:keyword |ip0:ip |ip1:ip eth1 |beta |127.0.0.1 |127.0.0.2 @@ -76,8 +76,8 @@ notEquals required_capability: mv_warn from hosts | sort host, card, ip1 | where ip0 != ip1 | keep card, host, ip0, ip1; -warning:Line 1:43: evaluation of [ip0 != ip1] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:43: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[ip0 != ip1\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value card:keyword |host:keyword |ip0:ip |ip1:ip eth0 |beta |127.0.0.1 |::1 @@ -150,10 +150,10 @@ required_capability: mv_warn from hosts | eval eq=case(ip0==ip1, ip0, ip1) | where eq in (ip0, ip1) | keep card, host, ip0, ip1, eq; ignoreOrder:true -warning:Line 1:27: evaluation of [ip0==ip1] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:27: java.lang.IllegalArgumentException: single-value function encountered multi-value -warning:Line 1:55: evaluation of [eq in (ip0, ip1)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:55: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[ip0==ip1\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[eq in \(ip0, ip1\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value card:keyword |host:keyword |ip0:ip |ip1:ip |eq:ip eth0 |alpha |127.0.0.1 |127.0.0.1 |127.0.0.1 @@ -191,8 +191,8 @@ cidrMatchSimple required_capability: mv_warn from hosts | where cidr_match(ip1, "127.0.0.2/32") | keep card, host, ip0, ip1; -warning:Line 1:20: evaluation of [cidr_match(ip1, \"127.0.0.2/32\")] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:20: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[cidr_match\(ip1, \\\"127.0.0.2/32\\\"\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value card:keyword |host:keyword |ip0:ip |ip1:ip eth1 |beta |127.0.0.1 |127.0.0.2 @@ -203,8 +203,8 @@ required_capability: mv_warn from hosts | where cidr_match(ip0, "127.0.0.2/32") is null | keep card, host, ip0, ip1; ignoreOrder:true -warning:Line 1:20: evaluation of [cidr_match(ip0, \"127.0.0.2/32\")] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:20: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[cidr_match\(ip0, \\\"127.0.0.2/32\\\"\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value card:keyword |host:keyword |ip0:ip |ip1:ip eth0 |epsilon |[fe80::cae2:65ff:fece:feb9, fe80::cae2:65ff:fece:fec0, fe80::cae2:65ff:fece:fec1]|fe80::cae2:65ff:fece:fec1 @@ -312,8 +312,8 @@ required_capability: mv_warn from hosts | where ip1 > to_ip("127.0.0.1") | keep card, ip1; ignoreOrder:true -warning:Line 1:20: evaluation of [ip1 > to_ip(\"127.0.0.1\")] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:20: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[ip1 > to_ip\(\\\"127.0.0.1\\\"\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value card:keyword |ip1:ip eth1 |127.0.0.2 @@ -553,8 +553,8 @@ required_capability: fn_ip_prefix from hosts | stats count(*) by ip_prefix(ip1, 24, 120) | sort `ip_prefix(ip1, 24, 120)`; -warning:Line 2:21: evaluation of [ip_prefix(ip1, 24, 120)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 2:21: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[ip_prefix\(ip1, 24, 120\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value count(*):long | ip_prefix(ip1, 24, 120):ip 2 | ::0 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-metrics.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-metrics.csv-spec index 91084726bfb25..3976329501894 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-metrics.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/k8s-metrics.csv-spec @@ -19,3 +19,72 @@ max_bytes:long | cluster: keyword 10277 | prod 7403 | staging ; + +maxRate +required_capability: metrics_syntax +METRICS k8s max(rate(network.total_bytes_in, 1minute)); + +max(rate(network.total_bytes_in, 1minute)): double +790.4235090751945 +; + +maxCost +required_capability: metrics_syntax +METRICS k8s max_cost=max(rate(network.total_cost)); + +max_cost: double +0.16151685393258428 +; + +sumRate +required_capability: metrics_syntax +METRICS k8s bytes=sum(rate(network.total_bytes_in)), sum(rate(network.total_cost)) BY cluster | SORT cluster; + +bytes: double | sum(rate(network.total_cost)): double | cluster: keyword +24.49149357711476 | 0.3018995503437827 | prod +33.308519044441084 | 0.4474920369252062 | qa +18.610708062016123 | 0.24387090901805775 | staging +; + +oneRateWithBucket +required_capability: metrics_syntax +METRICS k8s max(rate(network.total_bytes_in)) BY time_bucket = bucket(@timestamp,5minute) | SORT time_bucket DESC | LIMIT 2; + +max(rate(network.total_bytes_in)): double | time_bucket:date +10.594594594594595 | 2024-05-10T00:20:00.000Z +23.702205882352942 | 2024-05-10T00:15:00.000Z +; + +twoRatesWithBucket +required_capability: metrics_syntax +METRICS k8s max(rate(network.total_bytes_in)), sum(rate(network.total_bytes_in)) BY time_bucket = bucket(@timestamp,5minute) | SORT time_bucket DESC | LIMIT 3; + +max(rate(network.total_bytes_in)): double | sum(rate(network.total_bytes_in)): double | time_bucket:date +10.594594594594595 | 42.70864495221802 | 2024-05-10T00:20:00.000Z +23.702205882352942 | 112.36715680313907 | 2024-05-10T00:15:00.000Z +17.90625 | 85.18387414067914 | 2024-05-10T00:10:00.000Z +; + + +oneRateWithBucketAndCluster +required_capability: metrics_syntax +METRICS k8s max(rate(network.total_bytes_in)) BY time_bucket = bucket(@timestamp,5minute), cluster | SORT time_bucket DESC, cluster | LIMIT 6; + +max(rate(network.total_bytes_in)): double | time_bucket:date | cluster: keyword +10.594594594594595 | 2024-05-10T00:20:00.000Z | prod +5.586206896551724 | 2024-05-10T00:20:00.000Z | qa +5.37037037037037 | 2024-05-10T00:20:00.000Z | staging +15.913978494623656 | 2024-05-10T00:15:00.000Z | prod +23.702205882352942 | 2024-05-10T00:15:00.000Z | qa +9.823232323232324 | 2024-05-10T00:15:00.000Z | staging +; + +oneRateWithBucketAndClusterThenFilter +required_capability: metrics_syntax +METRICS k8s max(rate(network.total_bytes_in)) BY time_bucket = bucket(@timestamp,5minute), cluster | WHERE cluster=="prod" | SORT time_bucket DESC | LIMIT 3; + +max(rate(network.total_bytes_in)): double | time_bucket:date | cluster: keyword +10.594594594594595 | 2024-05-10T00:20:00.000Z | prod +15.913978494623656 | 2024-05-10T00:15:00.000Z | prod +11.562737642585551 | 2024-05-10T00:10:00.000Z | prod +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup.csv-spec index 35e1101becbf9..77d8e48d9e81f 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup.csv-spec @@ -1,5 +1,5 @@ keywordByInt -required_capability: tables_types +required_capability: lookup_v3 FROM employees | SORT emp_no | LIMIT 4 @@ -17,7 +17,7 @@ emp_no:integer | languages:integer | lang_name:keyword ; keywordByMvInt -required_capability: tables_types +required_capability: lookup_v3 ROW int=[1, 2, 3] | LOOKUP int_number_names ON int ; @@ -27,7 +27,7 @@ int:integer | name:keyword ; keywordByDupeInt -required_capability: tables_types +required_capability: lookup_v3 ROW int=[1, 1, 1] | LOOKUP int_number_names ON int ; @@ -37,7 +37,7 @@ int:integer | name:keyword ; intByKeyword -required_capability: tables_types +required_capability: lookup_v3 ROW name="two" | LOOKUP int_number_names ON name ; @@ -48,7 +48,7 @@ name:keyword | int:integer keywordByLong -required_capability: tables_types +required_capability: lookup_v3 FROM employees | SORT emp_no | LIMIT 4 @@ -66,7 +66,7 @@ emp_no:integer | languages:long | lang_name:keyword ; longByKeyword -required_capability: tables_types +required_capability: lookup_v3 ROW name="two" | LOOKUP long_number_names ON name ; @@ -76,7 +76,7 @@ name:keyword | long:long ; keywordByFloat -required_capability: tables_types +required_capability: lookup_v3 FROM employees | SORT emp_no | LIMIT 4 @@ -94,7 +94,7 @@ emp_no:integer | height:double | height_name:keyword ; floatByKeyword -required_capability: tables_types +required_capability: lookup_v3 ROW name="two point zero eight" | LOOKUP double_number_names ON name ; @@ -104,7 +104,7 @@ two point zero eight | 2.08 ; floatByNullMissing -required_capability: tables_types +required_capability: lookup_v3 ROW name=null | LOOKUP double_number_names ON name ; @@ -114,7 +114,7 @@ name:null | double:double ; floatByNullMatching -required_capability: tables_types +required_capability: lookup_v3 ROW name=null | LOOKUP double_number_names_with_null ON name ; @@ -124,7 +124,7 @@ name:null | double:double ; intIntByKeywordKeyword -required_capability: tables_types +required_capability: lookup_v3 ROW aa="foo", ab="zoo" | LOOKUP big ON aa, ab ; @@ -134,7 +134,7 @@ foo | zoo | 1 | -1 ; intIntByKeywordKeywordMissing -required_capability: tables_types +required_capability: lookup_v3 ROW aa="foo", ab="zoi" | LOOKUP big ON aa, ab ; @@ -144,7 +144,7 @@ foo | zoi | null | null ; intIntByKeywordKeywordSameValues -required_capability: tables_types +required_capability: lookup_v3 ROW aa="foo", ab="foo" | LOOKUP big ON aa, ab ; @@ -154,7 +154,7 @@ foo | foo | 2 | -2 ; intIntByKeywordKeywordSameValuesMissing -required_capability: tables_types +required_capability: lookup_v3 ROW aa="bar", ab="bar" | LOOKUP big ON aa, ab ; @@ -164,7 +164,7 @@ bar | bar | null | null ; lookupBeforeStats -required_capability: tables_types +required_capability: lookup_v3 FROM employees | RENAME languages AS int | LOOKUP int_number_names ON int @@ -182,7 +182,7 @@ height:double | languages:keyword ; lookupAfterStats -required_capability: tables_types +required_capability: lookup_v3 FROM employees | STATS int=TO_INT(AVG(height)) | LOOKUP int_number_names ON int @@ -194,7 +194,7 @@ two // Makes sure the LOOKUP squashes previous names doesNotDuplicateNames -required_capability: tables_types +required_capability: lookup_v3 FROM employees | SORT emp_no | LIMIT 4 @@ -213,7 +213,7 @@ emp_no:integer | languages:long | name:keyword ; lookupBeforeSort -required_capability: tables_types +required_capability: lookup_v3 FROM employees | WHERE emp_no < 10005 | RENAME languages AS int @@ -231,7 +231,7 @@ languages:keyword | emp_no:integer ; lookupAfterSort -required_capability: tables_types +required_capability: lookup_v3 FROM employees | WHERE emp_no < 10005 | SORT languages ASC, emp_no ASC @@ -248,12 +248,38 @@ languages:keyword | emp_no:integer five | 10004 ; +shadowing +required_capability: lookup_v3 +FROM employees +| KEEP emp_no +| WHERE emp_no == 10001 +| EVAL left = "left", int = emp_no - 10000, name = "name", right = "right" +| LOOKUP int_number_names ON int +; + +emp_no:integer | left:keyword | int:integer | right:keyword | name:keyword + 10001 | left | 1 | right | one +; + +shadowingMulti +required_capability: lookup_v3 +FROM employees +| KEEP emp_no +| WHERE emp_no == 10001 +| EVAL left = "left", nb = -10011+emp_no, na = "na", middle = "middle", ab = "ab", aa = "bar", right = "right" +| LOOKUP big ON aa, nb +; + +emp_no:integer | left:keyword | nb:integer | middle:keyword | aa:keyword | right:keyword | ab:keyword | na:integer + 10001 | left | -10 | middle | bar | right | zop | 10 +; + // // Make sure that the new LOOKUP syntax doesn't clash with any existing things // named "lookup" // rowNamedLookup -required_capability: tables_types +required_capability: lookup_v3 ROW lookup = "a" ; @@ -262,7 +288,7 @@ lookup:keyword ; rowNamedLOOKUP -required_capability: tables_types +required_capability: lookup_v3 ROW LOOKUP = "a" ; @@ -271,7 +297,7 @@ LOOKUP:keyword ; evalNamedLookup -required_capability: tables_types +required_capability: lookup_v3 ROW a = "a" | EVAL lookup = CONCAT(a, "1") ; @@ -280,7 +306,7 @@ a:keyword | lookup:keyword ; dissectNamedLookup -required_capability: tables_types +required_capability: lookup_v3 row a = "foo bar" | dissect a "foo %{lookup}"; a:keyword | lookup:keyword @@ -288,7 +314,7 @@ a:keyword | lookup:keyword ; renameIntoLookup -required_capability: tables_types +required_capability: lookup_v3 row a = "foo bar" | RENAME a AS lookup; lookup:keyword @@ -296,7 +322,7 @@ lookup:keyword ; sortOnLookup -required_capability: tables_types +required_capability: lookup_v3 ROW lookup = "a" | SORT lookup ; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec index be6cd058d24e9..8337af42df5ea 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec @@ -443,8 +443,8 @@ from employees | keep emp_no, languages, salary, avg_worked_seconds, l1, l2, l3 | limit 5; -warning:Line 2:13: evaluation of [LOG(languages, salary)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 2:13: java.lang.ArithmeticException: Log of base 1 +warningRegex:evaluation of \[LOG\(languages, salary\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.ArithmeticException: Log of base 1 emp_no:integer | languages:integer | salary:integer | avg_worked_seconds:long | l1:double | l2:double | l3:double 10001 | 2 | 57305 | 268728049 | 15.806373402659007 | 19.409210455930772 | 35.21558385858978 @@ -481,8 +481,8 @@ from employees | keep l1, l2, emp_no, languages, salary, avg_worked_seconds, l3 | limit 5; -warning:Line 2:13: evaluation of [LOG(languages, salary)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 2:13: java.lang.ArithmeticException: Log of base 1 +warningRegex:evaluation of \[LOG\(languages, salary\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.ArithmeticException: Log of base 1 l1:double | l2:double | emp_no:integer | languages:integer | salary:integer | avg_worked_seconds:long | l3:double 6.300030441266983 | 19.782340222815456 | 10015 | 5 | 25324 | 390266432 | 26.08237066408244 @@ -502,8 +502,8 @@ from employees | keep l1, l2, emp_no, base1, salary, avg_worked_seconds, l3 | limit 5; -warning:Line 3:13: evaluation of [LOG(base1, salary)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 3:13: java.lang.ArithmeticException: Log of base 1 +warningRegex:evaluation of \[LOG\(base1, salary\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.ArithmeticException: Log of base 1 l1:double | l2:double | emp_no:integer | base1:integer | salary:integer | avg_worked_seconds:long | l3:double null | 19.774989878141827 | 10044 | 1 | 39728 | 387408356 | null diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec index 53d7d1fd0d352..b77c9d7200af9 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/string.csv-spec @@ -342,8 +342,8 @@ required_capability: mv_warn from employees | where job_positions in ("Internship", first_name) | keep emp_no, job_positions; ignoreOrder:true -warning:Line 1:24: evaluation of [job_positions in (\"Internship\", first_name)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[job_positions in \(\\\"Internship\\\", first_name\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value emp_no:integer |job_positions:keyword 10048 |Internship @@ -533,8 +533,8 @@ lessThanMultivalue required_capability: mv_warn from employees | where job_positions < "C" | keep emp_no, job_positions | sort emp_no; -warning:Line 1:24: evaluation of [job_positions < \"C\"] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[job_positions < \\\"C\\\"\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued job_positions aren't included because they aren't less than or greater than C - that comparison is null emp_no:integer |job_positions:keyword @@ -546,8 +546,8 @@ greaterThanMultivalue required_capability: mv_warn from employees | where job_positions > "C" | keep emp_no, job_positions | sort emp_no | limit 6; -warning:Line 1:24: evaluation of [job_positions > \"C\"] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[job_positions > \\\"C\\\"\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued job_positions aren't included because they aren't less than or greater than C - that comparison is null emp_no:integer |job_positions:keyword @@ -563,8 +563,8 @@ equalToMultivalue required_capability: mv_warn from employees | where job_positions == "Accountant" | keep emp_no, job_positions | sort emp_no; -warning:Line 1:24: evaluation of [job_positions == \"Accountant\"] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[job_positions == \\\"Accountant\\\"\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued job_positions aren't included because they aren't less than or greater than C - that comparison is null emp_no:integer |job_positions:keyword @@ -575,8 +575,8 @@ equalToOrEqualToMultivalue required_capability: mv_warn from employees | where job_positions == "Accountant" or job_positions == "Tech Lead" | keep emp_no, job_positions | sort emp_no; -warning:Line 1:24: evaluation of [job_positions] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[job_positions\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued job_positions aren't included because they aren't less than or greater than C - that comparison is null emp_no:integer |job_positions:keyword @@ -588,8 +588,8 @@ inMultivalue required_capability: mv_warn from employees | where job_positions in ("Accountant", "Tech Lead") | keep emp_no, job_positions | sort emp_no; -warning:Line 1:24: evaluation of [job_positions in (\"Accountant\", \"Tech Lead\")] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[job_positions in \(\\\"Accountant\\\", \\"Tech Lead\\\"\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued job_positions aren't included because they aren't less than or greater than C - that comparison is null emp_no:integer |job_positions:keyword @@ -601,8 +601,8 @@ notLessThanMultivalue required_capability: mv_warn from employees | where not(job_positions < "C") | keep emp_no, job_positions | sort emp_no | limit 6; -warning:Line 1:24: evaluation of [not(job_positions < \"C\")] failed, treating result as null. Only first 20 failures recorded.#[Emulated:Line 1:28: evaluation of [job_positions < \"C\"] failed, treating result as null. Only first 20 failures recorded.] -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value#[Emulated:Line 1:28: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[.*job_positions < \\\"C\\\".*\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued job_positions aren't included because they aren't less than or greater than C - that comparison is null emp_no:integer |job_positions:keyword @@ -618,8 +618,8 @@ notGreaterThanMultivalue required_capability: mv_warn from employees | where not(job_positions > "C") | keep emp_no, job_positions | sort emp_no | limit 6; -warning:Line 1:24: evaluation of [not(job_positions > \"C\")] failed, treating result as null. Only first 20 failures recorded.#[Emulated:Line 1:28: evaluation of [job_positions > \"C\"] failed, treating result as null. Only first 20 failures recorded.] -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value#[Emulated:Line 1:28: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[.*job_positions > \\\"C\\\".*\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued job_positions aren't included because they aren't less than or greater than C - that comparison is null emp_no:integer |job_positions:keyword @@ -631,8 +631,8 @@ notEqualToMultivalue required_capability: mv_warn from employees | where not(job_positions == "Accountant") | keep emp_no, job_positions | sort emp_no | limit 6; -warning:Line 1:24: evaluation of [not(job_positions == \"Accountant\")] failed, treating result as null. Only first 20 failures recorded.#[Emulated:Line 1:28: evaluation of [job_positions == \"Accountant\"] failed, treating result as null. Only first 20 failures recorded.] -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value#[Emulated:Line 1:28: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[.*job_positions == \\\"Accountant\\\".*\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value // Note that multivalued job_positions aren't included because they aren't less than or greater than C - that comparison is null emp_no:integer |job_positions:keyword diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/unsigned_long.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/unsigned_long.csv-spec index 38f3d439e7504..03d0b71894d9b 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/unsigned_long.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/unsigned_long.csv-spec @@ -49,8 +49,8 @@ filterPushDownGT required_capability: mv_warn from ul_logs | where bytes_in >= to_ul(74330435873664882) | sort bytes_in | eval div = bytes_in / to_ul(pow(10., 15)) | keep bytes_in, div, id | limit 12; -warning:Line 1:22: evaluation of [bytes_in >= to_ul(74330435873664882)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:22: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[bytes_in >= to_ul\(74330435873664882\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value bytes_in:ul | div:ul |id:i 74330435873664882 |74 |82 @@ -71,10 +71,8 @@ filterPushDownRange required_capability: mv_warn from ul_logs | where bytes_in >= to_ul(74330435873664882) | where bytes_in <= to_ul(316080452389500167) | sort bytes_in | eval div = bytes_in / to_ul(pow(10., 15)) | keep bytes_in, div, id | limit 12; -warning:Line 1:22: evaluation of [bytes_in >= to_ul(74330435873664882)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:22: java.lang.IllegalArgumentException: single-value function encountered multi-value -warning:#[Emulated:Line 1:67: evaluation of [bytes_in <= to_ul(316080452389500167)] failed, treating result as null. Only first 20 failures recorded.] -warning:#[Emulated:Line 1:67: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[bytes_in .* to_ul\(.*\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value bytes_in:ul | div:ul |id:i 74330435873664882 |74 |82 @@ -88,8 +86,8 @@ required_capability: mv_warn // TODO: testing framework doesn't perform implicit conversion to UL of given values, needs explicit conversion from ul_logs | where bytes_in in (to_ul(74330435873664882), to_ul(154551962150890564), to_ul(195161570976258241)) | sort bytes_in | keep bytes_in, id; -warning:Line 1:22: evaluation of [bytes_in in (to_ul(74330435873664882), to_ul(154551962150890564), to_ul(195161570976258241))] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:22: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[bytes_in in \(to_ul\(74330435873664882\), to_ul\(154551962150890564\), to_ul\(195161570976258241\)\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value bytes_in:ul |id:i 74330435873664882 |82 @@ -101,8 +99,8 @@ filterOnFieldsEquality required_capability: mv_warn from ul_logs | where bytes_in == bytes_out; -warning:Line 1:22: evaluation of [bytes_in == bytes_out] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:22: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[bytes_in == bytes_out\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value @timestamp:date | bytes_in:ul | bytes_out:ul | id:i | status:k 2017-11-10T21:12:17.000Z|16002960716282089759|16002960716282089759|34 |OK @@ -112,8 +110,8 @@ filterOnFieldsInequality required_capability: mv_warn from ul_logs | sort id | where bytes_in < bytes_out | eval b_in = bytes_in / to_ul(pow(10.,15)), b_out = bytes_out / to_ul(pow(10.,15)) | limit 5; -warning:Line 1:32: evaluation of [bytes_in < bytes_out] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:32: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[bytes_in < bytes_out\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value @timestamp:date | bytes_in:ul | bytes_out:ul | id:i | status:k | b_in:ul | b_out:ul 2017-11-10T21:15:54.000Z|4348801185987554667 |12749081495402663265|1 |OK |4348 |12749 @@ -143,8 +141,8 @@ case required_capability: mv_warn from ul_logs | where case(bytes_in == to_ul(154551962150890564), true, false); -warning:Line 1:27: evaluation of [bytes_in == to_ul(154551962150890564)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:27: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[bytes_in == to_ul\(154551962150890564\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value @timestamp:date | bytes_in:ul | bytes_out:ul | id:i | status:k 2017-11-10T20:21:58.000Z|154551962150890564|9382204513185396493|63 |OK @@ -155,8 +153,8 @@ required_capability: mv_warn FROM ul_logs | WHERE bytes_in == bytes_out | EVAL deg = TO_DEGREES(bytes_in) | KEEP bytes_in, deg ; -warning:Line 1:22: evaluation of [bytes_in == bytes_out] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:22: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[bytes_in == bytes_out\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value bytes_in:ul | deg:double 16002960716282089759 | 9.169021087566165E20 @@ -167,8 +165,8 @@ required_capability: mv_warn FROM ul_logs | WHERE bytes_in == bytes_out | EVAL rad = TO_RADIANS(bytes_in) | KEEP bytes_in, rad ; -warning:Line 1:22: evaluation of [bytes_in == bytes_out] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:22: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[bytes_in == bytes_out\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value bytes_in:ul | rad:double 16002960716282089759 | 2.79304354566432608E17 @@ -197,8 +195,8 @@ keep s, bytes_in, bytes_out | sort bytes_out, s | limit 2; -warning:Line 2:7: evaluation of [signum(bytes_in)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 2:7: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[signum\(bytes_in\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value s:double | bytes_in:ul | bytes_out:ul 1.0 | 1957665857956635540 | 352442273299370793 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/where-like.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/where-like.csv-spec index 160fc46dafcf2..2a62117be8169 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/where-like.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/where-like.csv-spec @@ -302,8 +302,8 @@ FROM sample_data multiValueLike#[skip:-8.12.99] from employees | where job_positions like "Account*" | keep emp_no, job_positions; -warning:Line 1:24: evaluation of [job_positions like \"Account*\"] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[job_positions like \\\"Account\*\\\"\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value emp_no:integer | job_positions:keyword 10025 | Accountant @@ -313,8 +313,8 @@ emp_no:integer | job_positions:keyword multiValueRLike#[skip:-8.12.99] from employees | where job_positions rlike "Account.*" | keep emp_no, job_positions; -warning:Line 1:24: evaluation of [job_positions rlike \"Account.*\"] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:24: java.lang.IllegalArgumentException: single-value function encountered multi-value +warningRegex:evaluation of \[job_positions rlike \\\"Account.*\\\"\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value emp_no:integer | job_positions:keyword 10025 | Accountant diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java index d3471450e4728..9778756176574 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java @@ -323,7 +323,10 @@ private void assertCancelled(ActionFuture response) throws Ex * or the cancellation chained from another cancellation and has * "task cancelled". */ - assertThat(cancelException.getMessage(), either(equalTo("test cancel")).or(equalTo("task cancelled"))); + assertThat( + cancelException.getMessage(), + either(equalTo("test cancel")).or(equalTo("task cancelled")).or(equalTo("request cancelled test cancel")) + ); assertBusy( () -> assertThat( client().admin() diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java index 26ffdf0e13ccd..d7c15ad07e350 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/TimeSeriesIT.java @@ -8,14 +8,24 @@ package org.elasticsearch.xpack.esql.action; import org.elasticsearch.Build; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Rounding; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.mapper.DateFieldMapper; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.core.esql.action.ColumnInfo; import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.junit.Before; +import java.time.ZoneOffset; import java.util.ArrayList; import java.util.Comparator; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Objects; +import static org.elasticsearch.index.mapper.DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -28,77 +38,601 @@ protected EsqlQueryResponse run(EsqlQueryRequest request) { } public void testEmpty() { - Settings settings = Settings.builder().put("mode", "time_series").putList("routing_path", List.of("pod")).build(); + Settings settings = Settings.builder().put("mode", "time_series").putList("routing_path", List.of("host")).build(); client().admin() .indices() - .prepareCreate("pods") + .prepareCreate("empty_index") .setSettings(settings) .setMapping( "@timestamp", "type=date", - "pod", + "host", "type=keyword,time_series_dimension=true", "cpu", "type=long,time_series_metric=gauge" ) .get(); - run("METRICS pods | LIMIT 1").close(); + run("METRICS empty_index | LIMIT 1").close(); } - public void testSimpleMetrics() { - Settings settings = Settings.builder().put("mode", "time_series").putList("routing_path", List.of("pod")).build(); + record Doc(String host, String cluster, long timestamp, int requestCount, double cpu) {} + + final List docs = new ArrayList<>(); + + record RequestCounter(long timestamp, long count) { + + } + + static Double computeRate(List values) { + List sorted = values.stream().sorted(Comparator.comparingLong(RequestCounter::timestamp)).toList(); + if (sorted.size() < 2) { + return null; + } + long resets = 0; + for (int i = 0; i < sorted.size() - 1; i++) { + if (sorted.get(i).count > sorted.get(i + 1).count) { + resets += sorted.get(i).count; + } + } + RequestCounter last = sorted.get(sorted.size() - 1); + RequestCounter first = sorted.get(0); + double dv = resets + last.count - first.count; + double dt = last.timestamp - first.timestamp; + return dv * 1000 / dt; + } + + @Before + public void populateIndex() { + // this can be expensive, do one + Settings settings = Settings.builder().put("mode", "time_series").putList("routing_path", List.of("host", "cluster")).build(); client().admin() .indices() - .prepareCreate("pods") + .prepareCreate("hosts") .setSettings(settings) .setMapping( "@timestamp", "type=date", - "pod", + "host", + "type=keyword,time_series_dimension=true", + "cluster", "type=keyword,time_series_dimension=true", "cpu", - "type=double,time_series_metric=gauge" + "type=double,time_series_metric=gauge", + "request_count", + "type=integer,time_series_metric=counter" ) .get(); - List pods = List.of("p1", "p2", "p3"); - long startTime = DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER.parseMillis("2024-04-15T00:00:00Z"); - int numDocs = between(10, 100); - record Doc(String pod, long timestamp, double cpu) {} - List docs = new ArrayList<>(); + Map hostToClusters = new HashMap<>(); + for (int i = 0; i < 5; i++) { + hostToClusters.put("p" + i, randomFrom("qa", "prod")); + } + long timestamp = DEFAULT_DATE_TIME_FORMATTER.parseMillis("2024-04-15T00:00:00Z"); + int numDocs = between(20, 100); + docs.clear(); + Map requestCounts = new HashMap<>(); for (int i = 0; i < numDocs; i++) { - String pod = randomFrom(pods); - int cpu = randomIntBetween(0, 100); - long timestamp = startTime + (1000L * i); - docs.add(new Doc(pod, timestamp, cpu)); - client().prepareIndex("pods").setSource("@timestamp", timestamp, "pod", pod, "cpu", cpu).get(); - } - List sortedGroups = docs.stream().map(d -> d.pod).distinct().sorted().toList(); - client().admin().indices().prepareRefresh("pods").get(); - try (EsqlQueryResponse resp = run("METRICS pods load=avg(cpu) BY pod | SORT pod")) { + List hosts = randomSubsetOf(between(1, hostToClusters.size()), hostToClusters.keySet()); + timestamp += between(1, 10) * 1000L; + for (String host : hosts) { + var requestCount = requestCounts.compute(host, (k, curr) -> { + if (curr == null || randomInt(100) <= 20) { + return randomIntBetween(0, 10); + } else { + return curr + randomIntBetween(1, 10); + } + }); + int cpu = randomIntBetween(0, 100); + docs.add(new Doc(host, hostToClusters.get(host), timestamp, requestCount, cpu)); + } + } + Randomness.shuffle(docs); + for (Doc doc : docs) { + client().prepareIndex("hosts") + .setSource( + "@timestamp", + doc.timestamp, + "host", + doc.host, + "cluster", + doc.cluster, + "cpu", + doc.cpu, + "request_count", + doc.requestCount + ) + .get(); + } + client().admin().indices().prepareRefresh("hosts").get(); + } + + public void testSimpleMetrics() { + List sortedGroups = docs.stream().map(d -> d.host).distinct().sorted().toList(); + client().admin().indices().prepareRefresh("hosts").get(); + try (EsqlQueryResponse resp = run("METRICS hosts load=avg(cpu) BY host | SORT host")) { List> rows = EsqlTestUtils.getValuesList(resp); assertThat(rows, hasSize(sortedGroups.size())); for (int i = 0; i < rows.size(); i++) { List r = rows.get(i); String pod = (String) r.get(1); assertThat(pod, equalTo(sortedGroups.get(i))); - List values = docs.stream().filter(d -> d.pod.equals(pod)).map(d -> d.cpu).toList(); + List values = docs.stream().filter(d -> d.host.equals(pod)).map(d -> d.cpu).toList(); double avg = values.stream().mapToDouble(n -> n).sum() / values.size(); assertThat((double) r.get(0), equalTo(avg)); } } - try (EsqlQueryResponse resp = run("METRICS pods | SORT @timestamp DESC | KEEP @timestamp, pod, cpu | LIMIT 5")) { + try (EsqlQueryResponse resp = run("METRICS hosts | SORT @timestamp DESC, host | KEEP @timestamp, host, cpu | LIMIT 5")) { List> rows = EsqlTestUtils.getValuesList(resp); - List topDocs = docs.stream().sorted(Comparator.comparingLong(Doc::timestamp).reversed()).limit(5).toList(); + List topDocs = docs.stream() + .sorted(Comparator.comparingLong(Doc::timestamp).reversed().thenComparing(Doc::host)) + .limit(5) + .toList(); assertThat(rows, hasSize(topDocs.size())); for (int i = 0; i < rows.size(); i++) { List r = rows.get(i); - long timestamp = DateFieldMapper.DEFAULT_DATE_TIME_FORMATTER.parseMillis((String) r.get(0)); + long timestamp = DEFAULT_DATE_TIME_FORMATTER.parseMillis((String) r.get(0)); String pod = (String) r.get(1); double cpu = (Double) r.get(2); assertThat(topDocs.get(i).timestamp, equalTo(timestamp)); - assertThat(topDocs.get(i).pod, equalTo(pod)); + assertThat(topDocs.get(i).host, equalTo(pod)); assertThat(topDocs.get(i).cpu, equalTo(cpu)); } } } + + public void testRateWithoutGrouping() { + record RateKey(String cluster, String host) { + + } + Map> groups = new HashMap<>(); + for (Doc doc : docs) { + RateKey key = new RateKey(doc.cluster, doc.host); + groups.computeIfAbsent(key, k -> new ArrayList<>()).add(new RequestCounter(doc.timestamp, doc.requestCount)); + } + List rates = new ArrayList<>(); + for (List group : groups.values()) { + Double v = computeRate(group); + if (v != null) { + rates.add(v); + } + } + try (var resp = run("METRICS hosts sum(rate(request_count, 1second))")) { + assertThat(resp.columns(), equalTo(List.of(new ColumnInfo("sum(rate(request_count, 1second))", "double")))); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(1)); + assertThat(values.get(0), hasSize(1)); + assertThat((double) values.get(0).get(0), closeTo(rates.stream().mapToDouble(d -> d).sum(), 0.1)); + } + try (var resp = run("METRICS hosts max(rate(request_count)), min(rate(request_count))")) { + assertThat( + resp.columns(), + equalTo(List.of(new ColumnInfo("max(rate(request_count))", "double"), new ColumnInfo("min(rate(request_count))", "double"))) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(1)); + assertThat(values.get(0), hasSize(2)); + assertThat((double) values.get(0).get(0), closeTo(rates.stream().mapToDouble(d -> d).max().orElse(0.0), 0.1)); + assertThat((double) values.get(0).get(1), closeTo(rates.stream().mapToDouble(d -> d).min().orElse(0.0), 0.1)); + } + try (var resp = run("METRICS hosts max(rate(request_count)), avg(rate(request_count)), max(rate(request_count, 1minute))")) { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfo("max(rate(request_count))", "double"), + new ColumnInfo("avg(rate(request_count))", "double"), + new ColumnInfo("max(rate(request_count, 1minute))", "double") + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(1)); + assertThat(values.get(0), hasSize(3)); + assertThat((double) values.get(0).get(0), closeTo(rates.stream().mapToDouble(d -> d).max().orElse(0.0), 0.1)); + final double avg = rates.isEmpty() ? 0.0 : rates.stream().mapToDouble(d -> d).sum() / rates.size(); + assertThat((double) values.get(0).get(1), closeTo(avg, 0.1)); + assertThat((double) values.get(0).get(2), closeTo(rates.stream().mapToDouble(d -> d * 60.0).max().orElse(0.0), 0.1)); + } + try (var resp = run("METRICS hosts avg(rate(request_count)), avg(rate(request_count, 1second))")) { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfo("avg(rate(request_count))", "double"), + new ColumnInfo("avg(rate(request_count, 1second))", "double") + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(1)); + assertThat(values.get(0), hasSize(2)); + final double avg = rates.isEmpty() ? 0.0 : rates.stream().mapToDouble(d -> d).sum() / rates.size(); + assertThat((double) values.get(0).get(0), closeTo(avg, 0.1)); + assertThat((double) values.get(0).get(1), closeTo(avg, 0.1)); + } + } + + public void testRateGroupedByCluster() { + record RateKey(String cluster, String host) { + + } + Map> groups = new HashMap<>(); + for (Doc doc : docs) { + RateKey key = new RateKey(doc.cluster, doc.host); + groups.computeIfAbsent(key, k -> new ArrayList<>()).add(new RequestCounter(doc.timestamp, doc.requestCount)); + } + Map> bucketToRates = new HashMap<>(); + for (Map.Entry> e : groups.entrySet()) { + List values = bucketToRates.computeIfAbsent(e.getKey().cluster, k -> new ArrayList<>()); + Double rate = computeRate(e.getValue()); + values.add(Objects.requireNonNullElse(rate, 0.0)); + } + List sortedKeys = bucketToRates.keySet().stream().sorted().toList(); + try (var resp = run("METRICS hosts sum(rate(request_count)) BY cluster | SORT cluster")) { + assertThat( + resp.columns(), + equalTo(List.of(new ColumnInfo("sum(rate(request_count))", "double"), new ColumnInfo("cluster", "keyword"))) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(bucketToRates.size())); + for (int i = 0; i < bucketToRates.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(2)); + String key = sortedKeys.get(i); + assertThat(row.get(1), equalTo(key)); + assertThat((double) row.get(0), closeTo(bucketToRates.get(key).stream().mapToDouble(d -> d).sum(), 0.1)); + } + } + try (var resp = run("METRICS hosts avg(rate(request_count)) BY cluster | SORT cluster")) { + assertThat( + resp.columns(), + equalTo(List.of(new ColumnInfo("avg(rate(request_count))", "double"), new ColumnInfo("cluster", "keyword"))) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(bucketToRates.size())); + for (int i = 0; i < bucketToRates.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(2)); + String key = sortedKeys.get(i); + assertThat(row.get(1), equalTo(key)); + List rates = bucketToRates.get(key); + if (rates.isEmpty()) { + assertThat(row.get(0), equalTo(0.0)); + } else { + double avg = rates.stream().mapToDouble(d -> d).sum() / rates.size(); + assertThat((double) row.get(0), closeTo(avg, 0.1)); + } + } + } + try (var resp = run("METRICS hosts avg(rate(request_count, 1minute)), avg(rate(request_count)) BY cluster | SORT cluster")) { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfo("avg(rate(request_count, 1minute))", "double"), + new ColumnInfo("avg(rate(request_count))", "double"), + new ColumnInfo("cluster", "keyword") + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(bucketToRates.size())); + for (int i = 0; i < bucketToRates.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(3)); + String key = sortedKeys.get(i); + assertThat(row.get(2), equalTo(key)); + List rates = bucketToRates.get(key); + if (rates.isEmpty()) { + assertThat(row.get(0), equalTo(0.0)); + assertThat(row.get(1), equalTo(0.0)); + } else { + double avg = rates.stream().mapToDouble(d -> d).sum() / rates.size(); + assertThat((double) row.get(0), closeTo(avg * 60.0f, 0.1)); + assertThat((double) row.get(1), closeTo(avg, 0.1)); + } + } + } + } + + public void testRateWithTimeBucket() { + var rounding = new Rounding.Builder(TimeValue.timeValueSeconds(60)).timeZone(ZoneOffset.UTC).build().prepareForUnknown(); + record RateKey(String host, String cluster, long interval) {} + Map> groups = new HashMap<>(); + for (Doc doc : docs) { + RateKey key = new RateKey(doc.host, doc.cluster, rounding.round(doc.timestamp)); + groups.computeIfAbsent(key, k -> new ArrayList<>()).add(new RequestCounter(doc.timestamp, doc.requestCount)); + } + Map> bucketToRates = new HashMap<>(); + for (Map.Entry> e : groups.entrySet()) { + List values = bucketToRates.computeIfAbsent(e.getKey().interval, k -> new ArrayList<>()); + Double rate = computeRate(e.getValue()); + if (rate != null) { + values.add(rate); + } + } + List sortedKeys = bucketToRates.keySet().stream().sorted().limit(5).toList(); + try (var resp = run("METRICS hosts sum(rate(request_count)) BY ts=bucket(@timestamp, 1 minute) | SORT ts | LIMIT 5")) { + assertThat( + resp.columns(), + equalTo(List.of(new ColumnInfo("sum(rate(request_count))", "double"), new ColumnInfo("ts", "date"))) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(sortedKeys.size())); + for (int i = 0; i < sortedKeys.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(2)); + long key = sortedKeys.get(i); + assertThat(row.get(1), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key))); + List bucketValues = bucketToRates.get(key); + if (bucketValues.isEmpty()) { + assertNull(row.get(0)); + } else { + assertThat((double) row.get(0), closeTo(bucketValues.stream().mapToDouble(d -> d).sum(), 0.1)); + } + } + } + try (var resp = run("METRICS hosts avg(rate(request_count)) BY ts=bucket(@timestamp, 1minute) | SORT ts | LIMIT 5")) { + assertThat( + resp.columns(), + equalTo(List.of(new ColumnInfo("avg(rate(request_count))", "double"), new ColumnInfo("ts", "date"))) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(sortedKeys.size())); + for (int i = 0; i < sortedKeys.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(2)); + long key = sortedKeys.get(i); + assertThat(row.get(1), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key))); + List bucketValues = bucketToRates.get(key); + if (bucketValues.isEmpty()) { + assertNull(row.get(0)); + } else { + double avg = bucketValues.stream().mapToDouble(d -> d).sum() / bucketValues.size(); + assertThat((double) row.get(0), closeTo(avg, 0.1)); + } + } + } + try (var resp = run(""" + METRICS hosts avg(rate(request_count, 1minute)), avg(rate(request_count)) BY ts=bucket(@timestamp, 1minute) + | SORT ts + | LIMIT 5 + """)) { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfo("avg(rate(request_count, 1minute))", "double"), + new ColumnInfo("avg(rate(request_count))", "double"), + new ColumnInfo("ts", "date") + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(sortedKeys.size())); + for (int i = 0; i < sortedKeys.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(3)); + long key = sortedKeys.get(i); + assertThat(row.get(2), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key))); + List bucketValues = bucketToRates.get(key); + if (bucketValues.isEmpty()) { + assertNull(row.get(0)); + assertNull(row.get(1)); + } else { + double avg = bucketValues.stream().mapToDouble(d -> d).sum() / bucketValues.size(); + assertThat((double) row.get(0), closeTo(avg * 60.0f, 0.1)); + assertThat((double) row.get(1), closeTo(avg, 0.1)); + } + } + } + } + + public void testRateWithTimeBucketAndCluster() { + var rounding = new Rounding.Builder(TimeValue.timeValueSeconds(60)).timeZone(ZoneOffset.UTC).build().prepareForUnknown(); + record RateKey(String host, String cluster, long interval) {} + Map> groups = new HashMap<>(); + for (Doc doc : docs) { + RateKey key = new RateKey(doc.host, doc.cluster, rounding.round(doc.timestamp)); + groups.computeIfAbsent(key, k -> new ArrayList<>()).add(new RequestCounter(doc.timestamp, doc.requestCount)); + } + record GroupKey(String cluster, long interval) {} + Map> buckets = new HashMap<>(); + for (Map.Entry> e : groups.entrySet()) { + RateKey key = e.getKey(); + List values = buckets.computeIfAbsent(new GroupKey(key.cluster, key.interval), k -> new ArrayList<>()); + Double rate = computeRate(e.getValue()); + if (rate != null) { + values.add(rate); + } + } + List sortedKeys = buckets.keySet() + .stream() + .sorted(Comparator.comparing(GroupKey::interval).thenComparing(GroupKey::cluster)) + .limit(5) + .toList(); + try (var resp = run(""" + METRICS hosts sum(rate(request_count)) BY ts=bucket(@timestamp, 1 minute), cluster + | SORT ts, cluster + | LIMIT 5""")) { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfo("sum(rate(request_count))", "double"), + new ColumnInfo("ts", "date"), + new ColumnInfo("cluster", "keyword") + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(sortedKeys.size())); + for (int i = 0; i < sortedKeys.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(3)); + var key = sortedKeys.get(i); + assertThat(row.get(1), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key.interval))); + assertThat(row.get(2), equalTo(key.cluster)); + List bucketValues = buckets.get(key); + if (bucketValues.isEmpty()) { + assertNull(row.get(0)); + } else { + assertThat((double) row.get(0), closeTo(bucketValues.stream().mapToDouble(d -> d).sum(), 0.1)); + } + } + } + try (var resp = run(""" + METRICS hosts avg(rate(request_count)) BY ts=bucket(@timestamp, 1minute), cluster + | SORT ts, cluster + | LIMIT 5""")) { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfo("avg(rate(request_count))", "double"), + new ColumnInfo("ts", "date"), + new ColumnInfo("cluster", "keyword") + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(sortedKeys.size())); + for (int i = 0; i < sortedKeys.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(3)); + var key = sortedKeys.get(i); + assertThat(row.get(1), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key.interval))); + assertThat(row.get(2), equalTo(key.cluster)); + List bucketValues = buckets.get(key); + if (bucketValues.isEmpty()) { + assertNull(row.get(0)); + } else { + double avg = bucketValues.stream().mapToDouble(d -> d).sum() / bucketValues.size(); + assertThat((double) row.get(0), closeTo(avg, 0.1)); + } + } + } + try (var resp = run(""" + METRICS hosts avg(rate(request_count, 1minute)), avg(rate(request_count)) BY ts=bucket(@timestamp, 1minute), cluster + | SORT ts, cluster + | LIMIT 5""")) { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfo("avg(rate(request_count, 1minute))", "double"), + new ColumnInfo("avg(rate(request_count))", "double"), + new ColumnInfo("ts", "date"), + new ColumnInfo("cluster", "keyword") + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(sortedKeys.size())); + for (int i = 0; i < sortedKeys.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(4)); + var key = sortedKeys.get(i); + assertThat(row.get(2), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key.interval))); + assertThat(row.get(3), equalTo(key.cluster)); + List bucketValues = buckets.get(key); + if (bucketValues.isEmpty()) { + assertNull(row.get(0)); + assertNull(row.get(1)); + } else { + double avg = bucketValues.stream().mapToDouble(d -> d).sum() / bucketValues.size(); + assertThat((double) row.get(0), closeTo(avg * 60.0f, 0.1)); + assertThat((double) row.get(1), closeTo(avg, 0.1)); + } + } + } + try (var resp = run(""" + METRICS hosts + s = sum(rate(request_count)), + c = count(rate(request_count)), + max(rate(request_count)), + avg(rate(request_count)) + BY ts=bucket(@timestamp, 1minute), cluster + | SORT ts, cluster + | LIMIT 5 + | EVAL avg_rate= s/c + | KEEP avg_rate, `max(rate(request_count))`, `avg(rate(request_count))`, ts, cluster + """)) { + assertThat( + resp.columns(), + equalTo( + List.of( + new ColumnInfo("avg_rate", "double"), + new ColumnInfo("max(rate(request_count))", "double"), + new ColumnInfo("avg(rate(request_count))", "double"), + new ColumnInfo("ts", "date"), + new ColumnInfo("cluster", "keyword") + ) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(sortedKeys.size())); + for (int i = 0; i < sortedKeys.size(); i++) { + List row = values.get(i); + assertThat(row, hasSize(5)); + var key = sortedKeys.get(i); + assertThat(row.get(3), equalTo(DEFAULT_DATE_TIME_FORMATTER.formatMillis(key.interval))); + assertThat(row.get(4), equalTo(key.cluster)); + List bucketValues = buckets.get(key); + if (bucketValues.isEmpty()) { + assertNull(row.get(0)); + assertNull(row.get(1)); + } else { + double avg = bucketValues.stream().mapToDouble(d -> d).sum() / bucketValues.size(); + assertThat((double) row.get(0), closeTo(avg, 0.1)); + double max = bucketValues.stream().mapToDouble(d -> d).max().orElse(0.0); + assertThat((double) row.get(1), closeTo(max, 0.1)); + } + assertEquals(row.get(0), row.get(2)); + } + } + } + + public void testApplyRateBeforeFinalGrouping() { + record RateKey(String cluster, String host) { + + } + Map> groups = new HashMap<>(); + for (Doc doc : docs) { + RateKey key = new RateKey(doc.cluster, doc.host); + groups.computeIfAbsent(key, k -> new ArrayList<>()).add(new RequestCounter(doc.timestamp, doc.requestCount)); + } + List rates = new ArrayList<>(); + for (List group : groups.values()) { + Double v = computeRate(group); + if (v != null) { + rates.add(v); + } + } + try (var resp = run("METRICS hosts sum(abs(rate(request_count, 1second)))")) { + assertThat(resp.columns(), equalTo(List.of(new ColumnInfo("sum(abs(rate(request_count, 1second)))", "double")))); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(1)); + assertThat(values.get(0), hasSize(1)); + assertThat((double) values.get(0).get(0), closeTo(rates.stream().mapToDouble(d -> d).sum(), 0.1)); + } + try (var resp = run("METRICS hosts sum(10.0 * rate(request_count, 1second))")) { + assertThat(resp.columns(), equalTo(List.of(new ColumnInfo("sum(10.0 * rate(request_count, 1second))", "double")))); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(1)); + assertThat(values.get(0), hasSize(1)); + assertThat((double) values.get(0).get(0), closeTo(rates.stream().mapToDouble(d -> d * 10.0).sum(), 0.1)); + } + try (var resp = run("METRICS hosts sum(20 * rate(request_count, 1second) + 10 * floor(rate(request_count, 1second)))")) { + assertThat( + resp.columns(), + equalTo( + List.of(new ColumnInfo("sum(20 * rate(request_count, 1second) + 10 * floor(rate(request_count, 1second)))", "double")) + ) + ); + List> values = EsqlTestUtils.getValuesList(resp); + assertThat(values, hasSize(1)); + assertThat(values.get(0), hasSize(1)); + assertThat((double) values.get(0).get(0), closeTo(rates.stream().mapToDouble(d -> 20. * d + 10.0 * Math.floor(d)).sum(), 0.1)); + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index f30f7f84128c9..1caf94dde5c30 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -60,9 +60,11 @@ public enum Cap { METADATA_IGNORED_FIELD, /** - * Support for the syntax {@code "tables": {"type": []}}. + * LOOKUP command with + * - tables using syntax {@code "tables": {"type": []}} + * - fixed variable shadowing */ - TABLES_TYPES(true), + LOOKUP_V3(true), /** * Support for requesting the "REPEAT" command. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/ResponseValueUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/ResponseValueUtils.java index 98f2bbf95d3de..70ec7504ed3d2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/ResponseValueUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/ResponseValueUtils.java @@ -27,8 +27,8 @@ import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.esql.action.ColumnInfo; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.planner.PlannerUtils; -import org.elasticsearch.xpack.esql.type.EsqlDataTypes; import java.io.IOException; import java.io.UncheckedIOException; @@ -163,7 +163,7 @@ private static Object valueAt(String dataType, Block block, int offset, BytesRef static Page valuesToPage(BlockFactory blockFactory, List columns, List> values) { List dataTypes = columns.stream().map(ColumnInfo::type).toList(); List results = dataTypes.stream() - .map(c -> PlannerUtils.toElementType(EsqlDataTypes.fromName(c)).newBlockBuilder(values.size(), blockFactory)) + .map(c -> PlannerUtils.toElementType(DataType.fromEs(c)).newBlockBuilder(values.size(), blockFactory)) .toList(); for (List row : values) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 35aa8a6d42cca..0d556efbea5db 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -451,7 +451,7 @@ private LogicalPlan resolveAggregate(Aggregate a, List childrenOutput } groupings = newGroupings; if (changed.get()) { - a = new EsqlAggregate(a.source(), a.child(), newGroupings, a.aggregates()); + a = new EsqlAggregate(a.source(), a.child(), a.aggregateType(), newGroupings, a.aggregates()); changed.set(false); } } @@ -480,7 +480,7 @@ private LogicalPlan resolveAggregate(Aggregate a, List childrenOutput newAggregates.add(agg); } - a = changed.get() ? new EsqlAggregate(a.source(), a.child(), groupings, newAggregates) : a; + a = changed.get() ? new EsqlAggregate(a.source(), a.child(), a.aggregateType(), groupings, newAggregates) : a; } return a; @@ -519,13 +519,13 @@ private LogicalPlan resolveLookup(Lookup l, List childrenOutput) { } // check the on field against both the child output and the inner relation - List matchFields = new ArrayList<>(l.matchFields().size()); + List matchFields = new ArrayList<>(l.matchFields().size()); List localOutput = l.localRelation().output(); boolean modified = false; - for (NamedExpression ne : l.matchFields()) { - NamedExpression matchFieldChildReference = ne; - if (ne instanceof UnresolvedAttribute ua && ua.customMessage() == false) { + for (Attribute matchField : l.matchFields()) { + Attribute matchFieldChildReference = matchField; + if (matchField instanceof UnresolvedAttribute ua && ua.customMessage() == false) { modified = true; Attribute joinedAttribute = maybeResolveAttribute(ua, localOutput); // can't find the field inside the local relation diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index e36b5f7b9d69c..514a53b0933e9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; @@ -225,6 +226,14 @@ private static void checkAggregate(LogicalPlan p, Set failures) { // traverse the tree to find invalid matches checkInvalidNamedExpressionUsage(exp, groupings, groupRefs, failures, 0); }); + if (agg.aggregateType() == Aggregate.AggregateType.METRICS) { + aggs.forEach(a -> checkRateAggregates(a, 0, failures)); + } else { + agg.forEachExpression( + Rate.class, + r -> failures.add(fail(r, "the rate aggregate[{}] can only be used within the metrics command", r.sourceText())) + ); + } } else { p.forEachExpression( GroupingFunction.class, @@ -233,6 +242,26 @@ private static void checkAggregate(LogicalPlan p, Set failures) { } } + private static void checkRateAggregates(Expression expr, int nestedLevel, Set failures) { + if (expr instanceof AggregateFunction) { + nestedLevel++; + } + if (expr instanceof Rate r) { + if (nestedLevel != 2) { + failures.add( + fail( + expr, + "the rate aggregate [{}] can only be used within the metrics command and inside another aggregate", + r.sourceText() + ) + ); + } + } + for (Expression child : expr.children()) { + checkRateAggregates(child, nestedLevel, failures); + } + } + // traverse the expression and look either for an agg function or a grouping match // stop either when no children are left, the leafs are literals or a reference attribute is given private static void checkInvalidNamedExpressionUsage( @@ -245,7 +274,10 @@ private static void checkInvalidNamedExpressionUsage( // found an aggregate, constant or a group, bail out if (e instanceof AggregateFunction af) { af.field().forEachDown(AggregateFunction.class, f -> { - failures.add(fail(f, "nested aggregations [{}] not allowed inside other aggregations [{}]", f, af)); + // rate aggregate is allowed to be inside another aggregate + if (f instanceof Rate == false) { + failures.add(fail(f, "nested aggregations [{}] not allowed inside other aggregations [{}]", f, af)); + } }); } else if (e instanceof GroupingFunction gf) { // optimizer will later unroll expressions with aggs and non-aggs with a grouping function into an EVAL, but that will no longer @@ -337,7 +369,11 @@ private static void checkEvalFields(LogicalPlan p, Set failures) { } // check no aggregate functions are used field.forEachDown(AggregateFunction.class, af -> { - failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText())); + if (af instanceof Rate) { + failures.add(fail(af, "aggregate function [{}] not allowed outside METRICS command", af.sourceText())); + } else { + failures.add(fail(af, "aggregate function [{}] not allowed outside STATS command", af.sourceText())); + } }); }); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java index 05b78c8b5f309..87c558fe5bd1e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java @@ -299,7 +299,7 @@ private void doLookup( ); BlockLoader loader = ctx.blockLoader( extractField instanceof Alias a ? ((NamedExpression) a.child()).name() : extractField.name(), - EsqlDataTypes.isUnsupported(extractField.dataType()), + extractField.dataType() == DataType.UNSUPPORTED, MappedFieldType.FieldExtractPreference.NONE ); fields.add( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/ResolvedEnrichPolicy.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/ResolvedEnrichPolicy.java index e53d11854cc63..44443973764e6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/ResolvedEnrichPolicy.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/ResolvedEnrichPolicy.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xpack.esql.core.type.EsField; -import org.elasticsearch.xpack.esql.type.EsqlDataTypes; import java.io.IOException; import java.util.List; @@ -30,7 +29,7 @@ public ResolvedEnrichPolicy(StreamInput in) throws IOException { in.readString(), in.readStringCollectionAsList(), in.readMap(StreamInput::readString), - in.readMap(StreamInput::readString, ResolvedEnrichPolicy::readEsField) + in.readMap(EsField::new) ); } @@ -40,25 +39,13 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(matchType); out.writeStringCollection(enrichFields); out.writeMap(concreteIndices, StreamOutput::writeString); - out.writeMap(mapping, ResolvedEnrichPolicy::writeEsField); - } - - // TODO: we should have made EsField and DataType Writable, but write it as NamedWritable in PlanStreamInput - private static void writeEsField(StreamOutput out, EsField field) throws IOException { - out.writeString(field.getName()); - out.writeString(field.getDataType().typeName()); - out.writeMap(field.getProperties(), ResolvedEnrichPolicy::writeEsField); - out.writeBoolean(field.isAggregatable()); - out.writeBoolean(field.isAlias()); - } - - private static EsField readEsField(StreamInput in) throws IOException { - return new EsField( - in.readString(), - EsqlDataTypes.fromTypeName(in.readString()), - in.readMap(ResolvedEnrichPolicy::readEsField), - in.readBoolean(), - in.readBoolean() + out.writeMap( + mapping, + /* + * There are lots of subtypes of ESField, but we always write the field + * as though it were the base class. + */ + (o, v) -> new EsField(v.getName(), v.getDataType(), v.getProperties(), v.isAggregatable(), v.isAlias()).writeTo(o) ); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index 0df1ae078171d..7c4c8d63ea96f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -6,12 +6,18 @@ */ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.util.CollectionUtils; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import java.io.IOException; import java.util.List; import java.util.Objects; @@ -23,6 +29,23 @@ * A type of {@code Function} that takes multiple values and extracts a single value out of them. For example, {@code AVG()}. */ public abstract class AggregateFunction extends Function { + public static List getNamedWriteables() { + return List.of( + Avg.ENTRY, + Count.ENTRY, + CountDistinct.ENTRY, + Max.ENTRY, + Median.ENTRY, + MedianAbsoluteDeviation.ENTRY, + Min.ENTRY, + Percentile.ENTRY, + SpatialCentroid.ENTRY, + Sum.ENTRY, + TopList.ENTRY, + Values.ENTRY, + Rate.ENTRY + ); + } private final Expression field; private final List parameters; @@ -37,6 +60,16 @@ protected AggregateFunction(Source source, Expression field, List info() { return NodeInfo.create(this, Count::new, field()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java index c91b9c37ae0a3..9700cb2330d1c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.CountDistinctBooleanAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.CountDistinctBytesRefAggregatorFunctionSupplier; @@ -28,8 +31,11 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvDedupe; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.planner.ToAggregator; +import java.io.IOException; import java.util.List; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; @@ -39,6 +45,12 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; public class CountDistinct extends AggregateFunction implements OptionalArgument, ToAggregator, SurrogateExpression { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "CountDistinct", + CountDistinct::new + ); + private static final int DEFAULT_PRECISION = 3000; private final Expression precision; @@ -56,6 +68,26 @@ public CountDistinct( this.precision = precision; } + private CountDistinct(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + ((PlanStreamInput) in).readExpression(), + in.readOptionalWriteable(i -> ((PlanStreamInput) i).readExpression()) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + ((PlanStreamOutput) out).writeExpression(field()); + ((PlanStreamOutput) out).writeOptionalExpression(precision); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, CountDistinct::new, field(), precision); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java index 1c1139c197ac0..97a6f6b4b5e1f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MaxDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MaxIntAggregatorFunctionSupplier; @@ -20,9 +22,11 @@ import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax; +import java.io.IOException; import java.util.List; public class Max extends NumericAggregate implements SurrogateExpression { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Max", Max::new); @FunctionInfo( returnType = { "double", "integer", "long", "date" }, @@ -33,6 +37,15 @@ public Max(Source source, @Param(name = "number", type = { "double", "integer", super(source, field); } + private Max(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, Max::new, field()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java index c381693dbe2ce..36207df331e47 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.QuantileStates; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -19,12 +21,15 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedian; +import java.io.IOException; import java.util.List; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; public class Median extends AggregateFunction implements SurrogateExpression { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Median", Median::new); + // TODO: Add the compression parameter @FunctionInfo( returnType = { "double", "integer", "long" }, @@ -46,6 +51,15 @@ protected Expression.TypeResolution resolveType() { ); } + private Median(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public DataType dataType() { return DataType.DOUBLE; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java index db25ad6c8c41f..23d55942cc72f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviation.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MedianAbsoluteDeviationDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MedianAbsoluteDeviationIntAggregatorFunctionSupplier; @@ -17,9 +19,15 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; +import java.io.IOException; import java.util.List; public class MedianAbsoluteDeviation extends NumericAggregate { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "MedianAbsoluteDeviation", + MedianAbsoluteDeviation::new + ); // TODO: Add parameter @FunctionInfo( @@ -31,6 +39,15 @@ public MedianAbsoluteDeviation(Source source, @Param(name = "number", type = { " super(source, field); } + private MedianAbsoluteDeviation(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, MedianAbsoluteDeviation::new, field()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java index ecfc2200a3643..2dd3e973937f5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MinDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MinIntAggregatorFunctionSupplier; @@ -20,9 +22,11 @@ import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin; +import java.io.IOException; import java.util.List; public class Min extends NumericAggregate implements SurrogateExpression { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Min", Min::new); @FunctionInfo( returnType = { "double", "integer", "long", "date" }, @@ -33,6 +37,15 @@ public Min(Source source, @Param(name = "number", type = { "double", "integer", super(source, field); } + private Min(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, Min::new, field()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java index 390cd0d68018e..e7825a1d11704 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/NumericAggregate.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -14,6 +15,7 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.planner.ToAggregator; +import java.io.IOException; import java.util.List; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; @@ -51,6 +53,10 @@ public abstract class NumericAggregate extends AggregateFunction implements ToAg super(source, field); } + NumericAggregate(StreamInput in) throws IOException { + super(in); + } + @Override protected TypeResolution resolveType() { if (supportsDates()) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java index d21247a77d9cf..e2156f4d3b97d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.PercentileDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.PercentileIntAggregatorFunctionSupplier; @@ -17,7 +20,10 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import java.io.IOException; import java.util.List; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; @@ -27,6 +33,12 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; public class Percentile extends NumericAggregate { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "Percentile", + Percentile::new + ); + private final Expression percentile; @FunctionInfo( @@ -43,6 +55,22 @@ public Percentile( this.percentile = percentile; } + private Percentile(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), ((PlanStreamInput) in).readExpression(), ((PlanStreamInput) in).readExpression()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + ((PlanStreamOutput) out).writeExpression(children().get(0)); + ((PlanStreamOutput) out).writeExpression(children().get(1)); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, Percentile::new, field(), percentile); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java index cc65be77b9924..3d38c66119ead 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.RateDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.RateIntAggregatorFunctionSupplier; @@ -35,6 +38,7 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; public class Rate extends AggregateFunction implements OptionalArgument, ToAggregator { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Rate", Rate::readFrom); private static final TimeValue DEFAULT_UNIT = TimeValue.timeValueSeconds(1); private final Expression timestamp; @@ -61,19 +65,21 @@ public static Rate withUnresolvedTimestamp(Source source, Expression field, Expr return new Rate(source, field, new UnresolvedAttribute(source, "@timestamp"), unit); } - public static Rate readRate(PlanStreamInput in) throws IOException { - Source source = Source.readFrom(in); - Expression field = in.readExpression(); - Expression timestamp = in.readOptionalWriteable(i -> in.readExpression()); - Expression unit = in.readOptionalNamed(Expression.class); + private static Rate readFrom(StreamInput in) throws IOException { + PlanStreamInput planIn = (PlanStreamInput) in; + Source source = Source.readFrom(planIn); + Expression field = planIn.readExpression(); + Expression timestamp = planIn.readExpression(); + Expression unit = planIn.readOptionalNamed(Expression.class); return new Rate(source, field, timestamp, unit); } - public static void writeRate(PlanStreamOutput out, Rate rate) throws IOException { - rate.source().writeTo(out); - out.writeExpression(rate.field()); - out.writeExpression(rate.timestamp); - out.writeOptionalExpression(rate.unit); + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + PlanStreamOutput planOut = (PlanStreamOutput) out; + planOut.writeExpression(timestamp); + planOut.writeOptionalExpression(unit); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialAggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialAggregateFunction.java index 66a7e0ca436d6..d54d20eb4115f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialAggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialAggregateFunction.java @@ -7,9 +7,11 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; +import java.io.IOException; import java.util.Objects; /** @@ -25,6 +27,11 @@ protected SpatialAggregateFunction(Source source, Expression field, boolean useD this.useDocValues = useDocValues; } + protected SpatialAggregateFunction(StreamInput in, boolean useDocValues) throws IOException { + super(in); + this.useDocValues = useDocValues; + } + public abstract SpatialAggregateFunction withDocValues(); @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroid.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroid.java index 418f92284cca0..d5681ba8d856e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroid.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroid.java @@ -6,6 +6,8 @@ */ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.spatial.SpatialCentroidCartesianPointDocValuesAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.spatial.SpatialCentroidCartesianPointSourceValuesAggregatorFunctionSupplier; @@ -20,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.planner.ToAggregator; +import java.io.IOException; import java.util.List; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; @@ -29,6 +32,11 @@ * Calculate spatial centroid of all geo_point or cartesian point values of a field in matching documents. */ public class SpatialCentroid extends SpatialAggregateFunction implements ToAggregator { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "SpatialCentroid", + SpatialCentroid::new + ); @FunctionInfo(returnType = { "geo_point", "cartesian_point" }, description = "The centroid of a spatial field.", isAggregation = true) public SpatialCentroid(Source source, @Param(name = "field", type = { "geo_point", "cartesian_point" }) Expression field) { @@ -39,6 +47,15 @@ private SpatialCentroid(Source source, Expression field, boolean useDocValues) { super(source, field, useDocValues); } + private SpatialCentroid(StreamInput in) throws IOException { + super(in, false); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public SpatialCentroid withDocValues() { return new SpatialCentroid(source(), field(), true); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java index be9ae295f6fbc..34669454a2fa4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java @@ -6,6 +6,8 @@ */ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier; @@ -22,6 +24,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; +import java.io.IOException; import java.util.List; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; @@ -32,12 +35,22 @@ * Sum all values of a field in matching documents. */ public class Sum extends NumericAggregate implements SurrogateExpression { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::new); @FunctionInfo(returnType = "long", description = "The sum of a numeric field.", isAggregation = true) public Sum(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { super(source, field); } + private Sum(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, Sum::new, field()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopList.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopList.java index 79893b1c7de07..93e3da7c19cf8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopList.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopList.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.TopListDoubleAggregatorFunctionSupplier; @@ -38,6 +41,8 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; public class TopList extends AggregateFunction implements ToAggregator, SurrogateExpression { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "TopList", TopList::new); + private static final String ORDER_ASC = "ASC"; private static final String ORDER_DESC = "DESC"; @@ -64,24 +69,35 @@ public TopList( super(source, field, Arrays.asList(limit, order)); } - public static TopList readFrom(PlanStreamInput in) throws IOException { - return new TopList(Source.readFrom(in), in.readExpression(), in.readExpression(), in.readExpression()); + private TopList(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readExpression() + ); } - public void writeTo(PlanStreamOutput out) throws IOException { + @Override + public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); List fields = children(); assert fields.size() == 3; - out.writeExpression(fields.get(0)); - out.writeExpression(fields.get(1)); - out.writeExpression(fields.get(2)); + ((PlanStreamOutput) out).writeExpression(fields.get(0)); + ((PlanStreamOutput) out).writeExpression(fields.get(1)); + ((PlanStreamOutput) out).writeExpression(fields.get(2)); + } + + @Override + public String getWriteableName() { + return ENTRY.name; } - private Expression limitField() { + Expression limitField() { return parameters().get(0); } - private Expression orderField() { + Expression orderField() { return parameters().get(1); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Values.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Values.java index c76f60fe0f555..7d2fbcddb113b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Values.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Values.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.esql.expression.function.aggregate; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.ValuesBooleanAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier; @@ -23,11 +25,14 @@ import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.planner.ToAggregator; +import java.io.IOException; import java.util.List; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; public class Values extends AggregateFunction implements ToAggregator { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Values", Values::new); + @FunctionInfo( returnType = { "boolean|date|double|integer|ip|keyword|long|text|version" }, description = "Collect values for a field.", @@ -40,6 +45,15 @@ public Values( super(source, v); } + private Values(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, Values::new, field()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java index 431494534f4ec..dab2019a50682 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java @@ -9,6 +9,9 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.Rounding; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; @@ -29,8 +32,11 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.math.Floor; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.type.EsqlDataTypes; +import java.io.IOException; import java.time.ZoneId; import java.time.ZoneOffset; import java.util.List; @@ -53,6 +59,8 @@ * In the former case, two parameters will be provided, in the latter four. */ public class Bucket extends GroupingFunction implements Validatable, TwoOptionalArguments { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Bucket", Bucket::new); + // TODO maybe we should just cover the whole of representable dates here - like ten years, 100 years, 1000 years, all the way up. // That way you never end up with more than the target number of buckets. private static final Rounding LARGEST_HUMAN_DATE_ROUNDING = Rounding.builder(Rounding.DateTimeUnit.YEAR_OF_CENTURY).build(); @@ -193,6 +201,31 @@ public Bucket( this.to = to; } + private Bucket(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readOptionalNamed(Expression.class), + ((PlanStreamInput) in).readOptionalNamed(Expression.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + ((PlanStreamOutput) out).writeExpression(field); + ((PlanStreamOutput) out).writeExpression(buckets); + ((PlanStreamOutput) out).writeOptionalExpression(from); + ((PlanStreamOutput) out).writeOptionalExpression(to); + + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public boolean foldable() { return field.foldable() && buckets.foldable() && (from == null || from.foldable()) && (to == null || to.foldable()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java index 17934c1729ad7..f8adf4e5d9e16 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java @@ -10,8 +10,11 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; +import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Least; @@ -21,10 +24,29 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateParse; import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateTrunc; import org.elasticsearch.xpack.esql.expression.function.scalar.date.Now; +import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch; +import org.elasticsearch.xpack.esql.expression.function.scalar.ip.IpPrefix; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.Atan2; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.E; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.Log; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pi; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.Round; +import org.elasticsearch.xpack.esql.expression.function.scalar.math.Tau; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.EndsWith; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Left; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Locate; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Repeat; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Replace; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Right; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Split; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.StartsWith; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring; import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToLower; import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InsensitiveEquals; import java.util.List; @@ -42,18 +64,40 @@ public abstract class EsqlScalarFunction extends ScalarFunction implements EvaluatorMapper { public static List getNamedWriteables() { return List.of( + And.ENTRY, + Atan2.ENTRY, + Bucket.ENTRY, Case.ENTRY, + CIDRMatch.ENTRY, Coalesce.ENTRY, Concat.ENTRY, + E.ENTRY, + EndsWith.ENTRY, Greatest.ENTRY, + In.ENTRY, InsensitiveEquals.ENTRY, DateExtract.ENTRY, DateDiff.ENTRY, DateFormat.ENTRY, DateParse.ENTRY, DateTrunc.ENTRY, + IpPrefix.ENTRY, Least.ENTRY, + Left.ENTRY, + Locate.ENTRY, + Log.ENTRY, Now.ENTRY, + Or.ENTRY, + Pi.ENTRY, + Pow.ENTRY, + Right.ENTRY, + Repeat.ENTRY, + Replace.ENTRY, + Round.ENTRY, + Split.ENTRY, + Substring.ENTRY, + StartsWith.ENTRY, + Tau.ENTRY, ToLower.ENTRY, ToUpper.ENTRY ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/UnaryScalarFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/UnaryScalarFunction.java index eb2e5ab94487f..6c43e74593335 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/UnaryScalarFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/UnaryScalarFunction.java @@ -54,8 +54,10 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.StY; import org.elasticsearch.xpack.esql.expression.function.scalar.string.LTrim; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Length; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.RTrim; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; @@ -86,6 +88,7 @@ public static List getNamedWriteables() { LTrim.ENTRY, Neg.ENTRY, Not.ENTRY, + RLike.ENTRY, RTrim.ENTRY, Signum.ENTRY, Sin.ENTRY, @@ -111,7 +114,8 @@ public static List getNamedWriteables() { ToString.ENTRY, ToUnsignedLong.ENTRY, ToVersion.ENTRY, - Trim.ENTRY + Trim.ENTRY, + WildcardLike.ENTRY ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/CIDRMatch.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/CIDRMatch.java index e2c2395446ed6..e24ee80fe7972 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/CIDRMatch.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/CIDRMatch.java @@ -8,6 +8,9 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.ip; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.network.CIDRUtils; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.operator.EvalOperator; @@ -22,7 +25,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -32,6 +38,8 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.fromIndex; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isIPAndExact; import static org.elasticsearch.xpack.esql.expression.EsqlTypeResolutions.isStringAndExact; +import static org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanReader.readerFromPlanReader; +import static org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanWriter.writerFromPlanWriter; /** * This function takes a first parameter of type IP, followed by one or more parameters evaluated to a CIDR specification: @@ -45,6 +53,11 @@ * Example: `| eval cidr="10.0.0.0/8" | where cidr_match(ip_field, "127.0.0.1/30", cidr)` */ public class CIDRMatch extends EsqlScalarFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "CIDRMatch", + CIDRMatch::new + ); private final Expression ipField; private final List matches; @@ -68,6 +81,27 @@ public CIDRMatch( this.matches = matches; } + private CIDRMatch(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + ((PlanStreamInput) in).readExpression(), + in.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readExpression)) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + assert children().size() > 1; + ((PlanStreamOutput) out).writeExpression(children().get(0)); + out.writeCollection(children().subList(1, children().size()), writerFromPlanWriter(PlanStreamOutput::writeExpression)); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + public Expression ipField() { return ipField; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/IpPrefix.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/IpPrefix.java index d00d1b2c35fcb..696ba1c09d08a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/IpPrefix.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/IpPrefix.java @@ -8,6 +8,9 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.ip; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; @@ -40,6 +43,8 @@ * Truncates an IP value to a given prefix length. */ public class IpPrefix extends EsqlScalarFunction implements OptionalArgument { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "IpPrefix", IpPrefix::new); + // Borrowed from Lucene, rfc4291 prefix private static final byte[] IPV4_PREFIX = new byte[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1 }; @@ -76,17 +81,28 @@ public IpPrefix( this.prefixLengthV6Field = prefixLengthV6Field; } - public static IpPrefix readFrom(PlanStreamInput in) throws IOException { - return new IpPrefix(Source.readFrom(in), in.readExpression(), in.readExpression(), in.readExpression()); + private IpPrefix(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readExpression() + ); } - public void writeTo(PlanStreamOutput out) throws IOException { + @Override + public void writeTo(StreamOutput out) throws IOException { source().writeTo(out); List fields = children(); assert fields.size() == 3; - out.writeExpression(fields.get(0)); - out.writeExpression(fields.get(1)); - out.writeExpression(fields.get(2)); + ((PlanStreamOutput) out).writeExpression(fields.get(0)); + ((PlanStreamOutput) out).writeExpression(fields.get(1)); + ((PlanStreamOutput) out).writeExpression(fields.get(2)); + } + + @Override + public String getWriteableName() { + return ENTRY.name; } public Expression ipField() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Atan2.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Atan2.java index a2af991a244c3..5370a31023522 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Atan2.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Atan2.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.math; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -19,7 +22,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import java.io.IOException; import java.util.List; import java.util.function.Function; @@ -29,6 +35,8 @@ * Inverse cosine trigonometric function. */ public class Atan2 extends EsqlScalarFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Atan2", Atan2::new); + private final Expression y; private final Expression x; @@ -56,6 +64,22 @@ public Atan2( this.x = x; } + private Atan2(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), ((PlanStreamInput) in).readExpression(), ((PlanStreamInput) in).readExpression()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + ((PlanStreamOutput) out).writeExpression(y()); + ((PlanStreamOutput) out).writeExpression(x()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public Expression replaceChildren(List newChildren) { return new Atan2(source(), newChildren.get(0), newChildren.get(1)); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java index 9bcd8a2467b1d..757b67b47ce72 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java @@ -7,17 +7,24 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.math; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import java.io.IOException; import java.util.List; /** * Function that emits Euler's number. */ public class E extends DoubleConstantFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "E", E::new); + @FunctionInfo( returnType = "double", description = "Returns {wikipedia}/E_(mathematical_constant)[Euler's number].", @@ -27,6 +34,20 @@ public E(Source source) { super(source); } + private E(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public Object fold() { return Math.E; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Log.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Log.java index 97007f10b31bc..d17f24cade17b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Log.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Log.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.math; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -18,7 +21,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -28,6 +34,7 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNumeric; public class Log extends EsqlScalarFunction implements OptionalArgument { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Log", Log::new); private final Expression base; private final Expression value; @@ -60,6 +67,27 @@ public Log( this.base = value != null ? base : null; } + private Log(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readOptionalNamed(Expression.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + assert children().size() == 1 || children().size() == 2; + ((PlanStreamOutput) out).writeExpression(children().get(0)); + out.writeOptionalWriteable(children().size() == 2 ? o -> ((PlanStreamOutput) o).writeExpression(children().get(1)) : null); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected TypeResolution resolveType() { if (childrenResolved() == false) { @@ -126,4 +154,12 @@ public ExpressionEvaluator.Factory toEvaluator(Function EVALUATOR_IDENTITY = (s, e) -> e; @@ -67,6 +74,26 @@ public Round( this.decimals = decimals; } + private Round(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readOptionalNamed(Expression.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + ((PlanStreamOutput) out).writeExpression(field); + ((PlanStreamOutput) out).writeOptionalExpression(decimals); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected TypeResolution resolveType() { if (childrenResolved() == false) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java index 7a2eb801be84a..17e5b027270d1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java @@ -7,17 +7,24 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.math; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import java.io.IOException; import java.util.List; /** * Function that emits tau, also known as 2 * pi. */ public class Tau extends DoubleConstantFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Tau", Tau::new); + public static final double TAU = Math.PI * 2; @FunctionInfo( @@ -29,6 +36,20 @@ public Tau(Source source) { super(source); } + private Tau(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public Object fold() { return TAU; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java index 0735c18e04e1e..75d5641458e3f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/BinarySpatialFunction.java @@ -8,6 +8,9 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.spatial; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.geometry.Geometry; import org.elasticsearch.lucene.spatial.CoordinateEncoder; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -15,11 +18,14 @@ import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; +import org.elasticsearch.xpack.esql.core.util.PlanStreamOutput; import org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes; import org.elasticsearch.xpack.esql.expression.EsqlTypeResolutions; import org.elasticsearch.xpack.esql.type.EsqlDataTypes; import java.io.IOException; +import java.util.List; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; @@ -34,6 +40,10 @@ * and of compatible CRS. For example geo_point and geo_shape can be compared, but not geo_point and cartesian_point. */ public abstract class BinarySpatialFunction extends BinaryScalarFunction implements SpatialEvaluatorFactory.SpatialSourceResolution { + public static List getNamedWriteables() { + return List.of(SpatialContains.ENTRY, SpatialDisjoint.ENTRY, SpatialIntersects.ENTRY, SpatialWithin.ENTRY, StDistance.ENTRY); + } + private final SpatialTypeResolver spatialTypeResolver; protected SpatialCrsType crsType; protected final boolean leftDocValues; @@ -53,6 +63,23 @@ protected BinarySpatialFunction( this.spatialTypeResolver = new SpatialTypeResolver(this, pointsOnly); } + protected BinarySpatialFunction(StreamInput in, boolean leftDocValues, boolean rightDocValues, boolean pointsOnly) throws IOException { + this( + Source.EMPTY, + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readExpression(), + leftDocValues, + rightDocValues, + pointsOnly + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + ((PlanStreamOutput) out).writeExpression(left()); + ((PlanStreamOutput) out).writeExpression(right()); + } + @Override protected TypeResolution resolveType() { return spatialTypeResolver.resolveType(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java index d6589d387479a..6c2d11ab0ad16 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java @@ -11,6 +11,8 @@ import org.apache.lucene.geo.Component2D; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.geo.Orientation; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.geometry.Geometry; @@ -51,6 +53,12 @@ * Here we simply wire the rules together specific to ST_CONTAINS and QueryRelation.CONTAINS. */ public class SpatialContains extends SpatialRelatesFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "SpatialContains", + SpatialContains::new + ); + // public for test access with reflection public static final SpatialRelationsContains GEO = new SpatialRelationsContains( SpatialCoordinateTypes.GEO, @@ -134,6 +142,15 @@ public SpatialContains( super(source, left, right, leftDocValues, rightDocValues); } + private SpatialContains(StreamInput in) throws IOException { + super(in, false, false); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public ShapeField.QueryRelation queryRelation() { return ShapeField.QueryRelation.CONTAINS; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java index 3476606fd3224..e5520079e1b10 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java @@ -11,6 +11,8 @@ import org.apache.lucene.geo.Component2D; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.geo.Orientation; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.geometry.Geometry; @@ -48,6 +50,12 @@ * Here we simply wire the rules together specific to ST_DISJOINT and QueryRelation.DISJOINT. */ public class SpatialDisjoint extends SpatialRelatesFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "SpatialDisjoint", + SpatialDisjoint::new + ); + // public for test access with reflection public static final SpatialRelations GEO = new SpatialRelations( ShapeField.QueryRelation.DISJOINT, @@ -89,6 +97,15 @@ private SpatialDisjoint(Source source, Expression left, Expression right, boolea super(source, left, right, leftDocValues, rightDocValues); } + private SpatialDisjoint(StreamInput in) throws IOException { + super(in, false, false); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public ShapeField.QueryRelation queryRelation() { return ShapeField.QueryRelation.DISJOINT; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java index 8589468a9ec71..045690340f6ac 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java @@ -11,6 +11,8 @@ import org.apache.lucene.geo.Component2D; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.geo.Orientation; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.geometry.Geometry; @@ -48,6 +50,12 @@ * Here we simply wire the rules together specific to ST_INTERSECTS and QueryRelation.INTERSECTS. */ public class SpatialIntersects extends SpatialRelatesFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "SpatialIntersects", + SpatialIntersects::new + ); + // public for test access with reflection public static final SpatialRelations GEO = new SpatialRelations( ShapeField.QueryRelation.INTERSECTS, @@ -87,6 +95,15 @@ private SpatialIntersects(Source source, Expression left, Expression right, bool super(source, left, right, leftDocValues, rightDocValues); } + private SpatialIntersects(StreamInput in) throws IOException { + super(in, false, false); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public ShapeField.QueryRelation queryRelation() { return ShapeField.QueryRelation.INTERSECTS; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java index ca9b6838cc2ce..68005ecbfed47 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesFunction.java @@ -10,6 +10,7 @@ import org.apache.lucene.document.ShapeField; import org.apache.lucene.geo.Component2D; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.geometry.Geometry; import org.elasticsearch.geometry.Point; @@ -46,6 +47,10 @@ protected SpatialRelatesFunction(Source source, Expression left, Expression righ super(source, left, right, leftDocValues, rightDocValues, false); } + protected SpatialRelatesFunction(StreamInput in, boolean leftDocValues, boolean rightDocValues) throws IOException { + super(in, leftDocValues, rightDocValues, false); + } + public abstract ShapeField.QueryRelation queryRelation(); @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java index 7792131e301ac..f72571a4b5250 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java @@ -11,6 +11,8 @@ import org.apache.lucene.geo.Component2D; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.geo.Orientation; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.geometry.Geometry; @@ -49,6 +51,12 @@ * Here we simply wire the rules together specific to ST_WITHIN and QueryRelation.WITHIN. */ public class SpatialWithin extends SpatialRelatesFunction implements SurrogateExpression { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "SpatialWithin", + SpatialWithin::new + ); + // public for test access with reflection public static final SpatialRelations GEO = new SpatialRelations( ShapeField.QueryRelation.WITHIN, @@ -89,6 +97,15 @@ public SpatialWithin( super(source, left, right, leftDocValues, rightDocValues); } + private SpatialWithin(StreamInput in) throws IOException { + super(in, false, false); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public ShapeField.QueryRelation queryRelation() { return ShapeField.QueryRelation.WITHIN; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java index 89c048f7eace8..2e20fba74476b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java @@ -9,6 +9,8 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.SloppyMath; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.compute.operator.EvalOperator; @@ -40,6 +42,12 @@ * Alternatively it is described in PostGIS documentation at PostGIS:ST_Distance. */ public class StDistance extends BinarySpatialFunction implements EvaluatorMapper { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "StDistance", + StDistance::new + ); + // public for test access with reflection public static final DistanceCalculator GEO = new GeoDistanceCalculator(); // public for test access with reflection @@ -132,6 +140,15 @@ protected StDistance(Source source, Expression left, Expression right, boolean l super(source, left, right, leftDocValues, rightDocValues, true); } + private StDistance(StreamInput in) throws IOException { + super(in, false, false, true); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public DataType dataType() { return DOUBLE; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWith.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWith.java index 767563ed4112a..f117ddf9816ad 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWith.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWith.java @@ -8,6 +8,9 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.string; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -18,7 +21,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -28,6 +34,7 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; public class EndsWith extends EsqlScalarFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "EndsWith", EndsWith::new); private final Expression str; private final Expression suffix; @@ -55,6 +62,22 @@ public EndsWith( this.suffix = suffix; } + private EndsWith(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), ((PlanStreamInput) in).readExpression(), ((PlanStreamInput) in).readExpression()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + ((PlanStreamOutput) out).writeExpression(str); + ((PlanStreamOutput) out).writeExpression(suffix); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public DataType dataType() { return DataType.BOOLEAN; @@ -107,4 +130,12 @@ protected NodeInfo info() { public ExpressionEvaluator.Factory toEvaluator(Function toEvaluator) { return new EndsWithEvaluator.Factory(source(), toEvaluator.apply(str), toEvaluator.apply(suffix)); } + + Expression str() { + return str; + } + + Expression suffix() { + return suffix; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Left.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Left.java index 384874e173658..4f93ec8525dc6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Left.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Left.java @@ -9,6 +9,9 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.UnicodeUtil; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; @@ -21,7 +24,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -35,6 +41,8 @@ * {code left(foo, len)} is an alias to {code substring(foo, 0, len)} */ public class Left extends EsqlScalarFunction { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Left", Left::new); + private final Expression str; private final Expression length; @@ -53,6 +61,22 @@ public Left( this.length = length; } + private Left(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), ((PlanStreamInput) in).readExpression(), ((PlanStreamInput) in).readExpression()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + ((PlanStreamOutput) out).writeExpression(str); + ((PlanStreamOutput) out).writeExpression(length); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Evaluator static BytesRef process( @Fixed(includeInToString = false, build = true) BytesRef out, @@ -120,4 +144,12 @@ protected TypeResolution resolveType() { public boolean foldable() { return str.foldable() && length.foldable(); } + + Expression str() { + return str; + } + + Expression length() { + return length; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Locate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Locate.java index 1669a64ec83d2..3ea741d3a42d4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Locate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Locate.java @@ -9,6 +9,9 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.UnicodeUtil; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -20,7 +23,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -35,6 +41,7 @@ * Locate function, given a string 'a' and a substring 'b', it returns the index of the first occurrence of the substring 'b' in 'a'. */ public class Locate extends EsqlScalarFunction implements OptionalArgument { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Locate", Locate::new); private final Expression str; private final Expression substr; @@ -61,6 +68,28 @@ public Locate( this.start = start; } + private Locate(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readOptionalNamed(Expression.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + ((PlanStreamOutput) out).writeExpression(str); + ((PlanStreamOutput) out).writeExpression(substr); + ((PlanStreamOutput) out).writeOptionalExpression(start); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public DataType dataType() { return DataType.INTEGER; @@ -142,4 +171,16 @@ public ExpressionEvaluator.Factory toEvaluator(Function info() { return NodeInfo.create(this, RLike::new, field(), pattern(), caseInsensitive()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java index e8ad0a83829fe..9c5fee999c332 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java @@ -8,6 +8,9 @@ package org.elasticsearch.xpack.esql.expression.function.scalar.string; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; @@ -21,7 +24,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -33,6 +39,7 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; public class Repeat extends EsqlScalarFunction implements OptionalArgument { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Repeat", Repeat::new); static final long MAX_REPEATED_LENGTH = MB.toBytes(1); @@ -54,6 +61,22 @@ public Repeat( this.number = number; } + private Repeat(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), ((PlanStreamInput) in).readExpression(), ((PlanStreamInput) in).readExpression()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + ((PlanStreamOutput) out).writeExpression(str); + ((PlanStreamOutput) out).writeExpression(number); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public DataType dataType() { return DataType.KEYWORD; @@ -145,4 +168,12 @@ public ExpressionEvaluator.Factory toEvaluator(Function info() { public ExpressionEvaluator.Factory toEvaluator(Function toEvaluator) { return new StartsWithEvaluator.Factory(source(), toEvaluator.apply(str), toEvaluator.apply(prefix)); } + + Expression str() { + return str; + } + + Expression prefix() { + return prefix; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Substring.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Substring.java index 94b9f06b63b5d..cb8aa1c8e2a44 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Substring.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Substring.java @@ -9,6 +9,9 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.UnicodeUtil; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -21,7 +24,10 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -33,6 +39,11 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; public class Substring extends EsqlScalarFunction implements OptionalArgument { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "Substring", + Substring::new + ); private final Expression str, start, length; @@ -69,6 +80,28 @@ public Substring( this.length = length; } + private Substring(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readExpression(), + ((PlanStreamInput) in).readOptionalNamed(Expression.class) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + ((PlanStreamOutput) out).writeExpression(str); + ((PlanStreamOutput) out).writeExpression(start); + ((PlanStreamOutput) out).writeOptionalExpression(length); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public DataType dataType() { return DataType.KEYWORD; @@ -157,4 +190,16 @@ public ExpressionEvaluator.Factory toEvaluator(Function info() { return NodeInfo.create(this, WildcardLike::new, field(), pattern()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java index 17fca1e1cff88..a2024f9e9e7e4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.InProcessor; @@ -16,15 +19,22 @@ import org.elasticsearch.xpack.esql.expression.EsqlTypeResolutions; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; import org.elasticsearch.xpack.esql.type.EsqlDataTypes; +import java.io.IOException; import java.util.List; import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; import static org.elasticsearch.xpack.esql.core.util.StringUtils.ordinal; +import static org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanReader.readerFromPlanReader; +import static org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanWriter.writerFromPlanWriter; public class In extends org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.In { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "In", In::new); + @FunctionInfo( returnType = "boolean", description = "The `IN` operator allows testing whether a field or expression equals an element in a list of literals, " @@ -35,6 +45,26 @@ public In(Source source, Expression value, List list) { super(source, value, list); } + private In(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + ((PlanStreamInput) in).readExpression(), + in.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readExpression)) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + ((PlanStreamOutput) out).writeExpression(value()); + out.writeCollection(list(), writerFromPlanWriter(PlanStreamOutput::writeExpression)); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, In::new, value(), list()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java index 0a04cb5345ee8..fc23e0494732b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java @@ -9,7 +9,6 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; @@ -28,17 +27,10 @@ import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Order; -import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; import org.elasticsearch.xpack.esql.core.expression.predicate.fulltext.FullTextPredicate; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RegexMatch; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; import org.elasticsearch.xpack.esql.core.index.EsIndex; import org.elasticsearch.xpack.esql.core.plan.logical.Filter; import org.elasticsearch.xpack.esql.core.plan.logical.Limit; @@ -48,53 +40,12 @@ import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; -import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Max; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Median; -import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; -import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; -import org.elasticsearch.xpack.esql.expression.function.aggregate.TopList; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; -import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; -import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.UnaryScalarFunction; -import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch; -import org.elasticsearch.xpack.esql.expression.function.scalar.ip.IpPrefix; -import org.elasticsearch.xpack.esql.expression.function.scalar.math.Atan2; -import org.elasticsearch.xpack.esql.expression.function.scalar.math.E; -import org.elasticsearch.xpack.esql.expression.function.scalar.math.Log; -import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pi; -import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow; -import org.elasticsearch.xpack.esql.expression.function.scalar.math.Round; -import org.elasticsearch.xpack.esql.expression.function.scalar.math.Tau; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.AbstractMultivalueFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.BinarySpatialFunction; -import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialContains; -import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialDisjoint; -import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialIntersects; -import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialWithin; -import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.StDistance; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.EndsWith; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.Left; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.Locate; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.Repeat; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.Replace; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.Right; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.Split; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.StartsWith; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; -import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Dissect; import org.elasticsearch.xpack.esql.plan.logical.Dissect.Parser; @@ -139,7 +90,6 @@ import java.util.Map; import java.util.Set; import java.util.function.BiFunction; -import java.util.function.Function; import static java.util.Map.entry; import static org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.Entry.of; @@ -171,11 +121,6 @@ public static String name(Class cls) { return cls.getSimpleName(); } - static final Class QL_UNARY_SCLR_CLS = - org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction.class; - - static final Class ESQL_UNARY_SCLR_CLS = UnaryScalarFunction.class; - /** * List of named type entries that link concrete names to stream reader and writer implementations. * Entries have the form: category, name, serializer method, deserializer method. @@ -211,7 +156,7 @@ public static List namedTypeEntries() { of(PhysicalPlan.class, ShowExec.class, PlanNamedTypes::writeShowExec, PlanNamedTypes::readShowExec), of(PhysicalPlan.class, TopNExec.class, PlanNamedTypes::writeTopNExec, PlanNamedTypes::readTopNExec), // Logical Plan Nodes - a subset of plans that end up being actually serialized - of(LogicalPlan.class, Aggregate.class, PlanNamedTypes::writeAggregate, PlanNamedTypes::readAggregate), + of(LogicalPlan.class, Aggregate.class, Aggregate::writeAggregate, Aggregate::new), of(LogicalPlan.class, Dissect.class, PlanNamedTypes::writeDissect, PlanNamedTypes::readDissect), of(LogicalPlan.class, EsRelation.class, PlanNamedTypes::writeEsRelation, PlanNamedTypes::readEsRelation), of(LogicalPlan.class, Eval.class, PlanNamedTypes::writeEval, PlanNamedTypes::readEval), @@ -226,61 +171,15 @@ public static List namedTypeEntries() { of(LogicalPlan.class, MvExpand.class, PlanNamedTypes::writeMvExpand, PlanNamedTypes::readMvExpand), of(LogicalPlan.class, OrderBy.class, PlanNamedTypes::writeOrderBy, PlanNamedTypes::readOrderBy), of(LogicalPlan.class, Project.class, PlanNamedTypes::writeProject, PlanNamedTypes::readProject), - of(LogicalPlan.class, TopN.class, PlanNamedTypes::writeTopN, PlanNamedTypes::readTopN), - // InComparison - of(ScalarFunction.class, In.class, PlanNamedTypes::writeInComparison, PlanNamedTypes::readInComparison), - // RegexMatch - of(RegexMatch.class, WildcardLike.class, PlanNamedTypes::writeWildcardLike, PlanNamedTypes::readWildcardLike), - of(RegexMatch.class, RLike.class, PlanNamedTypes::writeRLike, PlanNamedTypes::readRLike), - // BinaryLogic - of(BinaryLogic.class, And.class, PlanNamedTypes::writeBinaryLogic, PlanNamedTypes::readBinaryLogic), - of(BinaryLogic.class, Or.class, PlanNamedTypes::writeBinaryLogic, PlanNamedTypes::readBinaryLogic), - // ScalarFunction - of(ScalarFunction.class, Atan2.class, PlanNamedTypes::writeAtan2, PlanNamedTypes::readAtan2), - of(ScalarFunction.class, CIDRMatch.class, PlanNamedTypes::writeCIDRMatch, PlanNamedTypes::readCIDRMatch), - of(ScalarFunction.class, E.class, PlanNamedTypes::writeNoArgScalar, PlanNamedTypes::readNoArgScalar), - of(ScalarFunction.class, IpPrefix.class, (out, prefix) -> prefix.writeTo(out), IpPrefix::readFrom), - of(ScalarFunction.class, Log.class, PlanNamedTypes::writeLog, PlanNamedTypes::readLog), - of(ScalarFunction.class, Pi.class, PlanNamedTypes::writeNoArgScalar, PlanNamedTypes::readNoArgScalar), - of(ScalarFunction.class, Round.class, PlanNamedTypes::writeRound, PlanNamedTypes::readRound), - of(ScalarFunction.class, Pow.class, PlanNamedTypes::writePow, PlanNamedTypes::readPow), - of(ScalarFunction.class, StartsWith.class, PlanNamedTypes::writeStartsWith, PlanNamedTypes::readStartsWith), - of(ScalarFunction.class, EndsWith.class, PlanNamedTypes::writeEndsWith, PlanNamedTypes::readEndsWith), - of(ScalarFunction.class, SpatialIntersects.class, PlanNamedTypes::writeBinarySpatialFunction, PlanNamedTypes::readIntersects), - of(ScalarFunction.class, SpatialDisjoint.class, PlanNamedTypes::writeBinarySpatialFunction, PlanNamedTypes::readDisjoint), - of(ScalarFunction.class, SpatialContains.class, PlanNamedTypes::writeBinarySpatialFunction, PlanNamedTypes::readContains), - of(ScalarFunction.class, SpatialWithin.class, PlanNamedTypes::writeBinarySpatialFunction, PlanNamedTypes::readWithin), - of(ScalarFunction.class, StDistance.class, PlanNamedTypes::writeBinarySpatialFunction, PlanNamedTypes::readDistance), - of(ScalarFunction.class, Substring.class, PlanNamedTypes::writeSubstring, PlanNamedTypes::readSubstring), - of(ScalarFunction.class, Locate.class, PlanNamedTypes::writeLocate, PlanNamedTypes::readLocate), - of(ScalarFunction.class, Left.class, PlanNamedTypes::writeLeft, PlanNamedTypes::readLeft), - of(ScalarFunction.class, Repeat.class, PlanNamedTypes::writeRepeat, PlanNamedTypes::readRepeat), - of(ScalarFunction.class, Right.class, PlanNamedTypes::writeRight, PlanNamedTypes::readRight), - of(ScalarFunction.class, Split.class, PlanNamedTypes::writeSplit, PlanNamedTypes::readSplit), - of(ScalarFunction.class, Tau.class, PlanNamedTypes::writeNoArgScalar, PlanNamedTypes::readNoArgScalar), - of(ScalarFunction.class, Replace.class, PlanNamedTypes::writeReplace, PlanNamedTypes::readReplace), - // GroupingFunctions - of(GroupingFunction.class, Bucket.class, PlanNamedTypes::writeBucket, PlanNamedTypes::readBucket), - // AggregateFunctions - of(AggregateFunction.class, Avg.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), - of(AggregateFunction.class, Count.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), - of(AggregateFunction.class, CountDistinct.class, PlanNamedTypes::writeCountDistinct, PlanNamedTypes::readCountDistinct), - of(AggregateFunction.class, Min.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), - of(AggregateFunction.class, Max.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), - of(AggregateFunction.class, Median.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), - of(AggregateFunction.class, MedianAbsoluteDeviation.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), - of(AggregateFunction.class, Percentile.class, PlanNamedTypes::writePercentile, PlanNamedTypes::readPercentile), - of(AggregateFunction.class, SpatialCentroid.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), - of(AggregateFunction.class, Sum.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), - of(AggregateFunction.class, TopList.class, (out, prefix) -> prefix.writeTo(out), TopList::readFrom), - of(AggregateFunction.class, Values.class, PlanNamedTypes::writeAggFunction, PlanNamedTypes::readAggFunction), - of(AggregateFunction.class, Rate.class, Rate::writeRate, Rate::readRate) + of(LogicalPlan.class, TopN.class, PlanNamedTypes::writeTopN, PlanNamedTypes::readTopN) ); List entries = new ArrayList<>(declared); // From NamedWriteables for (List ee : List.of( AbstractMultivalueFunction.getNamedWriteables(), + AggregateFunction.getNamedWriteables(), + BinarySpatialFunction.getNamedWriteables(), EsqlArithmeticOperation.getNamedWriteables(), EsqlBinaryComparison.getNamedWriteables(), EsqlScalarFunction.getNamedWriteables(), @@ -660,23 +559,6 @@ static void writeTopNExec(PlanStreamOutput out, TopNExec topNExec) throws IOExce out.writeOptionalVInt(topNExec.estimatedRowSize()); } - // -- Logical plan nodes - static Aggregate readAggregate(PlanStreamInput in) throws IOException { - return new Aggregate( - Source.readFrom(in), - in.readLogicalPlanNode(), - in.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readExpression)), - in.readNamedWriteableCollectionAsList(NamedExpression.class) - ); - } - - static void writeAggregate(PlanStreamOutput out, Aggregate aggregate) throws IOException { - Source.EMPTY.writeTo(out); - out.writeLogicalPlanNode(aggregate.child()); - out.writeCollection(aggregate.groupings(), writerFromPlanWriter(PlanStreamOutput::writeExpression)); - out.writeNamedWriteableCollection(aggregate.aggregates()); - } - static Dissect readDissect(PlanStreamInput in) throws IOException { return new Dissect( Source.readFrom(in), @@ -920,82 +802,6 @@ static void writeTopN(PlanStreamOutput out, TopN topN) throws IOException { out.writeExpression(topN.limit()); } - // -- InComparison - - static In readInComparison(PlanStreamInput in) throws IOException { - return new In( - Source.readFrom(in), - in.readExpression(), - in.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readExpression)) - ); - } - - static void writeInComparison(PlanStreamOutput out, In in) throws IOException { - in.source().writeTo(out); - out.writeExpression(in.value()); - out.writeCollection(in.list(), writerFromPlanWriter(PlanStreamOutput::writeExpression)); - } - - // -- RegexMatch - - static WildcardLike readWildcardLike(PlanStreamInput in, String name) throws IOException { - return new WildcardLike(Source.readFrom(in), in.readExpression(), new WildcardPattern(in.readString())); - } - - static void writeWildcardLike(PlanStreamOutput out, WildcardLike like) throws IOException { - like.source().writeTo(out); - out.writeExpression(like.field()); - out.writeString(like.pattern().pattern()); - } - - static RLike readRLike(PlanStreamInput in, String name) throws IOException { - return new RLike(Source.readFrom(in), in.readExpression(), new RLikePattern(in.readString())); - } - - static void writeRLike(PlanStreamOutput out, RLike like) throws IOException { - like.source().writeTo(out); - out.writeExpression(like.field()); - out.writeString(like.pattern().asJavaRegex()); - } - - // -- BinaryLogic - - static final Map> BINARY_LOGIC_CTRS = Map.ofEntries( - entry(name(And.class), And::new), - entry(name(Or.class), Or::new) - ); - - static BinaryLogic readBinaryLogic(PlanStreamInput in, String name) throws IOException { - var source = Source.readFrom(in); - var left = in.readExpression(); - var right = in.readExpression(); - return BINARY_LOGIC_CTRS.get(name).apply(source, left, right); - } - - static void writeBinaryLogic(PlanStreamOutput out, BinaryLogic binaryLogic) throws IOException { - Source.EMPTY.writeTo(out); - out.writeExpression(binaryLogic.left()); - out.writeExpression(binaryLogic.right()); - } - - static final Map> NO_ARG_SCALAR_CTRS = Map.ofEntries( - entry(name(E.class), E::new), - entry(name(Pi.class), Pi::new), - entry(name(Tau.class), Tau::new) - ); - - static ScalarFunction readNoArgScalar(PlanStreamInput in, String name) throws IOException { - var ctr = NO_ARG_SCALAR_CTRS.get(name); - if (ctr == null) { - throw new IOException("Constructor not found:" + name); - } - return ctr.apply(Source.readFrom(in)); - } - - static void writeNoArgScalar(PlanStreamOutput out, ScalarFunction function) throws IOException { - Source.EMPTY.writeTo(out); - } - static final Map< String, BiFunction< @@ -1026,251 +832,6 @@ static void writeQLUnaryScalar( out.writeExpression(function.field()); } - // -- ScalarFunction - - static Atan2 readAtan2(PlanStreamInput in) throws IOException { - return new Atan2(Source.readFrom(in), in.readExpression(), in.readExpression()); - } - - static void writeAtan2(PlanStreamOutput out, Atan2 atan2) throws IOException { - atan2.source().writeTo(out); - out.writeExpression(atan2.y()); - out.writeExpression(atan2.x()); - } - - static Bucket readBucket(PlanStreamInput in) throws IOException { - return new Bucket( - Source.readFrom(in), - in.readExpression(), - in.readExpression(), - in.readOptionalNamed(Expression.class), - in.readOptionalNamed(Expression.class) - ); - } - - static void writeBucket(PlanStreamOutput out, Bucket bucket) throws IOException { - bucket.source().writeTo(out); - out.writeExpression(bucket.field()); - out.writeExpression(bucket.buckets()); - out.writeOptionalExpression(bucket.from()); - out.writeOptionalExpression(bucket.to()); - } - - static CountDistinct readCountDistinct(PlanStreamInput in) throws IOException { - return new CountDistinct(Source.readFrom(in), in.readExpression(), in.readOptionalNamed(Expression.class)); - } - - static void writeCountDistinct(PlanStreamOutput out, CountDistinct countDistinct) throws IOException { - List fields = countDistinct.children(); - assert fields.size() == 1 || fields.size() == 2; - Source.EMPTY.writeTo(out); - out.writeExpression(fields.get(0)); - out.writeOptionalWriteable(fields.size() == 2 ? o -> out.writeExpression(fields.get(1)) : null); - } - - static SpatialIntersects readIntersects(PlanStreamInput in) throws IOException { - return new SpatialIntersects(Source.EMPTY, in.readExpression(), in.readExpression()); - } - - static SpatialDisjoint readDisjoint(PlanStreamInput in) throws IOException { - return new SpatialDisjoint(Source.EMPTY, in.readExpression(), in.readExpression()); - } - - static SpatialContains readContains(PlanStreamInput in) throws IOException { - return new SpatialContains(Source.EMPTY, in.readExpression(), in.readExpression()); - } - - static SpatialWithin readWithin(PlanStreamInput in) throws IOException { - return new SpatialWithin(Source.EMPTY, in.readExpression(), in.readExpression()); - } - - static StDistance readDistance(PlanStreamInput in) throws IOException { - return new StDistance(Source.EMPTY, in.readExpression(), in.readExpression()); - } - - static void writeBinarySpatialFunction(PlanStreamOutput out, BinarySpatialFunction binarySpatialFunction) throws IOException { - out.writeExpression(binarySpatialFunction.left()); - out.writeExpression(binarySpatialFunction.right()); - } - - static Round readRound(PlanStreamInput in) throws IOException { - return new Round(Source.readFrom(in), in.readExpression(), in.readOptionalNamed(Expression.class)); - } - - static void writeRound(PlanStreamOutput out, Round round) throws IOException { - round.source().writeTo(out); - out.writeExpression(round.field()); - out.writeOptionalExpression(round.decimals()); - } - - static Pow readPow(PlanStreamInput in) throws IOException { - return new Pow(Source.readFrom(in), in.readExpression(), in.readExpression()); - } - - static void writePow(PlanStreamOutput out, Pow pow) throws IOException { - pow.source().writeTo(out); - out.writeExpression(pow.base()); - out.writeExpression(pow.exponent()); - } - - static Percentile readPercentile(PlanStreamInput in) throws IOException { - return new Percentile(Source.readFrom(in), in.readExpression(), in.readExpression()); - } - - static void writePercentile(PlanStreamOutput out, Percentile percentile) throws IOException { - List fields = percentile.children(); - assert fields.size() == 2 : "percentile() aggregation must have two arguments"; - Source.EMPTY.writeTo(out); - out.writeExpression(fields.get(0)); - out.writeExpression(fields.get(1)); - } - - static StartsWith readStartsWith(PlanStreamInput in) throws IOException { - return new StartsWith(Source.readFrom(in), in.readExpression(), in.readExpression()); - } - - static void writeStartsWith(PlanStreamOutput out, StartsWith startsWith) throws IOException { - startsWith.source().writeTo(out); - List fields = startsWith.children(); - assert fields.size() == 2; - out.writeExpression(fields.get(0)); - out.writeExpression(fields.get(1)); - } - - static EndsWith readEndsWith(PlanStreamInput in) throws IOException { - return new EndsWith(Source.readFrom(in), in.readExpression(), in.readExpression()); - } - - static void writeEndsWith(PlanStreamOutput out, EndsWith endsWith) throws IOException { - List fields = endsWith.children(); - assert fields.size() == 2; - Source.EMPTY.writeTo(out); - out.writeExpression(fields.get(0)); - out.writeExpression(fields.get(1)); - } - - static Substring readSubstring(PlanStreamInput in) throws IOException { - return new Substring(Source.readFrom(in), in.readExpression(), in.readExpression(), in.readOptionalNamed(Expression.class)); - } - - static void writeSubstring(PlanStreamOutput out, Substring substring) throws IOException { - substring.source().writeTo(out); - List fields = substring.children(); - assert fields.size() == 2 || fields.size() == 3; - out.writeExpression(fields.get(0)); - out.writeExpression(fields.get(1)); - out.writeOptionalWriteable(fields.size() == 3 ? o -> out.writeExpression(fields.get(2)) : null); - } - - static Locate readLocate(PlanStreamInput in) throws IOException { - return new Locate(Source.readFrom(in), in.readExpression(), in.readExpression(), in.readOptionalNamed(Expression.class)); - } - - static void writeLocate(PlanStreamOutput out, Locate locate) throws IOException { - locate.source().writeTo(out); - List fields = locate.children(); - assert fields.size() == 2 || fields.size() == 3; - out.writeExpression(fields.get(0)); - out.writeExpression(fields.get(1)); - out.writeOptionalWriteable(fields.size() == 3 ? o -> out.writeExpression(fields.get(2)) : null); - } - - static Replace readReplace(PlanStreamInput in) throws IOException { - return new Replace(Source.EMPTY, in.readExpression(), in.readExpression(), in.readExpression()); - } - - static void writeReplace(PlanStreamOutput out, Replace replace) throws IOException { - List fields = replace.children(); - assert fields.size() == 3; - out.writeExpression(fields.get(0)); - out.writeExpression(fields.get(1)); - out.writeExpression(fields.get(2)); - } - - static Left readLeft(PlanStreamInput in) throws IOException { - return new Left(Source.readFrom(in), in.readExpression(), in.readExpression()); - } - - static void writeLeft(PlanStreamOutput out, Left left) throws IOException { - left.source().writeTo(out); - List fields = left.children(); - assert fields.size() == 2; - out.writeExpression(fields.get(0)); - out.writeExpression(fields.get(1)); - } - - static Repeat readRepeat(PlanStreamInput in) throws IOException { - return new Repeat(Source.readFrom(in), in.readExpression(), in.readExpression()); - } - - static void writeRepeat(PlanStreamOutput out, Repeat repeat) throws IOException { - repeat.source().writeTo(out); - List fields = repeat.children(); - assert fields.size() == 2; - out.writeExpression(fields.get(0)); - out.writeExpression(fields.get(1)); - } - - static Right readRight(PlanStreamInput in) throws IOException { - return new Right(Source.readFrom(in), in.readExpression(), in.readExpression()); - } - - static void writeRight(PlanStreamOutput out, Right right) throws IOException { - right.source().writeTo(out); - List fields = right.children(); - assert fields.size() == 2; - out.writeExpression(fields.get(0)); - out.writeExpression(fields.get(1)); - } - - static Split readSplit(PlanStreamInput in) throws IOException { - return new Split(Source.readFrom(in), in.readExpression(), in.readExpression()); - } - - static void writeSplit(PlanStreamOutput out, Split split) throws IOException { - split.source().writeTo(out); - out.writeExpression(split.left()); - out.writeExpression(split.right()); - } - - static CIDRMatch readCIDRMatch(PlanStreamInput in) throws IOException { - return new CIDRMatch( - Source.readFrom(in), - in.readExpression(), - in.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readExpression)) - ); - } - - static void writeCIDRMatch(PlanStreamOutput out, CIDRMatch cidrMatch) throws IOException { - cidrMatch.source().writeTo(out); - List children = cidrMatch.children(); - assert children.size() > 1; - out.writeExpression(children.get(0)); - out.writeCollection(children.subList(1, children.size()), writerFromPlanWriter(PlanStreamOutput::writeExpression)); - } - - // -- Aggregations - static final Map> AGG_CTRS = Map.ofEntries( - entry(name(Avg.class), Avg::new), - entry(name(Count.class), Count::new), - entry(name(Sum.class), Sum::new), - entry(name(Min.class), Min::new), - entry(name(Max.class), Max::new), - entry(name(Median.class), Median::new), - entry(name(MedianAbsoluteDeviation.class), MedianAbsoluteDeviation::new), - entry(name(SpatialCentroid.class), SpatialCentroid::new), - entry(name(Values.class), Values::new) - ); - - static AggregateFunction readAggFunction(PlanStreamInput in, String name) throws IOException { - return AGG_CTRS.get(name).apply(Source.readFrom(in), in.readExpression()); - } - - static void writeAggFunction(PlanStreamOutput out, AggregateFunction aggregateFunction) throws IOException { - Source.EMPTY.writeTo(out); - out.writeExpression(aggregateFunction.field()); - } - // -- ancillary supporting classes of plan nodes, etc static EsQueryExec.FieldSort readFieldSort(PlanStreamInput in) throws IOException { @@ -1312,16 +873,4 @@ static void writeDissectParser(PlanStreamOutput out, Parser dissectParser) throw out.writeString(dissectParser.pattern()); out.writeString(dissectParser.appendSeparator()); } - - static Log readLog(PlanStreamInput in) throws IOException { - return new Log(Source.readFrom(in), in.readExpression(), in.readOptionalNamed(Expression.class)); - } - - static void writeLog(PlanStreamOutput out, Log log) throws IOException { - log.source().writeTo(out); - List fields = log.children(); - assert fields.size() == 1 || fields.size() == 2; - out.writeExpression(fields.get(0)); - out.writeOptionalWriteable(fields.size() == 2 ? o -> out.writeExpression(fields.get(1)) : null); - } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java index 5eb024d410992..2a70ccdd3705c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java @@ -7,7 +7,13 @@ package org.elasticsearch.xpack.esql.optimizer; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.core.Tuple; +import org.elasticsearch.geometry.Circle; +import org.elasticsearch.geometry.Geometry; +import org.elasticsearch.geometry.Point; +import org.elasticsearch.geometry.utils.WellKnownBinary; +import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.esql.VerificationException; @@ -18,6 +24,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Order; @@ -34,6 +41,7 @@ import org.elasticsearch.xpack.esql.core.querydsl.query.Query; import org.elasticsearch.xpack.esql.core.rule.ParameterizedRuleExecutor; import org.elasticsearch.xpack.esql.core.rule.Rule; +import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.Queries; import org.elasticsearch.xpack.esql.core.util.Queries.Clause; @@ -41,8 +49,12 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialAggregateFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch; +import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialIntersects; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction; +import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesUtils; +import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.StDistance; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InsensitiveBinaryComparison; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; @@ -64,6 +76,7 @@ import org.elasticsearch.xpack.esql.planner.EsqlTranslatorHandler; import org.elasticsearch.xpack.esql.stats.SearchStats; +import java.nio.ByteOrder; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; @@ -103,13 +116,14 @@ PhysicalPlan verify(PhysicalPlan plan) { protected List> rules(boolean optimizeForEsSource) { List> esSourceRules = new ArrayList<>(4); - esSourceRules.add(new ReplaceAttributeSourceWithDocId()); + esSourceRules.add(new ReplaceSourceAttributes()); if (optimizeForEsSource) { esSourceRules.add(new PushTopNToSource()); esSourceRules.add(new PushLimitToSource()); esSourceRules.add(new PushFiltersToSource()); esSourceRules.add(new PushStatsToSource()); + esSourceRules.add(new EnableSpatialDistancePushdown()); } // execute the rules multiple times to improve the chances of things being pushed down @@ -126,15 +140,32 @@ protected List> batches() { return rules(true); } - private static class ReplaceAttributeSourceWithDocId extends OptimizerRule { + private static class ReplaceSourceAttributes extends OptimizerRule { - ReplaceAttributeSourceWithDocId() { + ReplaceSourceAttributes() { super(UP); } @Override protected PhysicalPlan rule(EsSourceExec plan) { - return new EsQueryExec(plan.source(), plan.index(), plan.indexMode(), plan.query()); + var docId = new FieldAttribute(plan.source(), EsQueryExec.DOC_ID_FIELD.getName(), EsQueryExec.DOC_ID_FIELD); + if (plan.indexMode() == IndexMode.TIME_SERIES) { + Attribute tsid = null, timestamp = null; + for (Attribute attr : plan.output()) { + String name = attr.name(); + if (name.equals(MetadataAttribute.TSID_FIELD)) { + tsid = attr; + } else if (name.equals(MetadataAttribute.TIMESTAMP_FIELD)) { + timestamp = attr; + } + } + if (tsid == null || timestamp == null) { + throw new IllegalStateException("_tsid or @timestamp are missing from the time-series source"); + } + return new EsQueryExec(plan.source(), plan.index(), plan.indexMode(), List.of(docId, tsid, timestamp), plan.query()); + } else { + return new EsQueryExec(plan.source(), plan.index(), plan.indexMode(), List.of(docId), plan.query()); + } } } @@ -551,4 +582,114 @@ private boolean allowedForDocValues(FieldAttribute fieldAttribute, AggregateExec return spatialRelatesAttributes.size() < 2; } } + + /** + * When a spatial distance predicate can be pushed down to lucene, this is done by capturing the distance within the same function. + * In principle this is like re-writing the predicate: + *
    WHERE ST_DISTANCE(field, TO_GEOPOINT("POINT(0 0)")) <= 10000
    + * as: + *
    WHERE ST_INTERSECTS(field, TO_GEOSHAPE("CIRCLE(0,0,10000)"))
    + */ + public static class EnableSpatialDistancePushdown extends PhysicalOptimizerRules.ParameterizedOptimizerRule< + FilterExec, + LocalPhysicalOptimizerContext> { + + @Override + protected PhysicalPlan rule(FilterExec filterExec, LocalPhysicalOptimizerContext ctx) { + PhysicalPlan plan = filterExec; + if (filterExec.child() instanceof EsQueryExec) { + if (filterExec.condition() instanceof EsqlBinaryComparison comparison) { + ComparisonType comparisonType = ComparisonType.from(comparison.getFunctionType()); + if (comparison.left() instanceof StDistance dist && comparison.right().foldable()) { + plan = rewriteComparison(filterExec, dist, comparison.right(), comparisonType); + } else if (comparison.right() instanceof StDistance dist && comparison.left().foldable()) { + plan = rewriteComparison(filterExec, dist, comparison.right(), ComparisonType.invert(comparisonType)); + } + } + } + + return plan; + } + + private FilterExec rewriteComparison(FilterExec filterExec, StDistance dist, Expression literal, ComparisonType comparisonType) { + // We currently only support spatial distance within a minimum range + if (comparisonType.lt) { + Object value = literal.fold(); + if (value instanceof Number number) { + if (dist.right().foldable()) { + return rewriteDistanceFilter(filterExec, dist.source(), dist.left(), dist.right(), number, comparisonType.eq); + } else if (dist.left().foldable()) { + return rewriteDistanceFilter(filterExec, dist.source(), dist.right(), dist.left(), number, comparisonType.eq); + } + } + } + return filterExec; + } + + private FilterExec rewriteDistanceFilter( + FilterExec filterExec, + Source source, + Expression spatialExpression, + Expression literalExpression, + Number number, + boolean inclusive + ) { + Geometry geometry = SpatialRelatesUtils.makeGeometryFromLiteral(literalExpression); + if (geometry instanceof Point point) { + double distance = number.doubleValue(); + if (inclusive == false) { + distance = Math.nextDown(distance); + } + var circle = new Circle(point.getX(), point.getY(), distance); + var wkb = WellKnownBinary.toWKB(circle, ByteOrder.LITTLE_ENDIAN); + var cExp = new Literal(literalExpression.source(), new BytesRef(wkb), DataType.GEO_SHAPE); + return new FilterExec(filterExec.source(), filterExec.child(), new SpatialIntersects(source, spatialExpression, cExp)); + } + return filterExec; + } + + /** + * This enum captures the key differences between various inequalities as perceived from the spatial distance function. + * In particular, we need to know which direction the inequality points, with lt=true meaning the left is expected to be smaller + * than the right. And eq=true meaning we expect euality as well. We currently don't support Equals and NotEquals, so the third + * field disables those. + */ + enum ComparisonType { + LTE(true, true, true), + LT(true, false, true), + GTE(false, true, true), + GT(false, false, true), + UNSUPPORTED(false, false, false); + + private final boolean lt; + private final boolean eq; + private final boolean supported; + + ComparisonType(boolean lt, boolean eq, boolean supported) { + this.lt = lt; + this.eq = eq; + this.supported = supported; + } + + static ComparisonType from(EsqlBinaryComparison.BinaryComparisonOperation op) { + return switch (op) { + case LT -> LT; + case LTE -> LTE; + case GT -> GT; + case GTE -> GTE; + default -> UNSUPPORTED; + }; + } + + static ComparisonType invert(ComparisonType comparisonType) { + return switch (comparisonType) { + case LT -> GT; + case LTE -> GTE; + case GT -> LT; + case GTE -> LTE; + default -> UNSUPPORTED; + }; + } + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index aaf9f8e63d795..ca4b5d17deed3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -66,6 +66,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.SplitInWithFoldableValue; import org.elasticsearch.xpack.esql.optimizer.rules.SubstituteSpatialSurrogates; import org.elasticsearch.xpack.esql.optimizer.rules.SubstituteSurrogates; +import org.elasticsearch.xpack.esql.optimizer.rules.TranslateMetricsAggregate; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; @@ -115,6 +116,9 @@ protected static Batch substitutions() { new ReplaceStatsAggExpressionWithEval(), // lastly replace surrogate functions new SubstituteSurrogates(), + // translate metric aggregates after surrogate substitution and replace nested expressions with eval (again) + new TranslateMetricsAggregate(), + new ReplaceStatsNestedExpressionWithEval(), new ReplaceRegexMatch(), new ReplaceTrivialTypeConversions(), new ReplaceAliasingEvalWithProject(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineProjections.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineProjections.java index 940c08ffb97f1..2070139519ea0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineProjections.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineProjections.java @@ -53,7 +53,7 @@ protected LogicalPlan rule(UnaryPlan plan) { // project can be fully removed if (newAggs != null) { var newGroups = replacePrunedAliasesUsedInGroupBy(a.groupings(), aggs, newAggs); - plan = new Aggregate(a.source(), a.child(), newGroups, newAggs); + plan = new Aggregate(a.source(), a.child(), a.aggregateType(), newGroups, newAggs); } } return plan; @@ -75,6 +75,7 @@ protected LogicalPlan rule(UnaryPlan plan) { plan = new Aggregate( a.source(), p.child(), + a.aggregateType(), combineUpperGroupingsAndLowerProjections(groupingAttrs, p.projections()), combineProjections(a.aggregates(), p.projections()) ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneColumns.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneColumns.java index cb0224c9c834d..9403e3996ec49 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneColumns.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneColumns.java @@ -69,10 +69,22 @@ public LogicalPlan apply(LogicalPlan plan) { } else { // Aggs cannot produce pages with 0 columns, so retain one grouping. remaining = List.of(Expressions.attribute(aggregate.groupings().get(0))); - p = new Aggregate(aggregate.source(), aggregate.child(), aggregate.groupings(), remaining); + p = new Aggregate( + aggregate.source(), + aggregate.child(), + aggregate.aggregateType(), + aggregate.groupings(), + remaining + ); } } else { - p = new Aggregate(aggregate.source(), aggregate.child(), aggregate.groupings(), remaining); + p = new Aggregate( + aggregate.source(), + aggregate.child(), + aggregate.aggregateType(), + aggregate.groupings(), + remaining + ); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/RemoveStatsOverride.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/RemoveStatsOverride.java index cf04637e456a5..cbcde663f8b14 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/RemoveStatsOverride.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/RemoveStatsOverride.java @@ -50,7 +50,7 @@ private static Aggregate removeAggDuplicates(Aggregate agg) { aggregates = removeDuplicateNames(aggregates); // replace EsqlAggregate with Aggregate - return new Aggregate(agg.source(), agg.child(), groupings, aggregates); + return new Aggregate(agg.source(), agg.child(), agg.aggregateType(), groupings, aggregates); } private static List removeDuplicateNames(List list) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsAggExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsAggExpressionWithEval.java index 9a24926953947..012d6e307df6c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsAggExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsAggExpressionWithEval.java @@ -138,7 +138,7 @@ protected LogicalPlan rule(Aggregate aggregate) { LogicalPlan plan = aggregate; if (changed.get()) { Source source = aggregate.source(); - plan = new Aggregate(source, aggregate.child(), aggregate.groupings(), newAggs); + plan = new Aggregate(source, aggregate.child(), aggregate.aggregateType(), aggregate.groupings(), newAggs); if (newEvals.size() > 0) { plan = new Eval(source, plan, newEvals); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsNestedExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsNestedExpressionWithEval.java index dc7686f57f2f4..99b0c8047f2ba 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsNestedExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsNestedExpressionWithEval.java @@ -73,6 +73,15 @@ protected LogicalPlan rule(Aggregate aggregate) { // if the child is a nested expression Expression child = as.child(); + // do not replace nested aggregates + if (child instanceof AggregateFunction af) { + Holder foundNestedAggs = new Holder<>(Boolean.FALSE); + af.children().forEach(e -> e.forEachDown(AggregateFunction.class, unused -> foundNestedAggs.set(Boolean.TRUE))); + if (foundNestedAggs.get()) { + return as; + } + } + // shortcut for common scenario if (child instanceof AggregateFunction af && af.field() instanceof Attribute) { return as; @@ -125,7 +134,7 @@ protected LogicalPlan rule(Aggregate aggregate) { var aggregates = aggsChanged.get() ? newAggs : aggregate.aggregates(); var newEval = new Eval(aggregate.source(), aggregate.child(), evals); - aggregate = new Aggregate(aggregate.source(), newEval, groupings, aggregates); + aggregate = new Aggregate(aggregate.source(), newEval, aggregate.aggregateType(), groupings, aggregates); } return aggregate; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSurrogates.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSurrogates.java index 39617b443a286..fa4049b0e5a3a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSurrogates.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSurrogates.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.expression.SurrogateExpression; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Project; @@ -70,6 +71,10 @@ protected LogicalPlan rule(Aggregate aggregate) { if (s instanceof AggregateFunction == false) { // 1. collect all aggregate functions from the expression var surrogateWithRefs = s.transformUp(AggregateFunction.class, af -> { + // TODO: more generic than this? + if (af instanceof Rate) { + return af; + } // 2. check if they are already use otherwise add them to the Aggregate with some made-up aliases // 3. replace them inside the expression using the given alias var attr = aggFuncToAttr.get(af); @@ -103,7 +108,7 @@ protected LogicalPlan rule(Aggregate aggregate) { if (changed) { var source = aggregate.source(); if (newAggs.isEmpty() == false) { - plan = new Aggregate(source, aggregate.child(), aggregate.groupings(), newAggs); + plan = new Aggregate(source, aggregate.child(), aggregate.aggregateType(), aggregate.groupings(), newAggs); } else { // All aggs actually have been surrogates for (foldable) expressions, e.g. // \_Aggregate[[],[AVG([1, 2][INTEGER]) AS s]] diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/TranslateMetricsAggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/TranslateMetricsAggregate.java new file mode 100644 index 0000000000000..0a62dccf80c1f --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/TranslateMetricsAggregate.java @@ -0,0 +1,211 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.index.IndexMode; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.AttributeSet; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; +import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.core.util.Holder; +import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; +import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; +import org.elasticsearch.xpack.esql.plan.logical.Aggregate; +import org.elasticsearch.xpack.esql.plan.logical.EsRelation; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +/** + * Rate aggregation is special because it must be computed per time series, regardless of the grouping keys. + * The keys must be `_tsid` or a pair of `_tsid` and `time_bucket`. To support user-defined grouping keys, + * we first execute the rate aggregation using the time-series keys, then perform another aggregation with + * the resulting rate using the user-specific keys. + *

    + * This class translates the aggregates in the METRICS commands to standard aggregates. + * This approach helps avoid introducing new plans and operators for metrics aggregations specially. + *

    + * Examples: + *

    + * METRICS k8s max(rate(request))
    + *
    + * becomes
    + *
    + * METRICS k8s
    + * | STATS rate(request) BY _tsid
    + * | STATS max(`rate(request)`)
    + *
    + * METRICS k8s max(rate(request)) BY host
    + *
    + * becomes
    + *
    + * METRICS k8s
    + * | STATS rate(request), VALUES(host) BY _tsid
    + * | STATS max(`rate(request)`) BY host=`VALUES(host)`
    + *
    + * METRICS k8s avg(rate(request)) BY host
    + *
    + * becomes
    + *
    + * METRICS k8s
    + * | STATS rate(request), VALUES(host) BY _tsid
    + * | STATS sum=sum(`rate(request)`), count(`rate(request)`) BY host=`VALUES(host)`
    + * | EVAL `avg(rate(request))` = `sum(rate(request))` / `count(rate(request))`
    + * | KEEP `avg(rate(request))`, host
    + *
    + * METRICS k8s avg(rate(request)) BY host, bucket(@timestamp, 1minute)
    + *
    + * becomes
    + *
    + * METRICS k8s
    + * | EVAL  `bucket(@timestamp, 1minute)`=datetrunc(@timestamp, 1minute)
    + * | STATS rate(request), VALUES(host) BY _tsid,`bucket(@timestamp, 1minute)`
    + * | STATS sum=sum(`rate(request)`), count(`rate(request)`) BY host=`VALUES(host)`, `bucket(@timestamp, 1minute)`
    + * | EVAL `avg(rate(request))` = `sum(rate(request))` / `count(rate(request))`
    + * | KEEP `avg(rate(request))`, host, `bucket(@timestamp, 1minute)`
    + * 
    + * Mixing between rate and non-rate aggregates will be supported later. + */ +public final class TranslateMetricsAggregate extends OptimizerRules.OptimizerRule { + + public TranslateMetricsAggregate() { + super(OptimizerRules.TransformDirection.UP); + } + + @Override + protected LogicalPlan rule(Aggregate aggregate) { + if (aggregate.aggregateType() == Aggregate.AggregateType.METRICS) { + return translate(aggregate); + } else { + return aggregate; + } + } + + LogicalPlan translate(Aggregate metrics) { + Map rateAggs = new HashMap<>(); // TODO + List nonRateAggs = new ArrayList<>(); + List outerRateAggs = new ArrayList<>(); + for (NamedExpression agg : metrics.aggregates()) { + if (agg instanceof Alias alias) { + // METRICS af(rate(counter)) becomes STATS $rate_1=rate(counter) | STATS `af(rate(counter))`=af($rate_1) + if (alias.child() instanceof AggregateFunction outerRate) { + Holder changed = new Holder<>(Boolean.FALSE); + Expression outerAgg = outerRate.transformDown(Rate.class, rate -> { + changed.set(Boolean.TRUE); + Alias rateAgg = rateAggs.computeIfAbsent(rate, k -> new Alias(rate.source(), agg.name(), rate)); + return rateAgg.toAttribute(); + }); + if (changed.get()) { + outerRateAggs.add(new Alias(alias.source(), alias.name(), null, outerAgg, agg.id())); + } + } else { + nonRateAggs.add(agg); + } + } + } + if (rateAggs.isEmpty()) { + return toStandardAggregate(metrics); + } + if (nonRateAggs.isEmpty() == false) { + // TODO: support this + throw new IllegalArgumentException("regular aggregates with rate aggregates are not supported yet"); + } + Holder tsid = new Holder<>(); + Holder timestamp = new Holder<>(); + metrics.forEachDown(EsRelation.class, r -> { + for (Attribute attr : r.output()) { + if (attr.name().equals(MetadataAttribute.TSID_FIELD)) { + tsid.set(attr); + } + if (attr.name().equals(MetadataAttribute.TIMESTAMP_FIELD)) { + timestamp.set(attr); + } + } + }); + if (tsid.get() == null || timestamp.get() == null) { + throw new IllegalArgumentException("_tsid or @timestamp field are missing from the metrics source"); + } + // metrics aggregates must be grouped by _tsid (and time-bucket) first and re-group by users key + List initialGroupings = new ArrayList<>(); + initialGroupings.add(tsid.get()); + List finalGroupings = new ArrayList<>(); + Holder timeBucketRef = new Holder<>(); + metrics.child().forEachExpressionUp(NamedExpression.class, e -> { + for (Expression child : e.children()) { + if (child instanceof Bucket bucket && bucket.field().equals(timestamp.get())) { + if (timeBucketRef.get() != null) { + throw new IllegalArgumentException("expected at most one time bucket"); + } + timeBucketRef.set(e); + } + } + }); + NamedExpression timeBucket = timeBucketRef.get(); + List initialAggs = new ArrayList<>(rateAggs.values()); + for (Expression group : metrics.groupings()) { + if (group instanceof Attribute == false) { + throw new EsqlIllegalArgumentException("expected named expression for grouping; got " + group); + } + final Attribute g = (Attribute) group; + final NamedExpression newFinalGroup; + if (timeBucket != null && g.id().equals(timeBucket.id())) { + newFinalGroup = timeBucket.toAttribute(); + initialGroupings.add(newFinalGroup); + } else { + newFinalGroup = new Alias(g.source(), g.name(), null, new Values(g.source(), g), g.id()); + initialAggs.add(newFinalGroup); + } + finalGroupings.add(new Alias(g.source(), g.name(), null, newFinalGroup.toAttribute(), g.id())); + } + var finalAggregates = Stream.concat(outerRateAggs.stream(), nonRateAggs.stream()).toList(); + return newAggregate( + newAggregate(metrics.child(), Aggregate.AggregateType.METRICS, initialAggs, initialGroupings), + Aggregate.AggregateType.STANDARD, + finalAggregates, + finalGroupings + ); + } + + private static Aggregate toStandardAggregate(Aggregate metrics) { + final LogicalPlan child = metrics.child().transformDown(EsRelation.class, r -> { + var attributes = new ArrayList<>(new AttributeSet(metrics.inputSet())); + attributes.removeIf(a -> a.name().equals(MetadataAttribute.TSID_FIELD)); + if (attributes.stream().noneMatch(a -> a.name().equals(MetadataAttribute.TIMESTAMP_FIELD)) == false) { + attributes.removeIf(a -> a.name().equals(MetadataAttribute.TIMESTAMP_FIELD)); + } + return new EsRelation(r.source(), r.index(), new ArrayList<>(attributes), IndexMode.STANDARD); + }); + return new Aggregate(metrics.source(), child, Aggregate.AggregateType.STANDARD, metrics.groupings(), metrics.aggregates()); + } + + private static Aggregate newAggregate( + LogicalPlan child, + Aggregate.AggregateType type, + List aggregates, + List groupings + ) { + return new Aggregate( + child.source(), + child, + type, + groupings, + Stream.concat(aggregates.stream(), groupings.stream().map(Expressions::attribute)).toList() + ); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java index 3bd34942448e8..353e7738fccc3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java @@ -41,6 +41,7 @@ import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.UnresolvedNamePattern; import org.elasticsearch.xpack.esql.parser.EsqlBaseParser.MetadataOptionContext; +import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Dissect; import org.elasticsearch.xpack.esql.plan.logical.Drop; import org.elasticsearch.xpack.esql.plan.logical.Enrich; @@ -258,7 +259,7 @@ public LogicalPlan visitFromCommand(EsqlBaseParser.FromCommandContext ctx) { @Override public PlanFactory visitStatsCommand(EsqlBaseParser.StatsCommandContext ctx) { final Stats stats = stats(source(ctx), ctx.grouping, ctx.stats); - return input -> new EsqlAggregate(source(ctx), input, stats.groupings, stats.aggregates); + return input -> new EsqlAggregate(source(ctx), input, Aggregate.AggregateType.STANDARD, stats.groupings, stats.aggregates); } private record Stats(List groupings, List aggregates) { @@ -438,12 +439,18 @@ public LogicalPlan visitMetricsCommand(EsqlBaseParser.MetricsCommandContext ctx) } Source source = source(ctx); TableIdentifier table = new TableIdentifier(source, null, visitIndexIdentifiers(ctx.indexIdentifier())); - var unresolvedRelation = new EsqlUnresolvedRelation(source, table, List.of(), IndexMode.TIME_SERIES); + if (ctx.aggregates == null && ctx.grouping == null) { - return unresolvedRelation; + return new EsqlUnresolvedRelation(source, table, List.of(), IndexMode.STANDARD); } final Stats stats = stats(source, ctx.grouping, ctx.aggregates); - return new EsqlAggregate(source, unresolvedRelation, stats.groupings, stats.aggregates); + var relation = new EsqlUnresolvedRelation( + source, + table, + List.of(new MetadataAttribute(source, MetadataAttribute.TSID_FIELD, DataType.KEYWORD, false)), + IndexMode.TIME_SERIES + ); + return new EsqlAggregate(source, relation, Aggregate.AggregateType.METRICS, stats.groupings, stats.aggregates); } @Override @@ -453,11 +460,17 @@ public PlanFactory visitLookupCommand(EsqlBaseParser.LookupCommandContext ctx) { } var source = source(ctx); - List matchFields = visitQualifiedNamePatterns(ctx.qualifiedNamePatterns(), ne -> { + @SuppressWarnings("unchecked") + List matchFields = (List) (List) visitQualifiedNamePatterns(ctx.qualifiedNamePatterns(), ne -> { if (ne instanceof UnresolvedNamePattern || ne instanceof UnresolvedStar) { var src = ne.source(); throw new ParsingException(src, "Using wildcards [*] in LOOKUP ON is not allowed yet [{}]", src.text()); } + if ((ne instanceof UnresolvedAttribute) == false) { + throw new IllegalStateException( + "visitQualifiedNamePatterns can only return UnresolvedNamePattern, UnresolvedStar or UnresolvedAttribute" + ); + } }); Literal tableName = new Literal(source, ctx.tableName.getText(), DataType.KEYWORD); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java index 8827e843939b6..5a44c36a81b2d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java @@ -6,6 +6,9 @@ */ package org.elasticsearch.xpack.esql.plan.logical; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -15,29 +18,87 @@ import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import java.io.IOException; import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanReader.readerFromPlanReader; +import static org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanWriter.writerFromPlanWriter; + public class Aggregate extends UnaryPlan { + public enum AggregateType { + STANDARD, + // include metrics aggregates such as rates + METRICS; + + static void writeType(StreamOutput out, AggregateType type) throws IOException { + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_ADD_AGGREGATE_TYPE)) { + out.writeString(type.name()); + } else if (type != STANDARD) { + throw new IllegalStateException("cluster is not ready to support aggregate type [" + type + "]"); + } + } + + static AggregateType readType(StreamInput in) throws IOException { + if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_ADD_AGGREGATE_TYPE)) { + return AggregateType.valueOf(in.readString()); + } else { + return STANDARD; + } + } + } + + private final AggregateType aggregateType; private final List groupings; private final List aggregates; - public Aggregate(Source source, LogicalPlan child, List groupings, List aggregates) { + public Aggregate( + Source source, + LogicalPlan child, + AggregateType aggregateType, + List groupings, + List aggregates + ) { super(source, child); + this.aggregateType = aggregateType; this.groupings = groupings; this.aggregates = aggregates; } + public Aggregate(PlanStreamInput in) throws IOException { + this( + Source.readFrom(in), + in.readLogicalPlanNode(), + AggregateType.readType(in), + in.readCollectionAsList(readerFromPlanReader(org.elasticsearch.xpack.esql.io.stream.PlanStreamInput::readExpression)), + in.readNamedWriteableCollectionAsList(NamedExpression.class) + ); + } + + public static void writeAggregate(PlanStreamOutput out, Aggregate aggregate) throws IOException { + Source.EMPTY.writeTo(out); + out.writeLogicalPlanNode(aggregate.child()); + AggregateType.writeType(out, aggregate.aggregateType()); + out.writeCollection(aggregate.groupings(), writerFromPlanWriter(PlanStreamOutput::writeExpression)); + out.writeNamedWriteableCollection(aggregate.aggregates()); + } + @Override protected NodeInfo info() { - return NodeInfo.create(this, Aggregate::new, child(), groupings, aggregates); + return NodeInfo.create(this, Aggregate::new, child(), aggregateType, groupings, aggregates); } @Override public Aggregate replaceChild(LogicalPlan newChild) { - return new Aggregate(source(), newChild, groupings, aggregates); + return new Aggregate(source(), newChild, aggregateType, groupings, aggregates); + } + + public AggregateType aggregateType() { + return aggregateType; } public List groupings() { @@ -60,7 +121,7 @@ public List output() { @Override public int hashCode() { - return Objects.hash(groupings, aggregates, child()); + return Objects.hash(aggregateType, groupings, aggregates, child()); } @Override @@ -74,7 +135,8 @@ public boolean equals(Object obj) { } Aggregate other = (Aggregate) obj; - return Objects.equals(groupings, other.groupings) + return aggregateType == other.aggregateType + && Objects.equals(groupings, other.groupings) && Objects.equals(aggregates, other.aggregates) && Objects.equals(child(), other.child()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlAggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlAggregate.java index 6cda14600e840..7f16ecd24dc1a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlAggregate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlAggregate.java @@ -34,8 +34,14 @@ public class EsqlAggregate extends Aggregate { private List lazyOutput; - public EsqlAggregate(Source source, LogicalPlan child, List groupings, List aggregates) { - super(source, child, groupings, aggregates); + public EsqlAggregate( + Source source, + LogicalPlan child, + AggregateType aggregateType, + List groupings, + List aggregates + ) { + super(source, child, aggregateType, groupings, aggregates); } @Override @@ -49,11 +55,11 @@ public List output() { @Override protected NodeInfo info() { - return NodeInfo.create(this, EsqlAggregate::new, child(), groupings(), aggregates()); + return NodeInfo.create(this, EsqlAggregate::new, child(), aggregateType(), groupings(), aggregates()); } @Override public EsqlAggregate replaceChild(LogicalPlan newChild) { - return new EsqlAggregate(source(), newChild, groupings(), aggregates()); + return new EsqlAggregate(source(), newChild, aggregateType(), groupings(), aggregates()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlUnresolvedRelation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlUnresolvedRelation.java index ffc4818b6ceb5..5b9cabd3807a2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlUnresolvedRelation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlUnresolvedRelation.java @@ -9,6 +9,9 @@ import org.elasticsearch.index.IndexMode; import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.AttributeSet; +import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; +import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.plan.TableIdentifier; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -44,6 +47,16 @@ public IndexMode indexMode() { return indexMode; } + @Override + public AttributeSet references() { + AttributeSet refs = super.references(); + if (indexMode == IndexMode.TIME_SERIES) { + refs = new AttributeSet(refs); + refs.add(new UnresolvedAttribute(source(), MetadataAttribute.TIMESTAMP_FIELD)); + } + return refs; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, EsqlUnresolvedRelation::new, table(), metadataFields(), indexMode(), unresolvedMessage()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Lookup.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Lookup.java index 690e4595f64e5..36bff408e3199 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Lookup.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Lookup.java @@ -11,7 +11,6 @@ import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; @@ -39,7 +38,7 @@ public class Lookup extends UnaryPlan { /** * References to the input fields to match against the {@link #localRelation}. */ - private final List matchFields; + private final List matchFields; // initialized during the analysis phase for output and validation // afterward, it is converted into a Join (BinaryPlan) hence why here it is not a child private final LocalRelation localRelation; @@ -49,7 +48,7 @@ public Lookup( Source source, LogicalPlan child, Expression tableName, - List matchFields, + List matchFields, @Nullable LocalRelation localRelation ) { super(source, child); @@ -61,7 +60,7 @@ public Lookup( public Lookup(PlanStreamInput in) throws IOException { super(Source.readFrom(in), in.readLogicalPlanNode()); this.tableName = in.readExpression(); - this.matchFields = in.readNamedWriteableCollectionAsList(NamedExpression.class); + this.matchFields = in.readNamedWriteableCollectionAsList(Attribute.class); this.localRelation = in.readBoolean() ? new LocalRelation(in) : null; } @@ -82,7 +81,7 @@ public Expression tableName() { return tableName; } - public List matchFields() { + public List matchFields() { return matchFields; } @@ -122,10 +121,10 @@ protected NodeInfo info() { @Override public List output() { if (lazyOutput == null) { - List rightSide = localRelation != null - ? Join.makeNullable(Join.makeReference(localRelation.output())) - : Expressions.asAttributes(matchFields); - lazyOutput = Join.mergeOutput(child().output(), rightSide, matchFields); + if (localRelation == null) { + throw new IllegalStateException("Cannot determine output of LOOKUP with unresolved table"); + } + lazyOutput = Join.computeOutput(child().output(), localRelation.output(), joinConfig()); } return lazyOutput; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java index 81ec67a28bbfd..57a52ad1a1cf8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java @@ -8,7 +8,8 @@ package org.elasticsearch.xpack.esql.plan.logical.join; import org.elasticsearch.xpack.esql.core.expression.Attribute; -import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.expression.AttributeSet; +import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.plan.logical.BinaryPlan; @@ -20,8 +21,12 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Set; + +import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes; public class Join extends BinaryPlan { @@ -68,47 +73,41 @@ public Join replaceChildren(LogicalPlan left, LogicalPlan right) { @Override public List output() { if (lazyOutput == null) { - lazyOutput = computeOutput(); + lazyOutput = computeOutput(left().output(), right().output(), config); } return lazyOutput; } - private List computeOutput() { - List right = makeReference(right().output()); + /** + * Merge output fields. + * Currently only implemented for LEFT JOINs; the rightOutput shadows the leftOutput, except for any attributes that + * occur in the join's matchFields. + */ + public static List computeOutput(List leftOutput, List rightOutput, JoinConfig config) { + AttributeSet matchFieldSet = new AttributeSet(config.matchFields()); + Set matchFieldNames = new HashSet<>(Expressions.names(config.matchFields())); return switch (config.type()) { - case LEFT -> // right side becomes nullable - mergeOutput(left().output(), makeNullable(right), config.matchFields()); - case RIGHT -> // left side becomes nullable - mergeOutput(makeNullable(left().output()), right, config.matchFields()); - case FULL -> // both sides become nullable - mergeOutput(makeNullable(left().output()), makeNullable(right), config.matchFields()); - default -> // neither side becomes nullable - mergeOutput(left().output(), right, config.matchFields()); + case LEFT -> { + // Right side becomes nullable. + List fieldsAddedFromRight = removeCollisionsWithMatchFields(rightOutput, matchFieldSet, matchFieldNames); + yield mergeOutputAttributes(makeNullable(makeReference(fieldsAddedFromRight)), leftOutput); + } + default -> throw new UnsupportedOperationException("Other JOINs than LEFT not supported"); }; } - /** - * Merge output fields, left hand side wins in name conflicts except - * for fields defined in {@link JoinConfig#matchFields()}. - */ - public static List mergeOutput( - List lhs, - List rhs, - List matchFields + private static List removeCollisionsWithMatchFields( + List attributes, + AttributeSet matchFields, + Set matchFieldNames ) { - List results = new ArrayList<>(lhs.size() + rhs.size()); - - for (Attribute a : lhs) { - if (rhs.contains(a) == false || matchFields.stream().anyMatch(m -> m.name().equals(a.name()))) { - results.add(a); - } - } - for (Attribute a : rhs) { - if (false == matchFields.stream().anyMatch(m -> m.name().equals(a.name()))) { - results.add(a); + List result = new ArrayList<>(); + for (Attribute attr : attributes) { + if ((matchFields.contains(attr) || matchFieldNames.contains(attr.name())) == false) { + result.add(attr); } } - return results; + return result; } /** @@ -125,7 +124,7 @@ public static List mergeOutput( public static List makeReference(List output) { List out = new ArrayList<>(output.size()); for (Attribute a : output) { - if (a.resolved()) { + if (a.resolved() && a instanceof ReferenceAttribute == false) { out.add(new ReferenceAttribute(a.source(), a.name(), a.dataType(), a.qualifier(), a.nullable(), a.id(), a.synthetic())); } else { out.add(a); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/JoinConfig.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/JoinConfig.java index b5cf5d9234c6b..6b603709b3972 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/JoinConfig.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/JoinConfig.java @@ -11,8 +11,8 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; @@ -24,11 +24,11 @@ * @param matchFields fields that are merged from the left and right relations * @param conditions when these conditions are true the rows are joined */ -public record JoinConfig(JoinType type, List matchFields, List conditions) implements Writeable { +public record JoinConfig(JoinType type, List matchFields, List conditions) implements Writeable { public JoinConfig(StreamInput in) throws IOException { this( JoinType.readFrom(in), - in.readNamedWriteableCollectionAsList(NamedExpression.class), + in.readNamedWriteableCollectionAsList(Attribute.class), in.readCollectionAsList(i -> ((PlanStreamInput) i).readExpression()) ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java index 13773ca61f8d8..b8f96709a583f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java @@ -28,9 +28,7 @@ import java.util.Objects; public class EsQueryExec extends LeafExec implements EstimatesRowSize { - static final EsField DOC_ID_FIELD = new EsField("_doc", DataType.DOC_DATA_TYPE, Map.of(), false); - static final EsField TSID_FIELD = new EsField("_tsid", DataType.TSID_DATA_TYPE, Map.of(), true); - static final EsField TIMESTAMP_FIELD = new EsField("@timestamp", DataType.DATETIME, Map.of(), true); + public static final EsField DOC_ID_FIELD = new EsField("_doc", DataType.DOC_DATA_TYPE, Map.of(), false); private final EsIndex index; private final IndexMode indexMode; @@ -55,8 +53,8 @@ public FieldSortBuilder fieldSortBuilder() { } } - public EsQueryExec(Source source, EsIndex index, IndexMode indexMode, QueryBuilder query) { - this(source, index, indexMode, sourceAttributes(source, indexMode), query, null, null, null); + public EsQueryExec(Source source, EsIndex index, IndexMode indexMode, List attributes, QueryBuilder query) { + this(source, index, indexMode, attributes, query, null, null, null); } public EsQueryExec( @@ -79,17 +77,6 @@ public EsQueryExec( this.estimatedRowSize = estimatedRowSize; } - private static List sourceAttributes(Source source, IndexMode indexMode) { - return switch (indexMode) { - case STANDARD, LOGS -> List.of(new FieldAttribute(source, DOC_ID_FIELD.getName(), DOC_ID_FIELD)); - case TIME_SERIES -> List.of( - new FieldAttribute(source, DOC_ID_FIELD.getName(), DOC_ID_FIELD), - new FieldAttribute(source, TSID_FIELD.getName(), TSID_FIELD), - new FieldAttribute(source, TIMESTAMP_FIELD.getName(), TIMESTAMP_FIELD) - ); - }; - } - public static boolean isSourceAttribute(Attribute attr) { return DOC_ID_FIELD.getName().equals(attr.name()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/HashJoinExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/HashJoinExec.java index dff0a6f0eade3..29cf079f317be 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/HashJoinExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/HashJoinExec.java @@ -9,7 +9,6 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; -import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; @@ -24,7 +23,7 @@ public class HashJoinExec extends UnaryExec implements EstimatesRowSize { private final LocalSourceExec joinData; - private final List matchFields; + private final List matchFields; /** * Conditions that must match for rows to be joined. The {@link Equals#left()} * is always from the child and the {@link Equals#right()} is always from the @@ -38,7 +37,7 @@ public HashJoinExec( Source source, PhysicalPlan child, LocalSourceExec hashData, - List matchFields, + List matchFields, List conditions, List output ) { @@ -52,7 +51,7 @@ public HashJoinExec( public HashJoinExec(PlanStreamInput in) throws IOException { super(Source.readFrom(in), in.readPhysicalPlanNode()); this.joinData = new LocalSourceExec(in); - this.matchFields = in.readNamedWriteableCollectionAsList(NamedExpression.class); + this.matchFields = in.readNamedWriteableCollectionAsList(Attribute.class); this.conditions = in.readCollectionAsList(i -> (Equals) EsqlBinaryComparison.readFrom(in)); this.output = in.readNamedWriteableCollectionAsList(Attribute.class); } @@ -70,7 +69,7 @@ public LocalSourceExec joinData() { return joinData; } - public List matchFields() { + public List matchFields() { return matchFields; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 825057c20d0e0..9e1e1a50fe8f0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -59,7 +59,6 @@ import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.DriverParallelism; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.PhysicalOperation; -import org.elasticsearch.xpack.esql.type.EsqlDataTypes; import org.elasticsearch.xpack.esql.type.MultiTypeEsField; import java.io.IOException; @@ -118,7 +117,7 @@ public final PhysicalOperation fieldExtractPhysicalOperation(FieldExtractExec fi MappedFieldType.FieldExtractPreference fieldExtractPreference = PlannerUtils.extractPreference(docValuesAttrs.contains(attr)); ElementType elementType = PlannerUtils.toElementType(dataType, fieldExtractPreference); String fieldName = attr.name(); - boolean isUnsupported = EsqlDataTypes.isUnsupported(dataType); + boolean isUnsupported = dataType == DataType.UNSUPPORTED; IntFunction loader = s -> getBlockLoaderFor(s, fieldName, isUnsupported, fieldExtractPreference, unionTypes); fields.add(new ValuesSourceReaderOperator.FieldInfo(fieldName, elementType, loader)); } @@ -233,7 +232,7 @@ public final Operator.OperatorFactory ordinalGroupingOperatorFactory( .toList(); // The grouping-by values are ready, let's group on them directly. // Costin: why are they ready and not already exposed in the layout? - boolean isUnsupported = EsqlDataTypes.isUnsupported(attrSource.dataType()); + boolean isUnsupported = attrSource.dataType() == DataType.UNSUPPORTED; return new OrdinalsGroupingOperator.OrdinalsGroupingOperatorFactory( shardIdx -> shardContexts.get(shardIdx).blockLoader(attrSource.name(), isUnsupported, NONE), vsShardContexts, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java index cc28839fd6575..20138f34c1041 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java @@ -60,6 +60,7 @@ import static java.util.Arrays.asList; import static org.elasticsearch.index.mapper.MappedFieldType.FieldExtractPreference.DOC_VALUES; +import static org.elasticsearch.index.mapper.MappedFieldType.FieldExtractPreference.NONE; import static org.elasticsearch.xpack.esql.core.util.Queries.Clause.FILTER; import static org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizer.PushFiltersToSource.canPushToSource; import static org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizer.TRANSLATOR_HANDLER; @@ -230,7 +231,7 @@ public static ElementType toSortableElementType(DataType dataType) { * Map QL's {@link DataType} to the compute engine's {@link ElementType}. */ public static ElementType toElementType(DataType dataType) { - return toElementType(dataType, MappedFieldType.FieldExtractPreference.NONE); + return toElementType(dataType, NONE); } /** @@ -239,47 +240,22 @@ public static ElementType toElementType(DataType dataType) { * For example, spatial types can be extracted into doc-values under specific conditions, otherwise they extract as BytesRef. */ public static ElementType toElementType(DataType dataType, MappedFieldType.FieldExtractPreference fieldExtractPreference) { - if (dataType == DataType.LONG - || dataType == DataType.DATETIME - || dataType == DataType.UNSIGNED_LONG - || dataType == DataType.COUNTER_LONG) { - return ElementType.LONG; - } - if (dataType == DataType.INTEGER || dataType == DataType.COUNTER_INTEGER) { - return ElementType.INT; - } - if (dataType == DataType.DOUBLE || dataType == DataType.COUNTER_DOUBLE) { - return ElementType.DOUBLE; - } - // unsupported fields are passed through as a BytesRef - if (dataType == DataType.KEYWORD - || dataType == DataType.TEXT - || dataType == DataType.IP - || dataType == DataType.SOURCE - || dataType == DataType.VERSION - || dataType == DataType.UNSUPPORTED) { - return ElementType.BYTES_REF; - } - if (dataType == DataType.NULL) { - return ElementType.NULL; - } - if (dataType == DataType.BOOLEAN) { - return ElementType.BOOLEAN; - } - if (dataType == DataType.DOC_DATA_TYPE) { - return ElementType.DOC; - } - if (dataType == DataType.TSID_DATA_TYPE) { - return ElementType.BYTES_REF; - } - if (EsqlDataTypes.isSpatialPoint(dataType)) { - return fieldExtractPreference == DOC_VALUES ? ElementType.LONG : ElementType.BYTES_REF; - } - if (EsqlDataTypes.isSpatial(dataType)) { - // TODO: support forStats for shape aggregations, like st_centroid - return ElementType.BYTES_REF; - } - throw EsqlIllegalArgumentException.illegalDataType(dataType); + + return switch (dataType) { + case LONG, DATETIME, UNSIGNED_LONG, COUNTER_LONG -> ElementType.LONG; + case INTEGER, COUNTER_INTEGER -> ElementType.INT; + case DOUBLE, COUNTER_DOUBLE -> ElementType.DOUBLE; + // unsupported fields are passed through as a BytesRef + case KEYWORD, TEXT, IP, SOURCE, VERSION, UNSUPPORTED -> ElementType.BYTES_REF; + case NULL -> ElementType.NULL; + case BOOLEAN -> ElementType.BOOLEAN; + case DOC_DATA_TYPE -> ElementType.DOC; + case TSID_DATA_TYPE -> ElementType.BYTES_REF; + case GEO_POINT, CARTESIAN_POINT -> fieldExtractPreference == DOC_VALUES ? ElementType.LONG : ElementType.BYTES_REF; + case GEO_SHAPE, CARTESIAN_SHAPE -> ElementType.BYTES_REF; + case SHORT, BYTE, DATE_PERIOD, TIME_DURATION, OBJECT, NESTED, FLOAT, HALF_FLOAT, SCALED_FLOAT -> + throw EsqlIllegalArgumentException.illegalDataType(dataType); + }; } /** @@ -296,6 +272,6 @@ public static ElementType toElementType(DataType dataType, MappedFieldType.Field * Returns DOC_VALUES if the given boolean is set. */ public static MappedFieldType.FieldExtractPreference extractPreference(boolean hasPreference) { - return hasPreference ? MappedFieldType.FieldExtractPreference.DOC_VALUES : MappedFieldType.FieldExtractPreference.NONE; + return hasPreference ? DOC_VALUES : NONE; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index 28191a394e69c..c83840b384dbd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -38,7 +38,6 @@ import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver; import org.elasticsearch.xpack.esql.execution.PlanExecutor; import org.elasticsearch.xpack.esql.session.EsqlConfiguration; -import org.elasticsearch.xpack.esql.type.EsqlDataTypes; import java.io.IOException; import java.time.ZoneOffset; @@ -172,7 +171,7 @@ private void innerExecute(Task task, EsqlQueryRequest request, ActionListener { List columns = physicalPlan.output() .stream() - .map(c -> new ColumnInfo(c.qualifiedName(), EsqlDataTypes.outputType(c.dataType()))) + .map(c -> new ColumnInfo(c.qualifiedName(), c.dataType().outputType())) .toList(); EsqlQueryResponse.Profile profile = configuration.profile() ? new EsqlQueryResponse.Profile(result.profiles()) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuery.java index 5d5e2b82e4b7b..4cd51b676fe89 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuery.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuery.java @@ -436,7 +436,8 @@ private Scorer scorer(Scorer nextScorer, LeafFieldData lfd) { @Override public boolean isCacheable(LeafReaderContext ctx) { - return next.isCacheable(ctx); + // we cannot cache this query because we loose the ability of emitting warnings + return false; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeRegistry.java index 87d35b728ca8c..4ddef25584eea 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeRegistry.java @@ -35,7 +35,13 @@ public Collection dataTypes() { @Override public DataType fromEs(String typeName, TimeSeriesParams.MetricType metricType) { - DataType type = EsqlDataTypes.fromName(typeName); + DataType type = DataType.fromEs(typeName); + /* + * If we're handling a time series COUNTER type field then convert it + * into it's counter. But *first* we have to widen it because we only + * have time series counters for `double`, `long` and `int`, not `float` + * and `half_float`, etc. + */ return metricType == TimeSeriesParams.MetricType.COUNTER ? type.widenSmallNumeric().counter() : type; } @@ -46,7 +52,7 @@ public DataType fromJava(Object value) { @Override public boolean isUnsupported(DataType type) { - return EsqlDataTypes.isUnsupported(type); + return type == DataType.UNSUPPORTED; } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypes.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypes.java index be1123a292409..a7d7bb66a4818 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypes.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypes.java @@ -8,12 +8,8 @@ import org.elasticsearch.xpack.esql.core.type.DataType; -import java.util.Collections; import java.util.Locale; -import java.util.Map; -import static java.util.stream.Collectors.toMap; -import static java.util.stream.Collectors.toUnmodifiableMap; import static org.elasticsearch.xpack.esql.core.type.DataType.BYTE; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD; import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; @@ -32,40 +28,10 @@ public final class EsqlDataTypes { - private static final Map NAME_TO_TYPE = DataType.types() - .stream() - .collect(toUnmodifiableMap(DataType::typeName, t -> t)); - - private static final Map ES_TO_TYPE; - - static { - Map map = DataType.types().stream().filter(e -> e.esType() != null).collect(toMap(DataType::esType, t -> t)); - // ES calls this 'point', but ESQL calls it 'cartesian_point' - map.put("point", DataType.CARTESIAN_POINT); - map.put("shape", DataType.CARTESIAN_SHAPE); - ES_TO_TYPE = Collections.unmodifiableMap(map); - } - private EsqlDataTypes() {} public static DataType fromTypeName(String name) { - return NAME_TO_TYPE.get(name.toLowerCase(Locale.ROOT)); - } - - public static DataType fromName(String name) { - DataType type = ES_TO_TYPE.get(name); - return type != null ? type : UNSUPPORTED; - } - - public static boolean isUnsupported(DataType type) { - return DataType.isUnsupported(type); - } - - public static String outputType(DataType type) { - if (type != null && type.esType() != null) { - return type.esType(); - } - return "unsupported"; + return DataType.fromTypeName(name.toLowerCase(Locale.ROOT)); } public static boolean isString(DataType t) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index fd161c8d63871..ebeb62ee02df6 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -86,7 +86,6 @@ import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.EsqlConfiguration; import org.elasticsearch.xpack.esql.stats.DisabledSearchStats; -import org.elasticsearch.xpack.esql.type.EsqlDataTypes; import org.junit.After; import org.junit.Before; import org.mockito.Mockito; @@ -417,7 +416,7 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception { List dataTypes = new ArrayList<>(columnNames.size()); List columnTypes = coordinatorPlan.output() .stream() - .peek(o -> dataTypes.add(EsqlDataTypes.outputType(o.dataType()))) + .peek(o -> dataTypes.add(o.dataType().outputType())) .map(o -> Type.asType(o.dataType().nameUpper())) .toList(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponseTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponseTests.java index a5c305ca77a45..ead0eb9ee0635 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponseTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponseTests.java @@ -48,7 +48,6 @@ import org.elasticsearch.xpack.esql.TestBlockFactory; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.planner.PlannerUtils; -import org.elasticsearch.xpack.esql.type.EsqlDataTypes; import org.elasticsearch.xpack.versionfield.Version; import org.junit.After; import org.junit.Before; @@ -139,7 +138,7 @@ private EsqlQueryResponse.Profile randomProfile() { private Page randomPage(List columns) { return new Page(columns.stream().map(c -> { - Block.Builder builder = PlannerUtils.toElementType(EsqlDataTypes.fromName(c.type())).newBlockBuilder(1, blockFactory); + Block.Builder builder = PlannerUtils.toElementType(DataType.fromEs(c.type())).newBlockBuilder(1, blockFactory); switch (c.type()) { case "unsigned_long", "long", "counter_long" -> ((LongBlock.Builder) builder).appendLong(randomLong()); case "integer", "counter_integer" -> ((IntBlock.Builder) builder).appendInt(randomInt()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 4482cb1a210b2..28d2046a0ea36 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; -import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; @@ -44,7 +43,6 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; import org.elasticsearch.xpack.esql.expression.function.aggregate.Max; import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Enrich; @@ -60,7 +58,6 @@ import java.io.IOException; import java.io.InputStream; -import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -2071,50 +2068,18 @@ public void testImplicitCasting() { assertThat(e.getMessage(), containsString("[+] has arguments with incompatible types [datetime] and [datetime]")); } - public void testRate() { + public void testRateRequiresCounterTypes() { assumeTrue("rate requires snapshot builds", Build.current().isSnapshot()); Analyzer analyzer = analyzer(tsdbIndexResolution()); - { - var query = "FROM test | STATS rate(network.bytes_in)"; - LogicalPlan plan = analyze(query, analyzer); - var limit = as(plan, Limit.class); - var stats = as(limit.child(), Aggregate.class); - var rate = as(as(stats.aggregates().get(0), Alias.class).child(), Rate.class); - FieldAttribute field = as(rate.field(), FieldAttribute.class); - assertThat(field.name(), equalTo("network.bytes_in")); - assertThat(rate.parameters(), hasSize(1)); - FieldAttribute timestamp = as(rate.parameters().get(0), FieldAttribute.class); - assertThat(timestamp.name(), equalTo("@timestamp")); - assertThat(rate.typeResolved(), equalTo(Expression.TypeResolution.TYPE_RESOLVED)); - } - { - var query = "FROM test | STATS rate(network.bytes_out, 1minute)"; - LogicalPlan plan = analyze(query, analyzer); - var limit = as(plan, Limit.class); - var stats = as(limit.child(), Aggregate.class); - var rate = as(as(stats.aggregates().get(0), Alias.class).child(), Rate.class); - FieldAttribute field = as(rate.field(), FieldAttribute.class); - assertThat(field.name(), equalTo("network.bytes_out")); - assertThat(rate.parameters(), hasSize(2)); - FieldAttribute timestamp = as(rate.parameters().get(0), FieldAttribute.class); - assertThat(timestamp.name(), equalTo("@timestamp")); - Expression unit = as(rate.parameters().get(1), Expression.class); - assertTrue(unit.foldable()); - Duration duration = as(unit.fold(), Duration.class); - assertThat(duration.toMillis(), equalTo(60 * 1000L)); - assertThat(rate.typeResolved(), equalTo(Expression.TypeResolution.TYPE_RESOLVED)); - } - { - var query = "FROM test | STATS rate(network.connections)"; - VerificationException error = expectThrows(VerificationException.class, () -> analyze(query, analyzer)); - assertThat( - error.getMessage(), - containsString( - "first argument of [rate(network.connections)] must be" - + " [counter_long, counter_integer or counter_double], found value [network.connections] type [long]" - ) - ); - } + var query = "METRICS test avg(rate(network.connections))"; + VerificationException error = expectThrows(VerificationException.class, () -> analyze(query, analyzer)); + assertThat( + error.getMessage(), + containsString( + "first argument of [rate(network.connections)] must be" + + " [counter_long, counter_integer or counter_double], found value [network.connections] type [long]" + ) + ); } private void verifyUnsupported(String query, String errorMessage) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 8eef05bd9687b..beb85268425be 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.analysis; +import org.elasticsearch.Build; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -548,6 +549,52 @@ public void testAggsResolutionWithUnresolvedGroupings() { ); } + public void testNotAllowRateOutsideMetrics() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + assertThat( + error("FROM tests | STATS avg(rate(network.bytes_in))", tsdb), + equalTo("1:24: the rate aggregate[rate(network.bytes_in)] can only be used within the metrics command") + ); + assertThat( + error("METRICS tests | STATS sum(rate(network.bytes_in))", tsdb), + equalTo("1:27: the rate aggregate[rate(network.bytes_in)] can only be used within the metrics command") + ); + assertThat( + error("FROM tests | STATS rate(network.bytes_in)", tsdb), + equalTo("1:20: the rate aggregate[rate(network.bytes_in)] can only be used within the metrics command") + ); + assertThat( + error("FROM tests | EVAL r = rate(network.bytes_in)", tsdb), + equalTo("1:23: aggregate function [rate(network.bytes_in)] not allowed outside METRICS command") + ); + } + + public void testRateNotEnclosedInAggregate() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + assertThat( + error("METRICS tests rate(network.bytes_in)", tsdb), + equalTo( + "1:15: the rate aggregate [rate(network.bytes_in)] can only be used within the metrics command and inside another aggregate" + ) + ); + assertThat( + error("METRICS tests avg(rate(network.bytes_in)), rate(network.bytes_in)", tsdb), + equalTo( + "1:44: the rate aggregate [rate(network.bytes_in)] can only be used within the metrics command and inside another aggregate" + ) + ); + assertThat(error("METRICS tests max(avg(rate(network.bytes_in)))", tsdb), equalTo(""" + 1:19: nested aggregations [avg(rate(network.bytes_in))] not allowed inside other aggregations\ + [max(avg(rate(network.bytes_in)))] + line 1:23: the rate aggregate [rate(network.bytes_in)] can only be used within the metrics command\ + and inside another aggregate""")); + assertThat(error("METRICS tests max(avg(rate(network.bytes_in)))", tsdb), equalTo(""" + 1:19: nested aggregations [avg(rate(network.bytes_in))] not allowed inside other aggregations\ + [max(avg(rate(network.bytes_in)))] + line 1:23: the rate aggregate [rate(network.bytes_in)] can only be used within the metrics command\ + and inside another aggregate""")); + } + private String error(String query) { return error(query, defaultAnalyzer); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistryTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistryTests.java index 4d50069d2f830..74ace6f4ceb9c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistryTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistryTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.expression.function; +import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.ParsingException; import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; @@ -15,11 +16,18 @@ import org.elasticsearch.xpack.esql.core.expression.function.FunctionRegistry; import org.elasticsearch.xpack.esql.core.expression.function.FunctionRegistryTests; import org.elasticsearch.xpack.esql.core.expression.function.FunctionResolutionStrategy; +import org.elasticsearch.xpack.esql.core.expression.function.OptionalArgument; import org.elasticsearch.xpack.esql.core.expression.function.UnresolvedFunction; +import org.elasticsearch.xpack.esql.core.session.Configuration; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.tree.SourceTests; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlConfigurationFunction; import java.util.Arrays; +import java.util.List; +import java.util.function.Function; import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomConfiguration; import static org.elasticsearch.xpack.esql.core.expression.function.FunctionRegistry.def; @@ -78,6 +86,19 @@ public void testUnaryFunction() { assertThat(e.getMessage(), endsWith("expects exactly one argument")); } + public void testConfigurationOptionalFunction() { + UnresolvedFunction ur = uf(DEFAULT, mock(Expression.class)); + FunctionDefinition def; + FunctionRegistry r = new EsqlFunctionRegistry( + EsqlFunctionRegistry.def(DummyConfigurationOptionalArgumentFunction.class, (Source l, Expression e, Configuration c) -> { + assertSame(e, ur.children().get(0)); + return new DummyConfigurationOptionalArgumentFunction(l, List.of(ur), c); + }, "dummy") + ); + def = r.resolveFunction(r.resolveAlias("DUMMY")); + assertEquals(ur.source(), ur.buildResolved(randomConfiguration(), def).source()); + } + private static UnresolvedFunction uf(FunctionResolutionStrategy resolutionStrategy, Expression... children) { return new UnresolvedFunction(SourceTests.randomSource(), "dummyFunction", resolutionStrategy, Arrays.asList(children)); } @@ -104,4 +125,33 @@ private String randomCapitalizedString(String input) { } return output.toString(); } + + public static class DummyConfigurationOptionalArgumentFunction extends EsqlConfigurationFunction implements OptionalArgument { + + public DummyConfigurationOptionalArgumentFunction(Source source, List fields, Configuration configuration) { + super(source, fields, configuration); + } + + @Override + public DataType dataType() { + return null; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new DummyConfigurationOptionalArgumentFunction(source(), newChildren, configuration()); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, DummyConfigurationOptionalArgumentFunction::new, children(), configuration()); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory toEvaluator( + Function toEvaluator + ) { + return null; + } + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java new file mode 100644 index 0000000000000..335473c4f3eff --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class AvgSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected Avg createTestInstance() { + return new Avg(randomSource(), randomChild()); + } + + @Override + protected Avg mutateInstance(Avg instance) throws IOException { + return new Avg(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinctSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinctSerializationTests.java new file mode 100644 index 0000000000000..c7166adb0b62f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinctSerializationTests.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class CountDistinctSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected CountDistinct createTestInstance() { + Source source = randomSource(); + Expression field = randomChild(); + Expression precision = randomBoolean() ? null : randomChild(); + return new CountDistinct(source, field, precision); + } + + @Override + protected CountDistinct mutateInstance(CountDistinct instance) throws IOException { + Source source = randomSource(); + Expression field = randomChild(); + Expression precision = randomBoolean() ? null : randomChild(); + if (randomBoolean()) { + field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild); + } else { + precision = randomValueOtherThan(precision, () -> randomBoolean() ? null : randomChild()); + } + return new CountDistinct(source, field, precision); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountSerializationTests.java new file mode 100644 index 0000000000000..1c588b26abad8 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountSerializationTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class CountSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected Count createTestInstance() { + return new Count(randomSource(), randomChild()); + } + + @Override + protected Count mutateInstance(Count instance) throws IOException { + return new Count(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxSerializationTests.java new file mode 100644 index 0000000000000..a50cba3e9e9cd --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxSerializationTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class MaxSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected Max createTestInstance() { + return new Max(randomSource(), randomChild()); + } + + @Override + protected Max mutateInstance(Max instance) throws IOException { + return new Max(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviationSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviationSerializationTests.java new file mode 100644 index 0000000000000..a57c45da07ba3 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianAbsoluteDeviationSerializationTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class MedianAbsoluteDeviationSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected MedianAbsoluteDeviation createTestInstance() { + return new MedianAbsoluteDeviation(randomSource(), randomChild()); + } + + @Override + protected MedianAbsoluteDeviation mutateInstance(MedianAbsoluteDeviation instance) throws IOException { + return new MedianAbsoluteDeviation( + instance.source(), + randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild) + ); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianSerializationTests.java new file mode 100644 index 0000000000000..56943e9ef41c3 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MedianSerializationTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class MedianSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected Median createTestInstance() { + return new Median(randomSource(), randomChild()); + } + + @Override + protected Median mutateInstance(Median instance) throws IOException { + return new Median(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinSerializationTests.java new file mode 100644 index 0000000000000..bd0d8088ef857 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinSerializationTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class MinSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected Min createTestInstance() { + return new Min(randomSource(), randomChild()); + } + + @Override + protected Min mutateInstance(Min instance) throws IOException { + return new Min(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PercentileSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PercentileSerializationTests.java new file mode 100644 index 0000000000000..88e063058c9f2 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/PercentileSerializationTests.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class PercentileSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected Percentile createTestInstance() { + Source source = randomSource(); + Expression field = randomChild(); + Expression percentile = randomChild(); + return new Percentile(source, field, percentile); + } + + @Override + protected Percentile mutateInstance(Percentile instance) throws IOException { + Source source = instance.source(); + Expression field = instance.field(); + Expression percentile = instance.percentile(); + if (randomBoolean()) { + field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild); + } else { + percentile = randomValueOtherThan(percentile, AbstractExpressionSerializationTests::randomChild); + } + return new Percentile(source, field, percentile); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroidSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroidSerializationTests.java new file mode 100644 index 0000000000000..9adf7d1e00361 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SpatialCentroidSerializationTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class SpatialCentroidSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected SpatialCentroid createTestInstance() { + return new SpatialCentroid(randomSource(), randomChild()); + } + + @Override + protected SpatialCentroid mutateInstance(SpatialCentroid instance) throws IOException { + return new SpatialCentroid( + instance.source(), + randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild) + ); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java new file mode 100644 index 0000000000000..9c7ee0e8348b7 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class SumSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected Sum createTestInstance() { + return new Sum(randomSource(), randomChild()); + } + + @Override + protected Sum mutateInstance(Sum instance) throws IOException { + return new Sum(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopListSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopListSerializationTests.java new file mode 100644 index 0000000000000..605d240512e65 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopListSerializationTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class TopListSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected TopList createTestInstance() { + Source source = randomSource(); + Expression field = randomChild(); + Expression limit = randomChild(); + Expression order = randomChild(); + return new TopList(source, field, limit, order); + } + + @Override + protected TopList mutateInstance(TopList instance) throws IOException { + Source source = instance.source(); + Expression field = instance.field(); + Expression limit = instance.limitField(); + Expression order = instance.orderField(); + switch (between(0, 2)) { + case 0 -> field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild); + case 1 -> limit = randomValueOtherThan(limit, AbstractExpressionSerializationTests::randomChild); + case 2 -> order = randomValueOtherThan(order, AbstractExpressionSerializationTests::randomChild); + } + return new TopList(source, field, limit, order); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesSerializationTests.java new file mode 100644 index 0000000000000..2471e6a8218b3 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ValuesSerializationTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class ValuesSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected Values createTestInstance() { + return new Values(randomSource(), randomChild()); + } + + @Override + protected Values mutateInstance(Values instance) throws IOException { + return new Values(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + } + + @Override + protected List getNamedWriteables() { + return AggregateFunction.getNamedWriteables(); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketSerializationTests.java new file mode 100644 index 0000000000000..8250cad0c85e8 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketSerializationTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.grouping; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class BucketSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Bucket createTestInstance() { + Source source = randomSource(); + Expression field = randomChild(); + Expression buckets = randomChild(); + Expression from = randomChild(); + Expression to = randomChild(); + return new Bucket(source, field, buckets, from, to); + } + + @Override + protected Bucket mutateInstance(Bucket instance) throws IOException { + Source source = instance.source(); + Expression field = instance.field(); + Expression buckets = instance.buckets(); + Expression from = instance.from(); + Expression to = instance.to(); + switch (between(0, 3)) { + case 0 -> field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild); + case 1 -> buckets = randomValueOtherThan(buckets, AbstractExpressionSerializationTests::randomChild); + case 2 -> from = randomValueOtherThan(from, AbstractExpressionSerializationTests::randomChild); + case 3 -> to = randomValueOtherThan(to, AbstractExpressionSerializationTests::randomChild); + } + return new Bucket(source, field, buckets, from, to); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/BucketTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketTests.java similarity index 98% rename from x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/BucketTests.java rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketTests.java index c4e614be94438..aaa6fe7d45c83 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/BucketTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/BucketTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.esql.expression.function.scalar.math; +package org.elasticsearch.xpack.esql.expression.function.grouping; import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; -import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.hamcrest.Matcher; import java.time.Duration; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AndSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AndSerializationTests.java new file mode 100644 index 0000000000000..ffeae4465eac6 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AndSerializationTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class AndSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected And createTestInstance() { + Source source = randomSource(); + Expression left = randomChild(); + Expression right = randomChild(); + return new And(source, left, right); + } + + @Override + protected And mutateInstance(And instance) throws IOException { + Source source = instance.source(); + Expression left = instance.left(); + Expression right = instance.right(); + if (randomBoolean()) { + left = randomValueOtherThan(left, AbstractExpressionSerializationTests::randomChild); + } else { + right = randomValueOtherThan(right, AbstractExpressionSerializationTests::randomChild); + } + return new And(source, left, right); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/OrSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/OrSerializationTests.java new file mode 100644 index 0000000000000..1755ba1fac026 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/OrSerializationTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class OrSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Or createTestInstance() { + Source source = randomSource(); + Expression left = randomChild(); + Expression right = randomChild(); + return new Or(source, left, right); + } + + @Override + protected Or mutateInstance(Or instance) throws IOException { + Source source = instance.source(); + Expression left = instance.left(); + Expression right = instance.right(); + if (randomBoolean()) { + left = randomValueOtherThan(left, AbstractExpressionSerializationTests::randomChild); + } else { + right = randomValueOtherThan(right, AbstractExpressionSerializationTests::randomChild); + } + return new Or(source, left, right); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/CIDRMatchSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/CIDRMatchSerializationTests.java new file mode 100644 index 0000000000000..e20f9e03f09b6 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/CIDRMatchSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.ip; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class CIDRMatchSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected CIDRMatch createTestInstance() { + Source source = randomSource(); + Expression ipField = randomChild(); + List matches = randomList(1, 10, AbstractExpressionSerializationTests::randomChild); + return new CIDRMatch(source, ipField, matches); + } + + @Override + protected CIDRMatch mutateInstance(CIDRMatch instance) throws IOException { + Source source = instance.source(); + Expression ipField = instance.ipField(); + List matches = instance.matches(); + if (randomBoolean()) { + ipField = randomValueOtherThan(ipField, AbstractExpressionSerializationTests::randomChild); + } else { + matches = randomValueOtherThan(matches, () -> randomList(1, 10, AbstractExpressionSerializationTests::randomChild)); + } + return new CIDRMatch(source, ipField, matches); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/IpPrefixSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/IpPrefixSerializationTests.java new file mode 100644 index 0000000000000..8393dad31b2e2 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/ip/IpPrefixSerializationTests.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.ip; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class IpPrefixSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected IpPrefix createTestInstance() { + Source source = randomSource(); + Expression ipField = randomChild(); + Expression prefixLengthV4Field = randomChild(); + Expression prefixLengthV6Field = randomChild(); + return new IpPrefix(source, ipField, prefixLengthV4Field, prefixLengthV6Field); + } + + @Override + protected IpPrefix mutateInstance(IpPrefix instance) throws IOException { + Source source = instance.source(); + Expression ipField = instance.ipField(); + Expression prefixLengthV4Field = instance.prefixLengthV4Field(); + Expression prefixLengthV6Field = instance.prefixLengthV6Field(); + switch (between(0, 2)) { + case 0 -> ipField = randomValueOtherThan(ipField, AbstractExpressionSerializationTests::randomChild); + case 1 -> prefixLengthV4Field = randomValueOtherThan(prefixLengthV4Field, AbstractExpressionSerializationTests::randomChild); + case 2 -> prefixLengthV6Field = randomValueOtherThan(prefixLengthV6Field, AbstractExpressionSerializationTests::randomChild); + } + return new IpPrefix(source, ipField, prefixLengthV4Field, prefixLengthV6Field); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Atan2SerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Atan2SerializationTests.java new file mode 100644 index 0000000000000..11986adf2bc24 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Atan2SerializationTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.math; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.AbstractUnaryScalarSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class Atan2SerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Atan2 createTestInstance() { + Source source = randomSource(); + Expression y = randomChild(); + Expression x = randomChild(); + return new Atan2(source, y, x); + } + + @Override + protected Atan2 mutateInstance(Atan2 instance) throws IOException { + Source source = instance.source(); + Expression y = instance.y(); + Expression x = instance.x(); + if (randomBoolean()) { + y = randomValueOtherThan(y, AbstractUnaryScalarSerializationTests::randomChild); + } else { + x = randomValueOtherThan(x, AbstractUnaryScalarSerializationTests::randomChild); + } + return new Atan2(source, y, x); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/ESerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/ESerializationTests.java new file mode 100644 index 0000000000000..ff8f1563df94e --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/ESerializationTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.math; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class ESerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected E createTestInstance() { + return new E(randomSource()); + } + + @Override + protected E mutateInstance(E instance) throws IOException { + return null; + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/LogSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/LogSerializationTests.java new file mode 100644 index 0000000000000..bb33516900dd7 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/LogSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.math; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class LogSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Log createTestInstance() { + Source source = randomSource(); + Expression value = randomChild(); + Expression base = randomBoolean() ? null : randomChild(); + return new Log(source, value, base); + } + + @Override + protected Log mutateInstance(Log instance) throws IOException { + Source source = instance.source(); + Expression value = instance.value(); + Expression base = instance.base(); + if (randomBoolean()) { + value = randomValueOtherThan(value, AbstractExpressionSerializationTests::randomChild); + } else { + base = randomValueOtherThan(base, () -> randomBoolean() ? null : randomChild()); + } + return new Log(source, value, base); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PiSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PiSerializationTests.java new file mode 100644 index 0000000000000..4768ab292be10 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PiSerializationTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.math; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class PiSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Pi createTestInstance() { + return new Pi(randomSource()); + } + + @Override + protected Pi mutateInstance(Pi instance) throws IOException { + return null; + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowSerializationTests.java new file mode 100644 index 0000000000000..b47ec608cccab --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.math; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class PowSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Pow createTestInstance() { + Source source = randomSource(); + Expression base = randomChild(); + Expression exponent = randomChild(); + return new Pow(source, base, exponent); + } + + @Override + protected Pow mutateInstance(Pow instance) throws IOException { + Source source = instance.source(); + Expression base = instance.base(); + Expression exponent = instance.exponent(); + if (randomBoolean()) { + base = randomValueOtherThan(base, AbstractExpressionSerializationTests::randomChild); + } else { + exponent = randomValueOtherThan(exponent, AbstractExpressionSerializationTests::randomChild); + } + return new Pow(source, base, exponent); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/RoundSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/RoundSerializationTests.java new file mode 100644 index 0000000000000..8146aea8d5c9f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/RoundSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.math; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class RoundSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Round createTestInstance() { + Source source = randomSource(); + Expression field = randomChild(); + Expression decimals = randomBoolean() ? null : randomChild(); + return new Round(source, field, decimals); + } + + @Override + protected Round mutateInstance(Round instance) throws IOException { + Source source = instance.source(); + Expression field = instance.field(); + Expression decimals = instance.decimals(); + if (randomBoolean()) { + field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild); + } else { + decimals = randomValueOtherThan(decimals, () -> randomBoolean() ? null : randomChild()); + } + return new Round(source, field, decimals); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/SignumSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/SignumSerializationTests.java new file mode 100644 index 0000000000000..98738aa8c64f6 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/SignumSerializationTests.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.math; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractUnaryScalarSerializationTests; + +public class SignumSerializationTests extends AbstractUnaryScalarSerializationTests { + @Override + protected Signum create(Source source, Expression child) { + return new Signum(source, child); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/TauSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/TauSerializationTests.java new file mode 100644 index 0000000000000..3320dcf0a180c --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/TauSerializationTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.math; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class TauSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Tau createTestInstance() { + return new Tau(randomSource()); + } + + @Override + protected Tau mutateInstance(Tau instance) throws IOException { + return null; + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/AbstractBinarySpatialFunctionSerializationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/AbstractBinarySpatialFunctionSerializationTestCase.java new file mode 100644 index 0000000000000..d304c474feac3 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/AbstractBinarySpatialFunctionSerializationTestCase.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.spatial; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; +import java.util.List; + +public abstract class AbstractBinarySpatialFunctionSerializationTestCase extends + AbstractExpressionSerializationTests { + + protected abstract T build(Source source, Expression left, Expression right); + + @Override + protected final List getNamedWriteables() { + return BinarySpatialFunction.getNamedWriteables(); + } + + @Override + protected final T createTestInstance() { + Source source = randomSource(); + Expression left = randomChild(); + Expression right = randomChild(); + return build(source, left, right); + } + + @Override + protected final T mutateInstance(T instance) throws IOException { + Source source = instance.source(); + Expression left = instance.left(); + Expression right = instance.right(); + if (randomBoolean()) { + left = randomValueOtherThan(left, AbstractExpressionSerializationTests::randomChild); + } else { + right = randomValueOtherThan(right, AbstractExpressionSerializationTests::randomChild); + } + return build(source, left, right); + } + + @Override + protected final boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContainsSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContainsSerializationTests.java new file mode 100644 index 0000000000000..5c707f54ac9d0 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContainsSerializationTests.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.spatial; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; + +public class SpatialContainsSerializationTests extends AbstractBinarySpatialFunctionSerializationTestCase { + @Override + protected SpatialContains build(Source source, Expression left, Expression right) { + return new SpatialContains(source, left, right); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjointSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjointSerializationTests.java new file mode 100644 index 0000000000000..a16e7ffdb2d17 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjointSerializationTests.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.spatial; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; + +public class SpatialDisjointSerializationTests extends AbstractBinarySpatialFunctionSerializationTestCase { + @Override + protected SpatialDisjoint build(Source source, Expression left, Expression right) { + return new SpatialDisjoint(source, left, right); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersectsSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersectsSerializationTests.java new file mode 100644 index 0000000000000..35a85926101f5 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersectsSerializationTests.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.spatial; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; + +public class SpatialIntersectsSerializationTests extends AbstractBinarySpatialFunctionSerializationTestCase { + @Override + protected SpatialIntersects build(Source source, Expression left, Expression right) { + return new SpatialIntersects(source, left, right); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithinSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithinSerializationTests.java new file mode 100644 index 0000000000000..74fe752b59eaf --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithinSerializationTests.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.spatial; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; + +public class SpatialWithinSerializationTests extends AbstractBinarySpatialFunctionSerializationTestCase { + @Override + protected SpatialWithin build(Source source, Expression left, Expression right) { + return new SpatialWithin(source, left, right); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistanceSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistanceSerializationTests.java new file mode 100644 index 0000000000000..c9cdccdaf0ca2 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistanceSerializationTests.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.spatial; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; + +public class StDistanceSerializationTests extends AbstractBinarySpatialFunctionSerializationTestCase { + @Override + protected StDistance build(Source source, Expression left, Expression right) { + return new StDistance(source, left, right); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StXSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StXSerializationTests.java new file mode 100644 index 0000000000000..56ddd039cb87f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StXSerializationTests.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.spatial; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.UnaryScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class StXSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return UnaryScalarFunction.getNamedWriteables(); + } + + @Override + protected StX createTestInstance() { + return new StX(randomSource(), randomChild()); + } + + @Override + protected StX mutateInstance(StX instance) throws IOException { + return new StX(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StYSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StYSerializationTests.java new file mode 100644 index 0000000000000..f44b49d38d34c --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StYSerializationTests.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.spatial; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.UnaryScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class StYSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return UnaryScalarFunction.getNamedWriteables(); + } + + @Override + protected StY createTestInstance() { + return new StY(randomSource(), randomChild()); + } + + @Override + protected StY mutateInstance(StY instance) throws IOException { + return new StY(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWithSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWithSerializationTests.java new file mode 100644 index 0000000000000..2f734d585ab52 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/EndsWithSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class EndsWithSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected EndsWith createTestInstance() { + Source source = randomSource(); + Expression str = randomChild(); + Expression suffix = randomChild(); + return new EndsWith(source, str, suffix); + } + + @Override + protected EndsWith mutateInstance(EndsWith instance) throws IOException { + Source source = instance.source(); + Expression str = instance.str(); + Expression suffix = instance.suffix(); + if (randomBoolean()) { + str = randomValueOtherThan(str, AbstractExpressionSerializationTests::randomChild); + } else { + suffix = randomValueOtherThan(suffix, AbstractExpressionSerializationTests::randomChild); + } + return new EndsWith(source, str, suffix); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LTrimSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LTrimSerializationTests.java new file mode 100644 index 0000000000000..e3cac6caf130d --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LTrimSerializationTests.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractUnaryScalarSerializationTests; + +public class LTrimSerializationTests extends AbstractUnaryScalarSerializationTests { + @Override + protected LTrim create(Source source, Expression child) { + return new LTrim(source, child); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LeftSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LeftSerializationTests.java new file mode 100644 index 0000000000000..2162044d2e29f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LeftSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class LeftSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Left createTestInstance() { + Source source = randomSource(); + Expression str = randomChild(); + Expression length = randomChild(); + return new Left(source, str, length); + } + + @Override + protected Left mutateInstance(Left instance) throws IOException { + Source source = instance.source(); + Expression str = instance.str(); + Expression length = instance.length(); + if (randomBoolean()) { + str = randomValueOtherThan(str, AbstractExpressionSerializationTests::randomChild); + } else { + length = randomValueOtherThan(length, AbstractExpressionSerializationTests::randomChild); + } + return new Left(source, str, length); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LengthSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LengthSerializationTests.java new file mode 100644 index 0000000000000..07b8cb722096b --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LengthSerializationTests.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractUnaryScalarSerializationTests; + +public class LengthSerializationTests extends AbstractUnaryScalarSerializationTests { + @Override + protected Length create(Source source, Expression child) { + return new Length(source, child); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LocateSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LocateSerializationTests.java new file mode 100644 index 0000000000000..d705b4a6167ec --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/LocateSerializationTests.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class LocateSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Locate createTestInstance() { + Source source = randomSource(); + Expression str = randomChild(); + Expression substr = randomChild(); + Expression start = randomChild(); + return new Locate(source, str, substr, start); + } + + @Override + protected Locate mutateInstance(Locate instance) throws IOException { + Source source = instance.source(); + Expression str = instance.str(); + Expression substr = instance.substr(); + Expression start = instance.start(); + switch (between(0, 2)) { + case 0 -> str = randomValueOtherThan(str, AbstractExpressionSerializationTests::randomChild); + case 1 -> substr = randomValueOtherThan(substr, AbstractExpressionSerializationTests::randomChild); + case 2 -> start = randomValueOtherThan(start, AbstractExpressionSerializationTests::randomChild); + } + return new Locate(source, str, substr, start); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeSerializationTests.java new file mode 100644 index 0000000000000..6be60b7163e3b --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeSerializationTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.UnaryScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class RLikeSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return UnaryScalarFunction.getNamedWriteables(); + } + + @Override + protected RLike createTestInstance() { + Source source = randomSource(); + Expression child = randomChild(); + RLikePattern pattern = new RLikePattern(randomAlphaOfLength(4)); + return new RLike(source, child, pattern); + } + + @Override + protected RLike mutateInstance(RLike instance) throws IOException { + Source source = instance.source(); + Expression child = instance.field(); + RLikePattern pattern = instance.pattern(); + if (randomBoolean()) { + child = randomValueOtherThan(child, AbstractExpressionSerializationTests::randomChild); + } else { + pattern = randomValueOtherThan(pattern, () -> new RLikePattern(randomAlphaOfLength(4))); + } + return new RLike(source, child, pattern); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java index e1bcc519840be..e673be2ad5290 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java @@ -38,7 +38,7 @@ public RLikeTests(@Name("TestCase") Supplier testCase @ParametersFactory public static Iterable parameters() { return parameters(str -> { - for (String syntax : new String[] { "\\", ".", "?", "+", "*", "|", "{", "}", "[", "]", "(", ")", "\"", "<", ">", "#" }) { + for (String syntax : new String[] { "\\", ".", "?", "+", "*", "|", "{", "}", "[", "]", "(", ")", "\"", "<", ">", "#", "&" }) { str = str.replace(syntax, "\\" + syntax); } return str; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RTrimSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RTrimSerializationTests.java new file mode 100644 index 0000000000000..e52be87c41af0 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RTrimSerializationTests.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractUnaryScalarSerializationTests; + +public class RTrimSerializationTests extends AbstractUnaryScalarSerializationTests { + @Override + protected RTrim create(Source source, Expression child) { + return new RTrim(source, child); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RepeatSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RepeatSerializationTests.java new file mode 100644 index 0000000000000..d246b28ddb0d6 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RepeatSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class RepeatSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Repeat createTestInstance() { + Source source = randomSource(); + Expression str = randomChild(); + Expression number = randomChild(); + return new Repeat(source, str, number); + } + + @Override + protected Repeat mutateInstance(Repeat instance) throws IOException { + Source source = instance.source(); + Expression str = instance.str(); + Expression number = instance.number(); + if (randomBoolean()) { + str = randomValueOtherThan(str, AbstractExpressionSerializationTests::randomChild); + } else { + number = randomValueOtherThan(number, AbstractExpressionSerializationTests::randomChild); + } + return new Repeat(source, str, number); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ReplaceSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ReplaceSerializationTests.java new file mode 100644 index 0000000000000..555f210e6b0c0 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ReplaceSerializationTests.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class ReplaceSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Replace createTestInstance() { + Source source = randomSource(); + Expression str = randomChild(); + Expression regex = randomChild(); + Expression newStr = randomChild(); + return new Replace(source, str, regex, newStr); + } + + @Override + protected Replace mutateInstance(Replace instance) throws IOException { + Source source = instance.source(); + Expression str = instance.str(); + Expression regex = instance.regex(); + Expression newStr = instance.newStr(); + switch (between(0, 2)) { + case 0 -> str = randomValueOtherThan(str, AbstractExpressionSerializationTests::randomChild); + case 1 -> regex = randomValueOtherThan(regex, AbstractExpressionSerializationTests::randomChild); + case 2 -> newStr = randomValueOtherThan(newStr, AbstractExpressionSerializationTests::randomChild); + } + return new Replace(source, str, regex, newStr); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RightSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RightSerializationTests.java new file mode 100644 index 0000000000000..17ab41cc467db --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RightSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class RightSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Right createTestInstance() { + Source source = randomSource(); + Expression str = randomChild(); + Expression length = randomChild(); + return new Right(source, str, length); + } + + @Override + protected Right mutateInstance(Right instance) throws IOException { + Source source = instance.source(); + Expression str = instance.str(); + Expression length = instance.length(); + if (randomBoolean()) { + str = randomValueOtherThan(str, AbstractExpressionSerializationTests::randomChild); + } else { + length = randomValueOtherThan(length, AbstractExpressionSerializationTests::randomChild); + } + return new Right(source, str, length); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/SplitSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/SplitSerializationTests.java new file mode 100644 index 0000000000000..4e38ea9a57d7f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/SplitSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class SplitSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Split createTestInstance() { + Source source = randomSource(); + Expression str = randomChild(); + Expression delim = randomChild(); + return new Split(source, str, delim); + } + + @Override + protected Split mutateInstance(Split instance) throws IOException { + Source source = instance.source(); + Expression str = instance.str(); + Expression delim = instance.delim(); + if (randomBoolean()) { + str = randomValueOtherThan(str, AbstractExpressionSerializationTests::randomChild); + } else { + delim = randomValueOtherThan(delim, AbstractExpressionSerializationTests::randomChild); + } + return new Split(source, str, delim); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/StartsWithSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/StartsWithSerializationTests.java new file mode 100644 index 0000000000000..4cff6f3441510 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/StartsWithSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class StartsWithSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected StartsWith createTestInstance() { + Source source = randomSource(); + Expression str = randomChild(); + Expression prefix = randomChild(); + return new StartsWith(source, str, prefix); + } + + @Override + protected StartsWith mutateInstance(StartsWith instance) throws IOException { + Source source = instance.source(); + Expression str = instance.str(); + Expression prefix = instance.prefix(); + if (randomBoolean()) { + str = randomValueOtherThan(str, AbstractExpressionSerializationTests::randomChild); + } else { + prefix = randomValueOtherThan(prefix, AbstractExpressionSerializationTests::randomChild); + } + return new StartsWith(source, str, prefix); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/SubstringSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/SubstringSerializationTests.java new file mode 100644 index 0000000000000..d5f8fe498902d --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/SubstringSerializationTests.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class SubstringSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected Substring createTestInstance() { + Source source = randomSource(); + Expression str = randomChild(); + Expression start = randomChild(); + Expression length = randomChild(); + return new Substring(source, str, start, length); + } + + @Override + protected Substring mutateInstance(Substring instance) throws IOException { + Source source = instance.source(); + Expression str = instance.str(); + Expression start = instance.start(); + Expression length = instance.length(); + switch (between(0, 2)) { + case 0 -> str = randomValueOtherThan(str, AbstractExpressionSerializationTests::randomChild); + case 1 -> start = randomValueOtherThan(start, AbstractExpressionSerializationTests::randomChild); + case 2 -> length = randomValueOtherThan(length, AbstractExpressionSerializationTests::randomChild); + } + return new Substring(source, str, start, length); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/TrimSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/TrimSerializationTests.java new file mode 100644 index 0000000000000..a49e07fd7065c --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/TrimSerializationTests.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractUnaryScalarSerializationTests; + +public class TrimSerializationTests extends AbstractUnaryScalarSerializationTests { + @Override + protected Trim create(Source source, Expression child) { + return new Trim(source, child); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeSerializationTests.java new file mode 100644 index 0000000000000..99b566b1e8584 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeSerializationTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.string; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.UnaryScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class WildcardLikeSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return UnaryScalarFunction.getNamedWriteables(); + } + + @Override + protected WildcardLike createTestInstance() { + Source source = randomSource(); + Expression child = randomChild(); + WildcardPattern pattern = new WildcardPattern(randomAlphaOfLength(4)); + return new WildcardLike(source, child, pattern); + } + + @Override + protected WildcardLike mutateInstance(WildcardLike instance) throws IOException { + Source source = instance.source(); + Expression child = instance.field(); + WildcardPattern pattern = instance.pattern(); + if (randomBoolean()) { + child = randomValueOtherThan(child, AbstractExpressionSerializationTests::randomChild); + } else { + pattern = randomValueOtherThan(pattern, () -> new WildcardPattern(randomAlphaOfLength(4))); + } + return new WildcardLike(source, child, pattern); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InSerializationTests.java new file mode 100644 index 0000000000000..a92921050ab18 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InSerializationTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; + +import java.io.IOException; +import java.util.List; + +public class InSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected List getNamedWriteables() { + return EsqlScalarFunction.getNamedWriteables(); + } + + @Override + protected In createTestInstance() { + Source source = randomSource(); + Expression value = randomChild(); + List list = randomList(10, AbstractExpressionSerializationTests::randomChild); + return new In(source, value, list); + } + + @Override + protected In mutateInstance(In instance) throws IOException { + Source source = instance.source(); + Expression value = instance.value(); + List list = instance.list(); + if (randomBoolean()) { + value = randomValueOtherThan(value, AbstractExpressionSerializationTests::randomChild); + } else { + list = randomValueOtherThan(list, () -> randomList(10, AbstractExpressionSerializationTests::randomChild)); + } + return new In(source, value, list); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java index 57d304a4f032e..05f0f47910665 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java @@ -21,7 +21,6 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; -import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.ArithmeticOperation; @@ -33,28 +32,9 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; -import org.elasticsearch.xpack.esql.core.type.InvalidMappedField; import org.elasticsearch.xpack.esql.core.type.KeywordEsField; -import org.elasticsearch.xpack.esql.core.type.TextEsField; -import org.elasticsearch.xpack.esql.core.type.UnsupportedEsField; import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; -import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; -import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; -import org.elasticsearch.xpack.esql.expression.function.aggregate.CountDistinct; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Max; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Median; -import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; -import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; -import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; -import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow; -import org.elasticsearch.xpack.esql.expression.function.scalar.math.Round; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.StartsWith; -import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod; @@ -243,19 +223,6 @@ public void testBinComparison() { Stream.generate(PlanNamedTypesTests::randomBinaryComparison).limit(100).forEach(obj -> assertNamedType(Expression.class, obj)); } - public void testAggFunctionSimple() throws IOException { - var orig = new Avg(Source.EMPTY, field("foo_val", DataType.DOUBLE)); - BytesStreamOutput bso = new BytesStreamOutput(); - PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry, null); - out.writeNamed(AggregateFunction.class, orig); - var deser = (Avg) planStreamInput(bso).readNamed(AggregateFunction.class); - EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser); - } - - public void testAggFunction() { - Stream.generate(PlanNamedTypesTests::randomAggFunction).limit(100).forEach(obj -> assertNamedType(AggregateFunction.class, obj)); - } - public void testArithmeticOperationSimple() throws IOException { var orig = new Add(Source.EMPTY, field("foo", DataType.LONG), field("bar", DataType.LONG)); BytesStreamOutput bso = new BytesStreamOutput(); @@ -269,42 +236,6 @@ public void testArithmeticOperation() { Stream.generate(PlanNamedTypesTests::randomArithmeticOperation).limit(100).forEach(obj -> assertNamedType(Expression.class, obj)); } - public void testSubStringSimple() throws IOException { - var orig = new Substring(Source.EMPTY, field("foo", DataType.KEYWORD), new Literal(Source.EMPTY, 1, DataType.INTEGER), null); - BytesStreamOutput bso = new BytesStreamOutput(); - PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry, null); - PlanNamedTypes.writeSubstring(out, orig); - var deser = PlanNamedTypes.readSubstring(planStreamInput(bso)); - EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser); - } - - public void testStartsWithSimple() throws IOException { - var orig = new StartsWith(Source.EMPTY, field("foo", DataType.KEYWORD), new Literal(Source.EMPTY, "fo", DataType.KEYWORD)); - BytesStreamOutput bso = new BytesStreamOutput(); - PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry, null); - PlanNamedTypes.writeStartsWith(out, orig); - var deser = PlanNamedTypes.readStartsWith(planStreamInput(bso)); - EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser); - } - - public void testRoundSimple() throws IOException { - var orig = new Round(Source.EMPTY, field("value", DataType.DOUBLE), new Literal(Source.EMPTY, 1, DataType.INTEGER)); - BytesStreamOutput bso = new BytesStreamOutput(); - PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry, null); - PlanNamedTypes.writeRound(out, orig); - var deser = PlanNamedTypes.readRound(planStreamInput(bso)); - EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser); - } - - public void testPowSimple() throws IOException { - var orig = new Pow(Source.EMPTY, field("value", DataType.DOUBLE), new Literal(Source.EMPTY, 1, DataType.INTEGER)); - BytesStreamOutput bso = new BytesStreamOutput(); - PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry, null); - PlanNamedTypes.writePow(out, orig); - var deser = PlanNamedTypes.readPow(planStreamInput(bso)); - EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser); - } - public void testFieldSortSimple() throws IOException { var orig = new EsQueryExec.FieldSort(field("val", DataType.LONG), Order.OrderDirection.ASC, Order.NullsPosition.FIRST); BytesStreamOutput bso = new BytesStreamOutput(); @@ -394,16 +325,6 @@ static EsIndex randomEsIndex() { ); } - static UnsupportedAttribute randomUnsupportedAttribute() { - return new UnsupportedAttribute( - Source.EMPTY, - randomAlphaOfLength(randomIntBetween(1, 25)), // name - randomUnsupportedEsField(), // field - randomStringOrNull(), // customMessage - nameIdOrNull() - ); - } - static FieldAttribute randomFieldAttributeOrNull() { return randomBoolean() ? randomFieldAttribute() : null; } @@ -433,22 +354,6 @@ static KeywordEsField randomKeywordEsField() { ); } - static TextEsField randomTextEsField() { - return new TextEsField( - randomAlphaOfLength(randomIntBetween(1, 25)), // name - randomProperties(), - randomBoolean(), // hasDocValues - randomBoolean() // alias - ); - } - - static InvalidMappedField randomInvalidMappedField() { - return new InvalidMappedField( - randomAlphaOfLength(randomIntBetween(1, 25)), // name - randomAlphaOfLength(randomIntBetween(1, 25)) // error message - ); - } - static EsqlBinaryComparison randomBinaryComparison() { int v = randomIntBetween(0, 5); var left = field(randomName(), randomDataType()); @@ -464,25 +369,6 @@ static EsqlBinaryComparison randomBinaryComparison() { }; } - static AggregateFunction randomAggFunction() { - int v = randomIntBetween(0, 8); - var field = field(randomName(), randomDataType()); - var right = field(randomName(), randomDataType()); - return switch (v) { - case 0 -> new Avg(Source.EMPTY, field); - case 1 -> new Count(Source.EMPTY, field); - case 2 -> new Sum(Source.EMPTY, field); - case 3 -> new Min(Source.EMPTY, field); - case 4 -> new Max(Source.EMPTY, field); - case 5 -> new Median(Source.EMPTY, field); - case 6 -> new MedianAbsoluteDeviation(Source.EMPTY, field); - case 7 -> new CountDistinct(Source.EMPTY, field, right); - case 8 -> new Percentile(Source.EMPTY, field, right); - case 9 -> new SpatialCentroid(Source.EMPTY, field); - default -> throw new AssertionError(v); - }; - } - static ArithmeticOperation randomArithmeticOperation() { int v = randomIntBetween(0, 4); var left = field(randomName(), randomDataType()); @@ -525,15 +411,6 @@ static EsField randomEsField(int depth) { ); } - static UnsupportedEsField randomUnsupportedEsField() { - return new UnsupportedEsField( - randomAlphaOfLength(randomIntBetween(1, 25)), // name - randomAlphaOfLength(randomIntBetween(1, 25)), // originalType - randomAlphaOfLength(randomIntBetween(1, 25)), // inherited - randomProperties() - ); - } - static Map randomProperties() { return randomProperties(0); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 74bdcf824ba80..3ce778d038875 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.optimizer; +import org.elasticsearch.Build; import org.elasticsearch.common.logging.LoggerMessageFormat; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.compute.aggregation.QuantileStates; @@ -62,8 +63,11 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.MedianAbsoluteDeviation; import org.elasticsearch.xpack.esql.expression.function.aggregate.Min; import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid; import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum; +import org.elasticsearch.xpack.esql.expression.function.aggregate.Values; +import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString; @@ -205,6 +209,9 @@ public class LogicalPlanOptimizerTests extends ESTestCase { private static EnrichResolution enrichResolution; private static final LiteralsOnTheRight LITERALS_ON_THE_RIGHT = new LiteralsOnTheRight(); + private static Map metricMapping; + private static Analyzer metricsAnalyzer; + private static class SubstitutionOnlyOptimizer extends LogicalPlanOptimizer { static SubstitutionOnlyOptimizer INSTANCE = new SubstitutionOnlyOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)); @@ -260,6 +267,13 @@ public static void init() { new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResultExtra, enrichResolution), TEST_VERIFIER ); + + metricMapping = loadMapping("k8s-mappings.json"); + var metricsIndex = IndexResolution.valid(new EsIndex("k8s", metricMapping, Set.of("k8s"))); + metricsAnalyzer = new Analyzer( + new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), metricsIndex, enrichResolution), + TEST_VERIFIER + ); } public void testEmptyProjections() { @@ -802,7 +816,13 @@ public void testSelectivelyPushDownFilterPastFunctionAgg() { Filter fa = new Filter(EMPTY, relation, conditionA); // invalid aggregate but that's fine cause its properties are not used by this rule - Aggregate aggregate = new Aggregate(EMPTY, fa, singletonList(getFieldAttribute("b")), emptyList()); + Aggregate aggregate = new Aggregate( + EMPTY, + fa, + Aggregate.AggregateType.STANDARD, + singletonList(getFieldAttribute("b")), + emptyList() + ); Filter fb = new Filter(EMPTY, aggregate, new And(EMPTY, aggregateCondition, conditionB)); // expected @@ -811,6 +831,7 @@ public void testSelectivelyPushDownFilterPastFunctionAgg() { new Aggregate( EMPTY, new Filter(EMPTY, relation, new And(EMPTY, conditionA, conditionB)), + Aggregate.AggregateType.STANDARD, singletonList(getFieldAttribute("b")), emptyList() ), @@ -5121,6 +5142,212 @@ public void testLookupStats() { ); } + public void testTranslateMetricsWithoutGrouping() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + var query = "METRICS k8s max(rate(network.total_bytes_in))"; + var plan = logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query))); + Limit limit = as(plan, Limit.class); + Aggregate finalAggs = as(limit.child(), Aggregate.class); + Aggregate aggsByTsid = as(finalAggs.child(), Aggregate.class); + as(aggsByTsid.child(), EsRelation.class); + + assertThat(finalAggs.aggregateType(), equalTo(Aggregate.AggregateType.STANDARD)); + assertThat(finalAggs.aggregates(), hasSize(1)); + Max max = as(Alias.unwrap(finalAggs.aggregates().get(0)), Max.class); + assertThat(Expressions.attribute(max.field()).id(), equalTo(aggsByTsid.aggregates().get(0).id())); + assertThat(finalAggs.groupings(), empty()); + + assertThat(aggsByTsid.aggregateType(), equalTo(Aggregate.AggregateType.METRICS)); + assertThat(aggsByTsid.aggregates(), hasSize(1)); // _tsid is dropped + Rate rate = as(Alias.unwrap(aggsByTsid.aggregates().get(0)), Rate.class); + assertThat(Expressions.attribute(rate.field()).name(), equalTo("network.total_bytes_in")); + } + + public void testTranslateMetricsGroupedByOneDimension() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + var query = "METRICS k8s sum(rate(network.total_bytes_in)) BY cluster | SORT cluster | LIMIT 10"; + var plan = logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query))); + TopN topN = as(plan, TopN.class); + Aggregate aggsByCluster = as(topN.child(), Aggregate.class); + assertThat(aggsByCluster.aggregates(), hasSize(2)); + Aggregate aggsByTsid = as(aggsByCluster.child(), Aggregate.class); + assertThat(aggsByTsid.aggregates(), hasSize(2)); // _tsid is dropped + as(aggsByTsid.child(), EsRelation.class); + + assertThat(aggsByCluster.aggregateType(), equalTo(Aggregate.AggregateType.STANDARD)); + Sum sum = as(Alias.unwrap(aggsByCluster.aggregates().get(0)), Sum.class); + assertThat(Expressions.attribute(sum.field()).id(), equalTo(aggsByTsid.aggregates().get(0).id())); + assertThat(aggsByCluster.groupings(), hasSize(1)); + assertThat(Expressions.attribute(aggsByCluster.groupings().get(0)).id(), equalTo(aggsByTsid.aggregates().get(1).id())); + + assertThat(aggsByTsid.aggregateType(), equalTo(Aggregate.AggregateType.METRICS)); + Rate rate = as(Alias.unwrap(aggsByTsid.aggregates().get(0)), Rate.class); + assertThat(Expressions.attribute(rate.field()).name(), equalTo("network.total_bytes_in")); + Values values = as(Alias.unwrap(aggsByTsid.aggregates().get(1)), Values.class); + assertThat(Expressions.attribute(values.field()).name(), equalTo("cluster")); + } + + public void testTranslateMetricsGroupedByTwoDimension() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + var query = "METRICS k8s avg(rate(network.total_bytes_in)) BY cluster, pod"; + var plan = logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query))); + Project project = as(plan, Project.class); + Eval eval = as(project.child(), Eval.class); + assertThat(eval.fields(), hasSize(1)); + Limit limit = as(eval.child(), Limit.class); + Aggregate finalAggs = as(limit.child(), Aggregate.class); + assertThat(finalAggs.aggregates(), hasSize(4)); + Aggregate aggsByTsid = as(finalAggs.child(), Aggregate.class); + assertThat(aggsByTsid.aggregates(), hasSize(3)); // _tsid is dropped + as(aggsByTsid.child(), EsRelation.class); + + Div div = as(Alias.unwrap(eval.fields().get(0)), Div.class); + assertThat(Expressions.attribute(div.left()).id(), equalTo(finalAggs.aggregates().get(0).id())); + assertThat(Expressions.attribute(div.right()).id(), equalTo(finalAggs.aggregates().get(1).id())); + + assertThat(finalAggs.aggregateType(), equalTo(Aggregate.AggregateType.STANDARD)); + Sum sum = as(Alias.unwrap(finalAggs.aggregates().get(0)), Sum.class); + assertThat(Expressions.attribute(sum.field()).id(), equalTo(aggsByTsid.aggregates().get(0).id())); + Count count = as(Alias.unwrap(finalAggs.aggregates().get(1)), Count.class); + assertThat(Expressions.attribute(count.field()).id(), equalTo(aggsByTsid.aggregates().get(0).id())); + assertThat(finalAggs.groupings(), hasSize(2)); + assertThat(Expressions.attribute(finalAggs.groupings().get(0)).id(), equalTo(aggsByTsid.aggregates().get(1).id())); + assertThat(Expressions.attribute(finalAggs.groupings().get(1)).id(), equalTo(aggsByTsid.aggregates().get(2).id())); + + assertThat(finalAggs.groupings(), hasSize(2)); + + assertThat(aggsByTsid.aggregateType(), equalTo(Aggregate.AggregateType.METRICS)); + assertThat(aggsByTsid.aggregates(), hasSize(3)); // rates, values(cluster), values(pod) + Rate rate = as(Alias.unwrap(aggsByTsid.aggregates().get(0)), Rate.class); + assertThat(Expressions.attribute(rate.field()).name(), equalTo("network.total_bytes_in")); + Values values1 = as(Alias.unwrap(aggsByTsid.aggregates().get(1)), Values.class); + assertThat(Expressions.attribute(values1.field()).name(), equalTo("cluster")); + Values values2 = as(Alias.unwrap(aggsByTsid.aggregates().get(2)), Values.class); + assertThat(Expressions.attribute(values2.field()).name(), equalTo("pod")); + } + + public void testTranslateMetricsGroupedByTimeBucket() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + var query = "METRICS k8s sum(rate(network.total_bytes_in)) BY bucket(@timestamp, 1h)"; + var plan = logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query))); + Limit limit = as(plan, Limit.class); + Aggregate finalAgg = as(limit.child(), Aggregate.class); + assertThat(finalAgg.aggregates(), hasSize(2)); + Aggregate aggsByTsid = as(finalAgg.child(), Aggregate.class); + assertThat(aggsByTsid.aggregates(), hasSize(2)); // _tsid is dropped + Eval eval = as(aggsByTsid.child(), Eval.class); + assertThat(eval.fields(), hasSize(1)); + as(eval.child(), EsRelation.class); + + assertThat(finalAgg.aggregateType(), equalTo(Aggregate.AggregateType.STANDARD)); + Sum sum = as(Alias.unwrap(finalAgg.aggregates().get(0)), Sum.class); + assertThat(Expressions.attribute(sum.field()).id(), equalTo(aggsByTsid.aggregates().get(0).id())); + assertThat(finalAgg.groupings(), hasSize(1)); + assertThat(Expressions.attribute(finalAgg.groupings().get(0)).id(), equalTo(aggsByTsid.aggregates().get(1).id())); + + assertThat(aggsByTsid.aggregateType(), equalTo(Aggregate.AggregateType.METRICS)); + Rate rate = as(Alias.unwrap(aggsByTsid.aggregates().get(0)), Rate.class); + assertThat(Expressions.attribute(rate.field()).name(), equalTo("network.total_bytes_in")); + assertThat(Expressions.attribute(aggsByTsid.groupings().get(1)).id(), equalTo(eval.fields().get(0).id())); + Bucket bucket = as(Alias.unwrap(eval.fields().get(0)), Bucket.class); + assertThat(Expressions.attribute(bucket.field()).name(), equalTo("@timestamp")); + } + + public void testTranslateMetricsGroupedByTimeBucketAndDimensions() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + var query = """ + METRICS k8s avg(rate(network.total_bytes_in)) BY pod, bucket(@timestamp, 5 minute), cluster + | SORT cluster + | LIMIT 10 + """; + var plan = logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query))); + Project project = as(plan, Project.class); + TopN topN = as(project.child(), TopN.class); + Eval eval = as(topN.child(), Eval.class); + assertThat(eval.fields(), hasSize(1)); + Div div = as(Alias.unwrap(eval.fields().get(0)), Div.class); + Aggregate finalAgg = as(eval.child(), Aggregate.class); + Aggregate aggsByTsid = as(finalAgg.child(), Aggregate.class); + Eval bucket = as(aggsByTsid.child(), Eval.class); + as(bucket.child(), EsRelation.class); + assertThat(Expressions.attribute(div.left()).id(), equalTo(finalAgg.aggregates().get(0).id())); + assertThat(Expressions.attribute(div.right()).id(), equalTo(finalAgg.aggregates().get(1).id())); + + assertThat(finalAgg.aggregateType(), equalTo(Aggregate.AggregateType.STANDARD)); + assertThat(finalAgg.aggregates(), hasSize(5)); // sum, count, pod, bucket, cluster + Sum sum = as(Alias.unwrap(finalAgg.aggregates().get(0)), Sum.class); + Count count = as(Alias.unwrap(finalAgg.aggregates().get(1)), Count.class); + assertThat(Expressions.attribute(sum.field()).id(), equalTo(aggsByTsid.aggregates().get(0).id())); + assertThat(Expressions.attribute(count.field()).id(), equalTo(aggsByTsid.aggregates().get(0).id())); + assertThat(finalAgg.groupings(), hasSize(3)); + assertThat(Expressions.attribute(finalAgg.groupings().get(0)).id(), equalTo(aggsByTsid.aggregates().get(1).id())); + + assertThat(aggsByTsid.aggregateType(), equalTo(Aggregate.AggregateType.METRICS)); + assertThat(aggsByTsid.aggregates(), hasSize(4)); // rate, values(pod), values(cluster), bucket + Rate rate = as(Alias.unwrap(aggsByTsid.aggregates().get(0)), Rate.class); + assertThat(Expressions.attribute(rate.field()).name(), equalTo("network.total_bytes_in")); + Values podValues = as(Alias.unwrap(aggsByTsid.aggregates().get(1)), Values.class); + assertThat(Expressions.attribute(podValues.field()).name(), equalTo("pod")); + Values clusterValues = as(Alias.unwrap(aggsByTsid.aggregates().get(3)), Values.class); + assertThat(Expressions.attribute(clusterValues.field()).name(), equalTo("cluster")); + } + + public void testAdjustMetricsRateBeforeFinalAgg() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + var query = """ + METRICS k8s avg(round(1.05 * rate(network.total_bytes_in))) BY bucket(@timestamp, 1 minute), cluster + | SORT cluster + | LIMIT 10 + """; + var plan = logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query))); + Project project = as(plan, Project.class); + TopN topN = as(project.child(), TopN.class); + Eval evalDiv = as(topN.child(), Eval.class); + assertThat(evalDiv.fields(), hasSize(1)); + Div div = as(Alias.unwrap(evalDiv.fields().get(0)), Div.class); + + Aggregate finalAgg = as(evalDiv.child(), Aggregate.class); + assertThat(finalAgg.aggregates(), hasSize(4)); // sum, count, bucket, cluster + assertThat(finalAgg.groupings(), hasSize(2)); + + Eval evalRound = as(finalAgg.child(), Eval.class); + Round round = as(Alias.unwrap(evalRound.fields().get(0)), Round.class); + Mul mul = as(round.field(), Mul.class); + + Aggregate aggsByTsid = as(evalRound.child(), Aggregate.class); + assertThat(aggsByTsid.aggregates(), hasSize(3)); // rate, cluster, bucket + assertThat(aggsByTsid.groupings(), hasSize(2)); + + Eval evalBucket = as(aggsByTsid.child(), Eval.class); + assertThat(evalBucket.fields(), hasSize(1)); + Bucket bucket = as(Alias.unwrap(evalBucket.fields().get(0)), Bucket.class); + as(evalBucket.child(), EsRelation.class); + + assertThat(Expressions.attribute(div.left()).id(), equalTo(finalAgg.aggregates().get(0).id())); + assertThat(Expressions.attribute(div.right()).id(), equalTo(finalAgg.aggregates().get(1).id())); + + assertThat(finalAgg.aggregateType(), equalTo(Aggregate.AggregateType.STANDARD)); + + Sum sum = as(Alias.unwrap(finalAgg.aggregates().get(0)), Sum.class); + Count count = as(Alias.unwrap(finalAgg.aggregates().get(1)), Count.class); + assertThat(Expressions.attribute(sum.field()).id(), equalTo(evalRound.fields().get(0).id())); + assertThat(Expressions.attribute(count.field()).id(), equalTo(evalRound.fields().get(0).id())); + + assertThat( + Expressions.attribute(finalAgg.groupings().get(0)).id(), + equalTo(Expressions.attribute(aggsByTsid.groupings().get(1)).id()) + ); + assertThat(Expressions.attribute(finalAgg.groupings().get(1)).id(), equalTo(aggsByTsid.aggregates().get(1).id())); + + assertThat(Expressions.attribute(mul.left()).id(), equalTo(aggsByTsid.aggregates().get(0).id())); + assertThat(mul.right().fold(), equalTo(1.05)); + assertThat(aggsByTsid.aggregateType(), equalTo(Aggregate.AggregateType.METRICS)); + Rate rate = as(Alias.unwrap(aggsByTsid.aggregates().get(0)), Rate.class); + assertThat(Expressions.attribute(rate.field()).name(), equalTo("network.total_bytes_in")); + Values values = as(Alias.unwrap(aggsByTsid.aggregates().get(1)), Values.class); + assertThat(Expressions.attribute(values.field()).name(), equalTo("cluster")); + } + private Literal nullOf(DataType dataType) { return new Literal(Source.EMPTY, null, dataType); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index c139b28f1fdaa..43e806b9a55cb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.Tuple; +import org.elasticsearch.geometry.Circle; import org.elasticsearch.geometry.Polygon; import org.elasticsearch.geometry.ShapeType; import org.elasticsearch.index.IndexMode; @@ -63,7 +64,9 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialIntersects; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialWithin; +import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.StDistance; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; @@ -3475,6 +3478,55 @@ public void testPushSpatialIntersectsShapeToSource() { } } + public void testPushSpatialDistanceToSource() { + for (String distanceFunction : new String[] { + "ST_DISTANCE(location, TO_GEOPOINT(\"POINT(12.565 55.673)\"))", + "ST_DISTANCE(TO_GEOPOINT(\"POINT(12.565 55.673)\"), location)" }) { + + for (String op : new String[] { "<", "<=", ">", ">=" }) { + var eq = op.contains("="); + var lt = op.contains("<"); + var predicate = lt ? distanceFunction + " " + op + " 600000" : "600000 " + op + " " + distanceFunction; + var query = "FROM airports | WHERE " + predicate + " AND scalerank > 1"; + var plan = this.physicalPlan(query, airports); + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var fragment = as(exchange.child(), FragmentExec.class); + var limit2 = as(fragment.fragment(), Limit.class); + var filter = as(limit2.child(), Filter.class); + var and = as(filter.condition(), And.class); + var comp = as(and.left(), EsqlBinaryComparison.class); + var expectedComp = eq ? LessThanOrEqual.class : LessThan.class; // normalized to less than + assertThat("filter contains expected binary comparison for " + predicate, comp, instanceOf(expectedComp)); + assertThat("filter contains ST_DISTANCE", comp.left(), instanceOf(StDistance.class)); + + var optimized = optimizedPlan(plan); + var topLimit = as(optimized, LimitExec.class); + exchange = as(topLimit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var fieldExtract = as(project.child(), FieldExtractExec.class); + var source = source(fieldExtract.child()); + // TODO: bring back SingleValueQuery once it can handle LeafShapeFieldData + // var condition = as(sv(source.query(), "location"), AbstractGeometryQueryBuilder.class); + var bool = as(source.query(), BoolQueryBuilder.class); + var rangeQueryBuilders = bool.filter().stream().filter(p -> p instanceof SingleValueQuery.Builder).toList(); + assertThat("Expected one range query builder", rangeQueryBuilders.size(), equalTo(1)); + assertThat(((SingleValueQuery.Builder) rangeQueryBuilders.get(0)).field(), equalTo("scalerank")); + var shapeQueryBuilders = bool.filter().stream().filter(p -> p instanceof SpatialRelatesQuery.ShapeQueryBuilder).toList(); + assertThat("Expected one shape query builder", shapeQueryBuilders.size(), equalTo(1)); + var condition = as(shapeQueryBuilders.get(0), SpatialRelatesQuery.ShapeQueryBuilder.class); + assertThat("Geometry field name", condition.fieldName(), equalTo("location")); + assertThat("Spatial relationship", condition.relation(), equalTo(ShapeRelation.INTERSECTS)); + assertThat("Geometry is Circle", condition.shape().type(), equalTo(ShapeType.CIRCLE)); + var circle = as(condition.shape(), Circle.class); + assertThat("Circle center-x", circle.getX(), equalTo(12.565)); + assertThat("Circle center-y", circle.getY(), equalTo(55.673)); + var expected = eq ? 600000.0 : Math.nextDown(600000.0); + assertThat("Circle radius", circle.getRadiusMeters(), equalTo(expected)); + } + } + } + public void testPushCartesianSpatialIntersectsToSource() { for (String query : new String[] { """ FROM airports_web diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java index b3685ffe746a0..1268ffb64a848 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; +import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Dissect; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsqlAggregate; @@ -237,6 +238,7 @@ public void testStatsWithGroups() { new EsqlAggregate( EMPTY, PROCESSING_CMD_INPUT, + Aggregate.AggregateType.STANDARD, List.of(attribute("c"), attribute("d.e")), List.of( new Alias(EMPTY, "b", new UnresolvedFunction(EMPTY, "min", DEFAULT, List.of(attribute("a")))), @@ -253,6 +255,7 @@ public void testStatsWithoutGroups() { new EsqlAggregate( EMPTY, PROCESSING_CMD_INPUT, + Aggregate.AggregateType.STANDARD, List.of(), List.of( new Alias(EMPTY, "min(a)", new UnresolvedFunction(EMPTY, "min", DEFAULT, List.of(attribute("a")))), @@ -265,7 +268,13 @@ public void testStatsWithoutGroups() { public void testStatsWithoutAggs() throws Exception { assertEquals( - new EsqlAggregate(EMPTY, PROCESSING_CMD_INPUT, List.of(attribute("a")), List.of(attribute("a"))), + new EsqlAggregate( + EMPTY, + PROCESSING_CMD_INPUT, + Aggregate.AggregateType.STANDARD, + List.of(attribute("a")), + List.of(attribute("a")) + ), processingCommand("stats by a") ); } @@ -1285,6 +1294,7 @@ public void testSimpleMetricsWithStats() { new EsqlAggregate( EMPTY, new EsqlUnresolvedRelation(EMPTY, new TableIdentifier(EMPTY, null, "foo"), List.of(), IndexMode.TIME_SERIES), + Aggregate.AggregateType.METRICS, List.of(attribute("ts")), List.of(new Alias(EMPTY, "load", new UnresolvedFunction(EMPTY, "avg", DEFAULT, List.of(attribute("cpu")))), attribute("ts")) ) @@ -1294,6 +1304,7 @@ public void testSimpleMetricsWithStats() { new EsqlAggregate( EMPTY, new EsqlUnresolvedRelation(EMPTY, new TableIdentifier(EMPTY, null, "foo,bar"), List.of(), IndexMode.TIME_SERIES), + Aggregate.AggregateType.METRICS, List.of(attribute("ts")), List.of(new Alias(EMPTY, "load", new UnresolvedFunction(EMPTY, "avg", DEFAULT, List.of(attribute("cpu")))), attribute("ts")) ) @@ -1303,6 +1314,7 @@ public void testSimpleMetricsWithStats() { new EsqlAggregate( EMPTY, new EsqlUnresolvedRelation(EMPTY, new TableIdentifier(EMPTY, null, "foo,bar"), List.of(), IndexMode.TIME_SERIES), + Aggregate.AggregateType.METRICS, List.of(attribute("ts")), List.of( new Alias(EMPTY, "load", new UnresolvedFunction(EMPTY, "avg", DEFAULT, List.of(attribute("cpu")))), @@ -1325,6 +1337,7 @@ public void testSimpleMetricsWithStats() { new EsqlAggregate( EMPTY, new EsqlUnresolvedRelation(EMPTY, new TableIdentifier(EMPTY, null, "foo*"), List.of(), IndexMode.TIME_SERIES), + Aggregate.AggregateType.METRICS, List.of(), List.of(new Alias(EMPTY, "count(errors)", new UnresolvedFunction(EMPTY, "count", DEFAULT, List.of(attribute("errors"))))) ) @@ -1334,6 +1347,7 @@ public void testSimpleMetricsWithStats() { new EsqlAggregate( EMPTY, new EsqlUnresolvedRelation(EMPTY, new TableIdentifier(EMPTY, null, "foo*"), List.of(), IndexMode.TIME_SERIES), + Aggregate.AggregateType.METRICS, List.of(), List.of(new Alias(EMPTY, "a(b)", new UnresolvedFunction(EMPTY, "a", DEFAULT, List.of(attribute("b"))))) ) @@ -1343,6 +1357,7 @@ public void testSimpleMetricsWithStats() { new EsqlAggregate( EMPTY, new EsqlUnresolvedRelation(EMPTY, new TableIdentifier(EMPTY, null, "foo*"), List.of(), IndexMode.TIME_SERIES), + Aggregate.AggregateType.METRICS, List.of(), List.of(new Alias(EMPTY, "a(b)", new UnresolvedFunction(EMPTY, "a", DEFAULT, List.of(attribute("b"))))) ) @@ -1352,6 +1367,7 @@ public void testSimpleMetricsWithStats() { new EsqlAggregate( EMPTY, new EsqlUnresolvedRelation(EMPTY, new TableIdentifier(EMPTY, null, "foo*"), List.of(), IndexMode.TIME_SERIES), + Aggregate.AggregateType.METRICS, List.of(), List.of(new Alias(EMPTY, "a1(b2)", new UnresolvedFunction(EMPTY, "a1", DEFAULT, List.of(attribute("b2"))))) ) @@ -1361,6 +1377,7 @@ public void testSimpleMetricsWithStats() { new EsqlAggregate( EMPTY, new EsqlUnresolvedRelation(EMPTY, new TableIdentifier(EMPTY, null, "foo*,bar*"), List.of(), IndexMode.TIME_SERIES), + Aggregate.AggregateType.METRICS, List.of(attribute("c"), attribute("d.e")), List.of( new Alias(EMPTY, "b", new UnresolvedFunction(EMPTY, "min", DEFAULT, List.of(attribute("a")))), diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java index e50ba59a31b2d..c14245d212cf0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java @@ -12,10 +12,10 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.dissect.DissectParser; import org.elasticsearch.xpack.esql.core.capabilities.UnresolvedException; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Order; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.expression.UnresolvedNamedExpression; @@ -88,8 +88,8 @@ protected Object pluggableMakeArg(Class> toBuildClass, Class (NamedExpression) makeArg(NamedExpression.class)), + JoinType.LEFT, + randomList(0, 10, () -> (Attribute) makeArg(Attribute.class)), randomList(0, 10, () -> (Expression) makeArg(Expression.class)) ); } diff --git a/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexIT.java b/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexIT.java index e378ce06611c6..36d4751423113 100644 --- a/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexIT.java +++ b/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexIT.java @@ -112,17 +112,15 @@ public void testTimestampRangeRecalculatedOnStalePrimaryAllocation() throws IOEx ensureYellowAndNoInitializingShards("index"); - final IndexLongFieldRange timestampFieldRange = clusterAdmin().prepareState() - .get() - .getState() - .metadata() - .index("index") - .getTimestampRange(); + IndexMetadata indexMetadata = clusterAdmin().prepareState().get().getState().metadata().index("index"); + final IndexLongFieldRange timestampFieldRange = indexMetadata.getTimestampRange(); assertThat(timestampFieldRange, not(sameInstance(IndexLongFieldRange.UNKNOWN))); assertThat(timestampFieldRange, not(sameInstance(IndexLongFieldRange.EMPTY))); assertTrue(timestampFieldRange.isComplete()); assertThat(timestampFieldRange.getMin(), equalTo(Instant.parse("2010-01-06T02:03:04.567Z").toEpochMilli())); assertThat(timestampFieldRange.getMax(), equalTo(Instant.parse("2010-01-06T02:03:04.567Z").toEpochMilli())); + + assertThat(indexMetadata.getEventIngestedRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); } public void testTimestampFieldTypeExposedByAllIndicesServices() throws Exception { @@ -155,6 +153,11 @@ public void testTimestampFieldTypeExposedByAllIndicesServices() throws Exception jsonBuilder().startObject() .startObject("_doc") .startObject("properties") + .startObject(IndexMetadata.EVENT_INGESTED_FIELD_NAME) + .field("type", "date") + .field("format", "dd LLL yyyy HH:mm:ssX") + .field("locale", locale) + .endObject() .startObject(DataStream.TIMESTAMP_FIELD_NAME) .field("type", "date") .field("format", "dd LLL yyyy HH:mm:ssX") diff --git a/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexTests.java b/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexTests.java index 92d042a98b16e..ccb917c9dbda5 100644 --- a/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexTests.java +++ b/x-pack/plugin/frozen-indices/src/internalClusterTest/java/org/elasticsearch/index/engine/frozen/FrozenIndexTests.java @@ -228,6 +228,7 @@ public void testSearchAndGetAPIsAreThrottled() throws IOException { public void testFreezeAndUnfreeze() { final IndexService originalIndexService = createIndex("index", Settings.builder().put("index.number_of_shards", 2).build()); assertThat(originalIndexService.getMetadata().getTimestampRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); + assertThat(originalIndexService.getMetadata().getEventIngestedRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); prepareIndex("index").setId("2").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); @@ -250,6 +251,7 @@ public void testFreezeAndUnfreeze() { IndexShard shard = indexService.getShard(0); assertEquals(0, shard.refreshStats().getTotal()); assertThat(indexService.getMetadata().getTimestampRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); + assertThat(indexService.getMetadata().getEventIngestedRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); } assertAcked( client().execute( @@ -268,6 +270,7 @@ public void testFreezeAndUnfreeze() { Engine engine = IndexShardTestCase.getEngine(shard); assertThat(engine, Matchers.instanceOf(InternalEngine.class)); assertThat(indexService.getMetadata().getTimestampRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); + assertThat(indexService.getMetadata().getEventIngestedRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); } prepareIndex("index").setId("4").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); } @@ -671,17 +674,15 @@ public void testComputesTimestampRangeFromMilliseconds() { client().execute(FreezeIndexAction.INSTANCE, new FreezeRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "index")).actionGet() ); - final IndexLongFieldRange timestampFieldRange = clusterAdmin().prepareState() - .get() - .getState() - .metadata() - .index("index") - .getTimestampRange(); + IndexMetadata indexMetadata = clusterAdmin().prepareState().get().getState().metadata().index("index"); + final IndexLongFieldRange timestampFieldRange = indexMetadata.getTimestampRange(); assertThat(timestampFieldRange, not(sameInstance(IndexLongFieldRange.UNKNOWN))); assertThat(timestampFieldRange, not(sameInstance(IndexLongFieldRange.EMPTY))); assertTrue(timestampFieldRange.isComplete()); assertThat(timestampFieldRange.getMin(), equalTo(Instant.parse("2010-01-05T01:02:03.456Z").toEpochMilli())); assertThat(timestampFieldRange.getMax(), equalTo(Instant.parse("2010-01-06T02:03:04.567Z").toEpochMilli())); + + assertThat(indexMetadata.getEventIngestedRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); } public void testComputesTimestampRangeFromNanoseconds() throws IOException { @@ -705,18 +706,98 @@ public void testComputesTimestampRangeFromNanoseconds() throws IOException { client().execute(FreezeIndexAction.INSTANCE, new FreezeRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "index")).actionGet() ); - final IndexLongFieldRange timestampFieldRange = clusterAdmin().prepareState() - .get() - .getState() - .metadata() - .index("index") - .getTimestampRange(); + IndexMetadata indexMetadata = clusterAdmin().prepareState().get().getState().metadata().index("index"); + final IndexLongFieldRange timestampFieldRange = indexMetadata.getTimestampRange(); assertThat(timestampFieldRange, not(sameInstance(IndexLongFieldRange.UNKNOWN))); assertThat(timestampFieldRange, not(sameInstance(IndexLongFieldRange.EMPTY))); assertTrue(timestampFieldRange.isComplete()); final DateFieldMapper.Resolution resolution = DateFieldMapper.Resolution.NANOSECONDS; assertThat(timestampFieldRange.getMin(), equalTo(resolution.convert(Instant.parse("2010-01-05T01:02:03.456789012Z")))); assertThat(timestampFieldRange.getMax(), equalTo(resolution.convert(Instant.parse("2010-01-06T02:03:04.567890123Z")))); + + assertThat(indexMetadata.getEventIngestedRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); } + public void testComputesEventIngestedRangeFromMilliseconds() { + final int shardCount = between(1, 3); + createIndex("index", Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, shardCount).build()); + prepareIndex("index").setSource(IndexMetadata.EVENT_INGESTED_FIELD_NAME, "2010-01-05T01:02:03.456Z").get(); + prepareIndex("index").setSource(IndexMetadata.EVENT_INGESTED_FIELD_NAME, "2010-01-06T02:03:04.567Z").get(); + + assertAcked( + client().execute(FreezeIndexAction.INSTANCE, new FreezeRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "index")).actionGet() + ); + + IndexMetadata indexMetadata = clusterAdmin().prepareState().get().getState().metadata().index("index"); + final IndexLongFieldRange eventIngestedRange = indexMetadata.getEventIngestedRange(); + assertThat(eventIngestedRange, not(sameInstance(IndexLongFieldRange.UNKNOWN))); + assertThat(eventIngestedRange, not(sameInstance(IndexLongFieldRange.EMPTY))); + assertTrue(eventIngestedRange.isComplete()); + assertThat(eventIngestedRange.getMin(), equalTo(Instant.parse("2010-01-05T01:02:03.456Z").toEpochMilli())); + assertThat(eventIngestedRange.getMax(), equalTo(Instant.parse("2010-01-06T02:03:04.567Z").toEpochMilli())); + + assertThat(indexMetadata.getTimestampRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); + } + + public void testComputesEventIngestedRangeFromNanoseconds() throws IOException { + + final XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(IndexMetadata.EVENT_INGESTED_FIELD_NAME) + .field("type", "date_nanos") + .field("format", "strict_date_optional_time_nanos") + .endObject() + .endObject() + .endObject(); + + final int shardCount = between(1, 3); + createIndex("index", Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, shardCount).build(), mapping); + prepareIndex("index").setSource(IndexMetadata.EVENT_INGESTED_FIELD_NAME, "2010-01-05T01:02:03.456789012Z").get(); + prepareIndex("index").setSource(IndexMetadata.EVENT_INGESTED_FIELD_NAME, "2010-01-06T02:03:04.567890123Z").get(); + + assertAcked( + client().execute(FreezeIndexAction.INSTANCE, new FreezeRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "index")).actionGet() + ); + + IndexMetadata indexMetadata = clusterAdmin().prepareState().get().getState().metadata().index("index"); + final IndexLongFieldRange eventIngestedRange = indexMetadata.getEventIngestedRange(); + assertThat(eventIngestedRange, not(sameInstance(IndexLongFieldRange.UNKNOWN))); + assertThat(eventIngestedRange, not(sameInstance(IndexLongFieldRange.EMPTY))); + assertTrue(eventIngestedRange.isComplete()); + final DateFieldMapper.Resolution resolution = DateFieldMapper.Resolution.NANOSECONDS; + assertThat(eventIngestedRange.getMin(), equalTo(resolution.convert(Instant.parse("2010-01-05T01:02:03.456789012Z")))); + assertThat(eventIngestedRange.getMax(), equalTo(resolution.convert(Instant.parse("2010-01-06T02:03:04.567890123Z")))); + + assertThat(indexMetadata.getTimestampRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); + } + + public void testComputesEventIngestedAndTimestampRangesWhenBothPresent() { + final int shardCount = between(1, 3); + createIndex("index", Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, shardCount).build()); + prepareIndex("index").setSource(IndexMetadata.EVENT_INGESTED_FIELD_NAME, "2010-01-05T01:02:03.456Z").get(); + prepareIndex("index").setSource(IndexMetadata.EVENT_INGESTED_FIELD_NAME, "2010-01-06T02:03:04.567Z").get(); + prepareIndex("index").setSource(DataStream.TIMESTAMP_FIELD_NAME, "2010-01-05T01:55:03.456Z").get(); + prepareIndex("index").setSource(DataStream.TIMESTAMP_FIELD_NAME, "2010-01-06T02:55:04.567Z").get(); + + assertAcked( + client().execute(FreezeIndexAction.INSTANCE, new FreezeRequest(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "index")).actionGet() + ); + + IndexMetadata indexMetadata = clusterAdmin().prepareState().get().getState().metadata().index("index"); + + final IndexLongFieldRange eventIngestedRange = indexMetadata.getEventIngestedRange(); + assertThat(eventIngestedRange, not(sameInstance(IndexLongFieldRange.UNKNOWN))); + assertThat(eventIngestedRange, not(sameInstance(IndexLongFieldRange.EMPTY))); + assertTrue(eventIngestedRange.isComplete()); + assertThat(eventIngestedRange.getMin(), equalTo(Instant.parse("2010-01-05T01:02:03.456Z").toEpochMilli())); + assertThat(eventIngestedRange.getMax(), equalTo(Instant.parse("2010-01-06T02:03:04.567Z").toEpochMilli())); + + final IndexLongFieldRange timestampRange = indexMetadata.getTimestampRange(); + assertThat(timestampRange, not(sameInstance(IndexLongFieldRange.UNKNOWN))); + assertThat(timestampRange, not(sameInstance(IndexLongFieldRange.EMPTY))); + assertTrue(timestampRange.isComplete()); + assertThat(timestampRange.getMin(), equalTo(Instant.parse("2010-01-05T01:55:03.456Z").toEpochMilli())); + assertThat(timestampRange.getMax(), equalTo(Instant.parse("2010-01-06T02:55:04.567Z").toEpochMilli())); + } } diff --git a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java index ac975365c01aa..0ac8484abdf92 100644 --- a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java +++ b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java @@ -56,19 +56,23 @@ import org.elasticsearch.xpack.core.ilm.TimeseriesLifecycleType; import org.elasticsearch.xpack.core.ilm.UnfollowAction; import org.elasticsearch.xpack.core.ilm.WaitForSnapshotAction; +import org.junit.Assert; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; @@ -84,35 +88,34 @@ public void testDependencies() { protected NamedXContentRegistry xContentRegistry() { List entries = new ArrayList<>(ClusterModule.getNamedXWriteables()); - entries.addAll( - Arrays.asList( - new NamedXContentRegistry.Entry( - LifecycleType.class, - new ParseField(TimeseriesLifecycleType.TYPE), - (p) -> TimeseriesLifecycleType.INSTANCE - ), - new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(AllocateAction.NAME), AllocateAction::parse), - new NamedXContentRegistry.Entry( - LifecycleAction.class, - new ParseField(WaitForSnapshotAction.NAME), - WaitForSnapshotAction::parse - ), - new NamedXContentRegistry.Entry( - LifecycleAction.class, - new ParseField(SearchableSnapshotAction.NAME), - SearchableSnapshotAction::parse - ), - new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(DeleteAction.NAME), DeleteAction::parse), - new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(ForceMergeAction.NAME), ForceMergeAction::parse), - new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(ReadOnlyAction.NAME), ReadOnlyAction::parse), - new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(RolloverAction.NAME), RolloverAction::parse), - new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(ShrinkAction.NAME), ShrinkAction::parse), - new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(FreezeAction.NAME), FreezeAction::parse), - new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(SetPriorityAction.NAME), SetPriorityAction::parse), - new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(MigrateAction.NAME), MigrateAction::parse), - new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(UnfollowAction.NAME), UnfollowAction::parse), - new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(DownsampleAction.NAME), DownsampleAction::parse) - ) + Collections.addAll( + entries, + new NamedXContentRegistry.Entry( + LifecycleType.class, + new ParseField(TimeseriesLifecycleType.TYPE), + (p) -> TimeseriesLifecycleType.INSTANCE + ), + new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(AllocateAction.NAME), AllocateAction::parse), + new NamedXContentRegistry.Entry( + LifecycleAction.class, + new ParseField(WaitForSnapshotAction.NAME), + WaitForSnapshotAction::parse + ), + new NamedXContentRegistry.Entry( + LifecycleAction.class, + new ParseField(SearchableSnapshotAction.NAME), + SearchableSnapshotAction::parse + ), + new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(DeleteAction.NAME), DeleteAction::parse), + new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(ForceMergeAction.NAME), ForceMergeAction::parse), + new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(ReadOnlyAction.NAME), ReadOnlyAction::parse), + new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(RolloverAction.NAME), RolloverAction::parse), + new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(ShrinkAction.NAME), ShrinkAction::parse), + new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(FreezeAction.NAME), FreezeAction::parse), + new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(SetPriorityAction.NAME), SetPriorityAction::parse), + new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(MigrateAction.NAME), MigrateAction::parse), + new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(UnfollowAction.NAME), UnfollowAction::parse), + new NamedXContentRegistry.Entry(LifecycleAction.class, new ParseField(DownsampleAction.NAME), DownsampleAction::parse) ); return new NamedXContentRegistry(entries); } @@ -130,7 +133,7 @@ public void testValidationFails() { ClusterState state = ClusterState.builder(clusterName).build(); ReservedLifecycleAction action = new ReservedLifecycleAction(xContentRegistry(), client, mock(XPackLicenseState.class)); - TransformState prevState = new TransformState(state, Collections.emptySet()); + TransformState prevState = new TransformState(state, Set.of()); String badPolicyJSON = """ { @@ -145,9 +148,9 @@ public void testValidationFails() { } }"""; - assertEquals( - "[1:2] [lifecycle_policy] unknown field [phase] did you mean [phases]?", - expectThrows(XContentParseException.class, () -> processJSON(action, prevState, badPolicyJSON)).getMessage() + assertThat( + expectThrows(XContentParseException.class, () -> processJSON(action, prevState, badPolicyJSON)).getMessage(), + is("[1:2] [lifecycle_policy] unknown field [phase] did you mean [phases]?") ); } @@ -162,10 +165,10 @@ public void testActionAddRemove() throws Exception { String emptyJSON = ""; - TransformState prevState = new TransformState(state, Collections.emptySet()); + TransformState prevState = new TransformState(state, Set.of()); TransformState updatedState = processJSON(action, prevState, emptyJSON); - assertEquals(0, updatedState.keys().size()); + assertThat(updatedState.keys(), empty()); assertEquals(prevState.state(), updatedState.state()); String twoPoliciesJSON = """ @@ -359,9 +362,9 @@ public void testOperatorControllerFromJSONContent() throws IOException { AtomicReference x = new AtomicReference<>(); try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, testJSON)) { - controller.process("operator", parser, (e) -> x.set(e)); + controller.process("operator", parser, x::set); - assertTrue(x.get() instanceof IllegalStateException); + assertThat(x.get(), instanceOf(IllegalStateException.class)); assertThat(x.get().getMessage(), containsString("Error processing state change request for operator")); } @@ -380,11 +383,7 @@ public void testOperatorControllerFromJSONContent() throws IOException { ); try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, testJSON)) { - controller.process("operator", parser, (e) -> { - if (e != null) { - fail("Should not fail"); - } - }); + controller.process("operator", parser, Assert::assertNull); } } @@ -411,9 +410,9 @@ public void testOperatorControllerWithPluginPackage() { "my_timeseries_lifecycle", Map.of( "warm", - new Phase("warm", new TimeValue(10, TimeUnit.SECONDS), Collections.emptyMap()), + new Phase("warm", new TimeValue(10, TimeUnit.SECONDS), Map.of()), "delete", - new Phase("delete", new TimeValue(30, TimeUnit.SECONDS), Collections.emptyMap()) + new Phase("delete", new TimeValue(30, TimeUnit.SECONDS), Map.of()) ) ) ) @@ -421,9 +420,9 @@ public void testOperatorControllerWithPluginPackage() { new ReservedStateVersion(123L, Version.CURRENT) ); - controller.process("operator", pack, (e) -> x.set(e)); + controller.process("operator", pack, x::set); - assertTrue(x.get() instanceof IllegalStateException); + assertThat(x.get(), instanceOf(IllegalStateException.class)); assertThat(x.get().getMessage(), containsString("Error processing state change request for operator")); Client client = mock(Client.class); @@ -440,10 +439,6 @@ public void testOperatorControllerWithPluginPackage() { ) ); - controller.process("operator", pack, (e) -> { - if (e != null) { - fail("Should not fail"); - } - }); + controller.process("operator", pack, Assert::assertNull); } } diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index f4378d8ab5b7c..92afa3faa51e3 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -38,6 +38,169 @@ dependencies { clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') api "com.ibm.icu:icu4j:${versions.icu4j}" + + runtimeOnly 'com.google.guava:guava:32.0.1-jre' + implementation 'com.google.code.gson:gson:2.10' + implementation "com.google.protobuf:protobuf-java-util:${versions.protobuf}" + implementation "com.google.protobuf:protobuf-java:${versions.protobuf}" + implementation 'com.google.api.grpc:proto-google-iam-v1:1.6.2' + implementation 'com.google.auth:google-auth-library-credentials:1.11.0' + implementation 'com.google.auth:google-auth-library-oauth2-http:1.11.0' + implementation "com.google.oauth-client:google-oauth-client:${versions.google_oauth_client}" + implementation 'com.google.api-client:google-api-client:2.1.1' + implementation 'com.google.http-client:google-http-client:1.42.3' + implementation 'com.google.http-client:google-http-client-gson:1.42.3' + implementation 'com.google.http-client:google-http-client-appengine:1.42.3' + implementation 'com.google.http-client:google-http-client-jackson2:1.42.3' + implementation "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" + implementation 'com.google.api:gax-httpjson:0.105.1' + implementation 'io.grpc:grpc-context:1.49.2' + implementation 'io.opencensus:opencensus-api:0.31.1' + implementation 'io.opencensus:opencensus-contrib-http-util:0.31.1' +} + +tasks.named("dependencyLicenses").configure { + mapping from: /google-auth-.*/, to: 'google-auth' + mapping from: /google-http-.*/, to: 'google-http' + mapping from: /opencensus.*/, to: 'opencensus' + mapping from: /protobuf.*/, to: 'protobuf' + mapping from: /proto-google.*/, to: 'proto-google' + mapping from: /jackson.*/, to: 'jackson' +} + +tasks.named("thirdPartyAudit").configure { + ignoreViolations( + // uses internal java api: sun.misc.Unsafe + 'com.google.protobuf.UnsafeUtil', + 'com.google.protobuf.UnsafeUtil$1', + 'com.google.protobuf.UnsafeUtil$JvmMemoryAccessor', + 'com.google.protobuf.UnsafeUtil$MemoryAccessor', + 'com.google.protobuf.MessageSchema', + 'com.google.protobuf.UnsafeUtil$Android32MemoryAccessor', + 'com.google.protobuf.UnsafeUtil$Android64MemoryAccessor', + 'com.google.common.cache.Striped64', + 'com.google.common.cache.Striped64$1', + 'com.google.common.cache.Striped64$Cell', + 'com.google.common.hash.Striped64', + 'com.google.common.hash.Striped64$1', + 'com.google.common.hash.Striped64$Cell', + 'com.google.common.hash.LittleEndianByteArray$UnsafeByteArray', + 'com.google.common.hash.LittleEndianByteArray$UnsafeByteArray$1', + 'com.google.common.hash.LittleEndianByteArray$UnsafeByteArray$2', + 'com.google.common.util.concurrent.AbstractFuture$UnsafeAtomicHelper', + 'com.google.common.util.concurrent.AbstractFuture$UnsafeAtomicHelper$1', + 'com.google.common.hash.LittleEndianByteArray$UnsafeByteArray', + 'com.google.common.primitives.UnsignedBytes$LexicographicalComparatorHolder$UnsafeComparator', + 'com.google.common.primitives.UnsignedBytes$LexicographicalComparatorHolder$UnsafeComparator$1', + ) + + ignoreMissingClasses( + 'com.google.api.AnnotationsProto', + 'com.google.api.ClientProto', + 'com.google.api.FieldBehaviorProto', + 'com.google.api.ResourceProto', + 'com.google.api.core.AbstractApiFuture', + 'com.google.api.core.ApiFunction', + 'com.google.api.core.ApiFuture', + 'com.google.api.core.ApiFutureCallback', + 'com.google.api.core.ApiFutures', + 'com.google.api.gax.core.BackgroundResource', + 'com.google.api.gax.core.ExecutorProvider', + 'com.google.api.gax.core.GaxProperties', + 'com.google.api.gax.core.GoogleCredentialsProvider', + 'com.google.api.gax.core.GoogleCredentialsProvider$Builder', + 'com.google.api.gax.core.InstantiatingExecutorProvider', + 'com.google.api.gax.core.InstantiatingExecutorProvider$Builder', + 'com.google.api.gax.longrunning.OperationSnapshot', + 'com.google.api.gax.paging.AbstractFixedSizeCollection', + 'com.google.api.gax.paging.AbstractPage', + 'com.google.api.gax.paging.AbstractPagedListResponse', + 'com.google.api.gax.retrying.RetrySettings', + 'com.google.api.gax.retrying.RetrySettings$Builder', + 'com.google.api.gax.rpc.ApiCallContext', + 'com.google.api.gax.rpc.ApiCallContext$Key', + 'com.google.api.gax.rpc.ApiClientHeaderProvider', + 'com.google.api.gax.rpc.ApiClientHeaderProvider$Builder', + 'com.google.api.gax.rpc.ApiException', + 'com.google.api.gax.rpc.ApiExceptionFactory', + 'com.google.api.gax.rpc.BatchingCallSettings', + 'com.google.api.gax.rpc.Callables', + 'com.google.api.gax.rpc.ClientContext', + 'com.google.api.gax.rpc.ClientSettings', + 'com.google.api.gax.rpc.ClientSettings$Builder', + 'com.google.api.gax.rpc.FixedHeaderProvider', + 'com.google.api.gax.rpc.HeaderProvider', + 'com.google.api.gax.rpc.LongRunningClient', + 'com.google.api.gax.rpc.OperationCallSettings', + 'com.google.api.gax.rpc.OperationCallable', + 'com.google.api.gax.rpc.PageContext', + 'com.google.api.gax.rpc.PagedCallSettings', + 'com.google.api.gax.rpc.PagedCallSettings$Builder', + 'com.google.api.gax.rpc.PagedListDescriptor', + 'com.google.api.gax.rpc.PagedListResponseFactory', + 'com.google.api.gax.rpc.ResponseObserver', + 'com.google.api.gax.rpc.ServerStreamingCallSettings', + 'com.google.api.gax.rpc.ServerStreamingCallable', + 'com.google.api.gax.rpc.StateCheckingResponseObserver', + 'com.google.api.gax.rpc.StatusCode', + 'com.google.api.gax.rpc.StatusCode$Code', + 'com.google.api.gax.rpc.StreamController', + 'com.google.api.gax.rpc.StubSettings', + 'com.google.api.gax.rpc.StubSettings$Builder', + 'com.google.api.gax.rpc.TranslatingUnaryCallable', + 'com.google.api.gax.rpc.TransportChannel', + 'com.google.api.gax.rpc.TransportChannelProvider', + 'com.google.api.gax.rpc.UnaryCallSettings', + 'com.google.api.gax.rpc.UnaryCallSettings$Builder', + 'com.google.api.gax.rpc.UnaryCallable', + 'com.google.api.gax.rpc.internal.ApiCallContextOptions', + 'com.google.api.gax.rpc.internal.Headers', + 'com.google.api.gax.rpc.mtls.MtlsProvider', + 'com.google.api.gax.tracing.ApiTracer', + 'com.google.api.gax.tracing.BaseApiTracer', + 'com.google.api.gax.tracing.SpanName', + 'com.google.api.pathtemplate.PathTemplate', + 'com.google.common.util.concurrent.internal.InternalFutureFailureAccess', + 'com.google.common.util.concurrent.internal.InternalFutures', + 'com.google.longrunning.CancelOperationRequest', + 'com.google.longrunning.CancelOperationRequest$Builder', + 'com.google.longrunning.DeleteOperationRequest', + 'com.google.longrunning.DeleteOperationRequest$Builder', + 'com.google.longrunning.GetOperationRequest', + 'com.google.longrunning.GetOperationRequest$Builder', + 'com.google.longrunning.ListOperationsRequest', + 'com.google.longrunning.ListOperationsRequest$Builder', + 'com.google.longrunning.ListOperationsResponse', + 'com.google.longrunning.Operation', + 'com.google.rpc.Code', + 'com.google.rpc.Status', + 'com.google.type.Expr', + 'com.google.type.Expr$Builder', + 'com.google.type.ExprOrBuilder', + 'com.google.type.ExprProto', + 'org.threeten.bp.Duration', + 'org.threeten.bp.Instant', + 'com.google.api.client.http.apache.v2.ApacheHttpTransport', + 'com.google.appengine.api.datastore.Blob', + 'com.google.appengine.api.datastore.DatastoreService', + 'com.google.appengine.api.datastore.DatastoreServiceFactory', + 'com.google.appengine.api.datastore.Entity', + 'com.google.appengine.api.datastore.Key', + 'com.google.appengine.api.datastore.KeyFactory', + 'com.google.appengine.api.datastore.PreparedQuery', + 'com.google.appengine.api.datastore.Query', + 'com.google.appengine.api.memcache.Expiration', + 'com.google.appengine.api.memcache.MemcacheService', + 'com.google.appengine.api.memcache.MemcacheServiceFactory', + 'com.google.appengine.api.urlfetch.FetchOptions$Builder', + 'com.google.appengine.api.urlfetch.FetchOptions', + 'com.google.appengine.api.urlfetch.HTTPHeader', + 'com.google.appengine.api.urlfetch.HTTPMethod', + 'com.google.appengine.api.urlfetch.HTTPRequest', + 'com.google.appengine.api.urlfetch.HTTPResponse', + 'com.google.appengine.api.urlfetch.URLFetchService', + 'com.google.appengine.api.urlfetch.URLFetchServiceFactory' + ) } if (BuildParams.isSnapshotBuild() == false) { diff --git a/x-pack/plugin/inference/licenses/gax-httpjson-LICENSE.txt b/x-pack/plugin/inference/licenses/gax-httpjson-LICENSE.txt new file mode 100644 index 0000000000000..267561bb386de --- /dev/null +++ b/x-pack/plugin/inference/licenses/gax-httpjson-LICENSE.txt @@ -0,0 +1,27 @@ +Copyright 2016, Google Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/x-pack/plugin/inference/licenses/gax-httpjson-NOTICE.txt b/x-pack/plugin/inference/licenses/gax-httpjson-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/licenses/google-api-client-LICENSE.txt b/x-pack/plugin/inference/licenses/google-api-client-LICENSE.txt new file mode 100644 index 0000000000000..4eedc0116add7 --- /dev/null +++ b/x-pack/plugin/inference/licenses/google-api-client-LICENSE.txt @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-pack/plugin/inference/licenses/google-api-client-NOTICE.txt b/x-pack/plugin/inference/licenses/google-api-client-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/licenses/google-auth-LICENSE.txt b/x-pack/plugin/inference/licenses/google-auth-LICENSE.txt new file mode 100644 index 0000000000000..12edf23c6711f --- /dev/null +++ b/x-pack/plugin/inference/licenses/google-auth-LICENSE.txt @@ -0,0 +1,28 @@ +Copyright 2014, Google Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/x-pack/plugin/inference/licenses/google-auth-NOTICE.txt b/x-pack/plugin/inference/licenses/google-auth-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/licenses/google-http-LICENSE.txt b/x-pack/plugin/inference/licenses/google-http-LICENSE.txt new file mode 100644 index 0000000000000..980a15ac24eeb --- /dev/null +++ b/x-pack/plugin/inference/licenses/google-http-LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-pack/plugin/inference/licenses/google-http-NOTICE.txt b/x-pack/plugin/inference/licenses/google-http-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/licenses/google-oauth-client-LICENSE.txt b/x-pack/plugin/inference/licenses/google-oauth-client-LICENSE.txt new file mode 100644 index 0000000000000..12edf23c6711f --- /dev/null +++ b/x-pack/plugin/inference/licenses/google-oauth-client-LICENSE.txt @@ -0,0 +1,28 @@ +Copyright 2014, Google Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/x-pack/plugin/inference/licenses/google-oauth-client-NOTICE.txt b/x-pack/plugin/inference/licenses/google-oauth-client-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/licenses/grpc-context-LICENSE.txt b/x-pack/plugin/inference/licenses/grpc-context-LICENSE.txt new file mode 100644 index 0000000000000..d645695673349 --- /dev/null +++ b/x-pack/plugin/inference/licenses/grpc-context-LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-pack/plugin/inference/licenses/grpc-context-NOTICE.txt b/x-pack/plugin/inference/licenses/grpc-context-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/licenses/gson-LICENSE.txt b/x-pack/plugin/inference/licenses/gson-LICENSE.txt new file mode 100644 index 0000000000000..d645695673349 --- /dev/null +++ b/x-pack/plugin/inference/licenses/gson-LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-pack/plugin/inference/licenses/gson-NOTICE.txt b/x-pack/plugin/inference/licenses/gson-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/licenses/guava-LICENSE.txt b/x-pack/plugin/inference/licenses/guava-LICENSE.txt new file mode 100644 index 0000000000000..d645695673349 --- /dev/null +++ b/x-pack/plugin/inference/licenses/guava-LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-pack/plugin/inference/licenses/guava-NOTICE.txt b/x-pack/plugin/inference/licenses/guava-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/licenses/jackson-LICENSE.txt b/x-pack/plugin/inference/licenses/jackson-LICENSE.txt new file mode 100644 index 0000000000000..f5f45d26a49d6 --- /dev/null +++ b/x-pack/plugin/inference/licenses/jackson-LICENSE.txt @@ -0,0 +1,8 @@ +This copy of Jackson JSON processor streaming parser/generator is licensed under the +Apache (Software) License, version 2.0 ("the License"). +See the License for details about distribution rights, and the +specific rights regarding derivate works. + +You may obtain a copy of the License at: + +http://www.apache.org/licenses/LICENSE-2.0 diff --git a/x-pack/plugin/inference/licenses/jackson-NOTICE.txt b/x-pack/plugin/inference/licenses/jackson-NOTICE.txt new file mode 100644 index 0000000000000..4c976b7b4cc58 --- /dev/null +++ b/x-pack/plugin/inference/licenses/jackson-NOTICE.txt @@ -0,0 +1,20 @@ +# Jackson JSON processor + +Jackson is a high-performance, Free/Open Source JSON processing library. +It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has +been in development since 2007. +It is currently developed by a community of developers, as well as supported +commercially by FasterXML.com. + +## Licensing + +Jackson core and extension components may licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +## Credits + +A list of contributors may be found from CREDITS file, which is included +in some artifacts (usually source distributions); but is always available +from the source code management (SCM) system project uses. diff --git a/x-pack/plugin/inference/licenses/opencensus-LICENSE.txt b/x-pack/plugin/inference/licenses/opencensus-LICENSE.txt new file mode 100644 index 0000000000000..d645695673349 --- /dev/null +++ b/x-pack/plugin/inference/licenses/opencensus-LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-pack/plugin/inference/licenses/opencensus-NOTICE.txt b/x-pack/plugin/inference/licenses/opencensus-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/licenses/proto-google-LICENSE.txt b/x-pack/plugin/inference/licenses/proto-google-LICENSE.txt new file mode 100644 index 0000000000000..d645695673349 --- /dev/null +++ b/x-pack/plugin/inference/licenses/proto-google-LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-pack/plugin/inference/licenses/proto-google-NOTICE.txt b/x-pack/plugin/inference/licenses/proto-google-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/x-pack/plugin/inference/licenses/protobuf-LICENSE.txt b/x-pack/plugin/inference/licenses/protobuf-LICENSE.txt new file mode 100644 index 0000000000000..19b305b00060a --- /dev/null +++ b/x-pack/plugin/inference/licenses/protobuf-LICENSE.txt @@ -0,0 +1,32 @@ +Copyright 2008 Google Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Code generated by the Protocol Buffer compiler is owned by the owner +of the input file used when generating it. This code is not +standalone and requires a support library to be linked with it. This +support library is itself covered by the above license. diff --git a/x-pack/plugin/inference/licenses/protobuf-NOTICE.txt b/x-pack/plugin/inference/licenses/protobuf-NOTICE.txt new file mode 100644 index 0000000000000..19b305b00060a --- /dev/null +++ b/x-pack/plugin/inference/licenses/protobuf-NOTICE.txt @@ -0,0 +1,32 @@ +Copyright 2008 Google Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Code generated by the Protocol Buffer compiler is owned by the owner +of the input file used when generating it. This code is not +standalone and requires a support library to be linked with it. This +support library is itself covered by the above license. diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index c67c6f29d69c5..183d41bf730fe 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -19,6 +19,9 @@ requires org.apache.lucene.core; requires org.apache.lucene.join; requires com.ibm.icu; + requires com.google.auth.oauth2; + requires com.google.api.client; + requires com.google.gson; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 14980df2f8789..b75c44731df06 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -24,6 +24,8 @@ import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings; @@ -48,6 +50,9 @@ import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; @@ -111,8 +116,10 @@ public static List getNamedWriteables() { addAzureOpenAiNamedWriteables(namedWriteables); addAzureAiStudioNamedWriteables(namedWriteables); addGoogleAiStudioNamedWritables(namedWriteables); + addGoogleVertexAiNamedWriteables(namedWriteables); addMistralNamedWriteables(namedWriteables); addCustomElandWriteables(namedWriteables); + addAnthropicNamedWritables(namedWriteables); return namedWriteables; } @@ -287,6 +294,28 @@ private static void addGoogleAiStudioNamedWritables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry(SecretSettings.class, GoogleVertexAiSecretSettings.NAME, GoogleVertexAiSecretSettings::new) + ); + + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + GoogleVertexAiEmbeddingsServiceSettings.NAME, + GoogleVertexAiEmbeddingsServiceSettings::new + ) + ); + + namedWriteables.add( + new NamedWriteableRegistry.Entry( + TaskSettings.class, + GoogleVertexAiEmbeddingsTaskSettings.NAME, + GoogleVertexAiEmbeddingsTaskSettings::new + ) + ); + } + private static void addInternalElserNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry(ServiceSettings.class, ElserInternalServiceSettings.NAME, ElserInternalServiceSettings::new) @@ -372,4 +401,21 @@ private static void addCustomElandWriteables(final List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + AnthropicChatCompletionServiceSettings.NAME, + AnthropicChatCompletionServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + TaskSettings.class, + AnthropicChatCompletionTaskSettings.NAME, + AnthropicChatCompletionTaskSettings::new + ) + ); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 1e0f715e3f3e9..c8fb7e94a19ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -66,12 +66,14 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicService; import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; @@ -199,7 +201,9 @@ public List getInferenceServiceFactories() { context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()), context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()), context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()), + context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()), context -> new MistralService(httpFactory.get(), serviceComponents.get()), + context -> new AnthropicService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java index 01a345909c6b1..0e8928c3a2391 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java @@ -46,6 +46,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El return switch (elementType) { case BYTE -> EmbeddingType.BYTE; case FLOAT -> EmbeddingType.FLOAT; + case BIT -> throw new IllegalArgumentException("Bit vectors are not supported"); }; } }; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SizeLimitInputStream.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SizeLimitInputStream.java index 78e7b5cbbd95e..cbef4e39fae54 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SizeLimitInputStream.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/SizeLimitInputStream.java @@ -21,6 +21,13 @@ */ public final class SizeLimitInputStream extends FilterInputStream { + public static class InputStreamTooLargeException extends IOException { + + public InputStreamTooLargeException(String message) { + super(message); + } + } + private final long maxByteSize; private final AtomicLong byteCounter = new AtomicLong(0); @@ -73,9 +80,9 @@ public boolean markSupported() { return false; } - private void checkMaximumLengthReached() throws IOException { + private void checkMaximumLengthReached() throws InputStreamTooLargeException { if (byteCounter.get() > maxByteSize) { - throw new IOException("Maximum limit of [" + maxByteSize + "] bytes reached"); + throw new InputStreamTooLargeException("Maximum limit of [" + maxByteSize + "] bytes reached"); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreator.java new file mode 100644 index 0000000000000..fa386c80643b0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreator.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.anthropic; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel; + +import java.util.Map; +import java.util.Objects; + +/** + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the anthropic model type. + */ +public class AnthropicActionCreator implements AnthropicActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + + public AnthropicActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(AnthropicChatCompletionModel model, Map taskSettings) { + var overriddenModel = AnthropicChatCompletionModel.of(model, taskSettings); + + return new AnthropicChatCompletionAction(sender, overriddenModel, serviceComponents); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionVisitor.java new file mode 100644 index 0000000000000..d2727c0e9b20c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionVisitor.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.anthropic; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel; + +import java.util.Map; + +public interface AnthropicActionVisitor { + + ExecutableAction create(AnthropicChatCompletionModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionAction.java new file mode 100644 index 0000000000000..9891d671764a4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionAction.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.anthropic; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.AnthropicCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class AnthropicChatCompletionAction implements ExecutableAction { + + private final String errorMessage; + private final AnthropicCompletionRequestManager requestCreator; + + private final Sender sender; + + public AnthropicChatCompletionAction(Sender sender, AnthropicChatCompletionModel model, ServiceComponents serviceComponents) { + Objects.requireNonNull(serviceComponents); + Objects.requireNonNull(model); + this.sender = Objects.requireNonNull(sender); + this.requestCreator = AnthropicCompletionRequestManager.of(model, serviceComponents.threadPool()); + this.errorMessage = constructFailedToSendRequestMessage(model.getUri(), "Anthropic chat completions"); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + if (inferenceInputs instanceof DocumentsOnlyInput == false) { + listener.onFailure(new ElasticsearchStatusException("Invalid inference input type", RestStatus.INTERNAL_SERVER_ERROR)); + return; + } + + var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs; + if (docsOnlyInput.getInputs().size() > 1) { + listener.onFailure(new ElasticsearchStatusException("Anthropic completions only accepts 1 input", RestStatus.BAD_REQUEST)); + return; + } + + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); + + sender.send(requestCreator, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, errorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionCreator.java new file mode 100644 index 0000000000000..32254432d3ee2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionCreator.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googlevertexai; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; + +import java.util.Map; +import java.util.Objects; + +public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor { + + private final Sender sender; + + private final ServiceComponents serviceComponents; + + public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map taskSettings) { + return new GoogleVertexAiEmbeddingsAction(sender, model, serviceComponents); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionVisitor.java new file mode 100644 index 0000000000000..8d885749fee09 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionVisitor.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googlevertexai; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; + +import java.util.Map; + +public interface GoogleVertexAiActionVisitor { + + ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map taskSettings); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiEmbeddingsAction.java new file mode 100644 index 0000000000000..f9814224c101a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiEmbeddingsAction.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googlevertexai; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.GoogleVertexAiEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class GoogleVertexAiEmbeddingsAction implements ExecutableAction { + + private final String failedToSendRequestErrorMessage; + + private final GoogleVertexAiEmbeddingsRequestManager requestManager; + + private final Sender sender; + + public GoogleVertexAiEmbeddingsAction(Sender sender, GoogleVertexAiEmbeddingsModel model, ServiceComponents serviceComponents) { + Objects.requireNonNull(serviceComponents); + Objects.requireNonNull(model); + this.sender = Objects.requireNonNull(sender); + this.requestManager = new GoogleVertexAiEmbeddingsRequestManager( + model, + serviceComponents.truncator(), + serviceComponents.threadPool() + ); + this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "Google Vertex AI embeddings"); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException( + failedToSendRequestErrorMessage, + listener + ); + + sender.send(requestManager, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicAccount.java new file mode 100644 index 0000000000000..fb74188b10995 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicAccount.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.anthropic; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicModel; + +import java.net.URI; +import java.util.Objects; + +public record AnthropicAccount(URI uri, SecureString apiKey) { + + public static AnthropicAccount of(AnthropicModel model) { + return new AnthropicAccount(model.getUri(), model.apiKey()); + } + + public AnthropicAccount { + Objects.requireNonNull(uri); + Objects.requireNonNull(apiKey); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java new file mode 100644 index 0000000000000..cab2c655b9ffb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.anthropic; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; +import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown; + +public class AnthropicResponseHandler extends BaseResponseHandler { + /** + * Rate limit headers taken from https://docs.anthropic.com/en/api/rate-limits#response-headers + */ + // The maximum number of requests allowed within the rate limit window. + static final String REQUESTS_LIMIT = "anthropic-ratelimit-requests-limit"; + // The number of requests remaining within the current rate limit window. + static final String REMAINING_REQUESTS = "anthropic-ratelimit-requests-remaining"; + // The time when the request rate limit window will reset, provided in RFC 3339 format. + static final String REQUEST_RESET = "anthropic-ratelimit-requests-reset"; + // The maximum number of tokens allowed within the rate limit window. + static final String TOKENS_LIMIT = "anthropic-ratelimit-tokens-limit"; + // The number of tokens remaining, rounded to the nearest thousand, within the current rate limit window. + static final String REMAINING_TOKENS = "anthropic-ratelimit-tokens-remaining"; + // The time when the token rate limit window will reset, provided in RFC 3339 format. + static final String TOKENS_RESET = "anthropic-ratelimit-tokens-reset"; + // The number of seconds until the rate limit window resets. + static final String RETRY_AFTER = "retry-after"; + + static final String SERVER_BUSY = "Received an Anthropic server is temporarily overloaded status code"; + + public AnthropicResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse); + } + + @Override + public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) + throws RetryException { + checkForFailureStatusCode(request, result); + checkForEmptyBody(throttlerManager, logger, request, result); + } + + /** + * Validates the status code throws an RetryException if not in the range [200, 300). + * + * The Anthropic API error codes are documented here. + * @param request The originating request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 300 or < 200 } + */ + void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode >= 200 && statusCode < 300) { + return; + } + + // handle error codes + if (statusCode == 500) { + throw new RetryException(true, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 529) { + throw new RetryException(true, buildError(SERVER_BUSY, request, result)); + } else if (statusCode > 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result)); + } else if (statusCode == 403) { + throw new RetryException(false, buildError(PERMISSION_DENIED, request, result)); + } else if (statusCode == 401) { + throw new RetryException(false, buildError(AUTHENTICATION, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } + + static String buildRateLimitErrorMessage(HttpResult result) { + var response = result.response(); + var tokenLimit = getFirstHeaderOrUnknown(response, TOKENS_LIMIT); + var remainingTokens = getFirstHeaderOrUnknown(response, REMAINING_TOKENS); + var requestLimit = getFirstHeaderOrUnknown(response, REQUESTS_LIMIT); + var remainingRequests = getFirstHeaderOrUnknown(response, REMAINING_REQUESTS); + var requestReset = getFirstHeaderOrUnknown(response, REQUEST_RESET); + var tokensReset = getFirstHeaderOrUnknown(response, TOKENS_RESET); + var retryAfter = getFirstHeaderOrUnknown(response, RETRY_AFTER); + + var usageMessage = Strings.format( + "Token limit [%s], remaining tokens [%s], tokens reset [%s]. " + + "Request limit [%s], remaining requests [%s], request reset [%s]. Retry after [%s]", + tokenLimit, + remainingTokens, + tokensReset, + requestLimit, + remainingRequests, + requestReset, + retryAfter + ); + + return RATE_LIMIT + ". " + usageMessage; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googlevertexai/GoogleVertexAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googlevertexai/GoogleVertexAiResponseHandler.java new file mode 100644 index 0000000000000..872bf51f3662a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googlevertexai/GoogleVertexAiResponseHandler.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.googlevertexai; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.googlevertexai.GoogleVertexAiErrorResponseEntity; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; + +public class GoogleVertexAiResponseHandler extends BaseResponseHandler { + + static final String GOOGLE_VERTEX_AI_UNAVAILABLE = "The Google Vertex AI service may be temporarily overloaded or down"; + + public GoogleVertexAiResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, GoogleVertexAiErrorResponseEntity::fromResponse); + } + + @Override + public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) + throws RetryException { + checkForFailureStatusCode(request, result); + checkForEmptyBody(throttlerManager, logger, request, result); + } + + void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode >= 200 && statusCode < 300) { + return; + } + + // handle error codes + if (statusCode == 500) { + throw new RetryException(true, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 503) { + throw new RetryException(true, buildError(GOOGLE_VERTEX_AI_UNAVAILABLE, request, result)); + } else if (statusCode > 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 404) { + throw new RetryException(false, buildError(resourceNotFoundError(request), request, result)); + } else if (statusCode == 403) { + throw new RetryException(false, buildError(PERMISSION_DENIED, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } + + private static String resourceNotFoundError(Request request) { + return format("Resource not found at [%s]", request.getURI()); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpSettings.java index ef5fec24c3d59..642b76d775173 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpSettings.java @@ -20,9 +20,9 @@ public class HttpSettings { // These settings are default scope for testing static final Setting MAX_HTTP_RESPONSE_SIZE = Setting.byteSizeSetting( "xpack.inference.http.max_response_size", - new ByteSizeValue(10, ByteSizeUnit.MB), // default + new ByteSizeValue(50, ByteSizeUnit.MB), // default ByteSizeValue.ONE, // min - new ByteSizeValue(50, ByteSizeUnit.MB), // max + new ByteSizeValue(100, ByteSizeUnit.MB), // max Setting.Property.NodeScope, Setting.Property.Dynamic ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java index ffe10ffe3b6ae..dd45501564e4e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.SizeLimitInputStream; import org.elasticsearch.xpack.inference.external.http.HttpClient; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; @@ -26,12 +27,16 @@ import java.net.UnknownHostException; import java.util.Objects; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; public class RetryingHttpSender implements RequestSender { + + static final int MAX_RETIES = 3; + private final HttpClient httpClient; private final ThrottlerManager throttlerManager; private final RetrySettings retrySettings; @@ -68,6 +73,7 @@ private class InternalRetrier extends RetryableAction { private final Logger logger; private final HttpClientContext context; private final Supplier hasRequestCompletedFunction; + private final AtomicInteger retryCount; InternalRetrier( Logger logger, @@ -91,10 +97,12 @@ private class InternalRetrier extends RetryableAction { this.context = Objects.requireNonNull(context); this.responseHandler = Objects.requireNonNull(responseHandler); this.hasRequestCompletedFunction = Objects.requireNonNull(hasRequestCompletedFunction); + this.retryCount = new AtomicInteger(0); } @Override public void tryAction(ActionListener listener) { + retryCount.incrementAndGet(); // A timeout likely occurred so let's stop attempting to execute the request if (hasRequestCompletedFunction.get()) { return; @@ -140,10 +148,10 @@ private Exception transformIfRetryable(Exception e) { RestStatus.BAD_REQUEST, e ); - } - - if (e instanceof IOException) { - exceptionToReturn = new RetryException(true, e); + } else if (e instanceof SizeLimitInputStream.InputStreamTooLargeException) { + return e; + } else if (e instanceof IOException) { + return new RetryException(true, e); } return exceptionToReturn; @@ -164,6 +172,10 @@ private Exception wrapWithElasticsearchException(Exception e, String inferenceEn @Override public boolean shouldRetry(Exception e) { + if (retryCount.get() >= MAX_RETIES) { + return false; + } + if (e instanceof Retryable retry) { request = retry.rebuildRequest(request); return retry.shouldRetry(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java new file mode 100644 index 0000000000000..7dd1a66db13e7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.anthropic.AnthropicResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.anthropic.AnthropicChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class AnthropicCompletionRequestManager extends AnthropicRequestManager { + + private static final Logger logger = LogManager.getLogger(AnthropicCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createCompletionHandler(); + + public static AnthropicCompletionRequestManager of(AnthropicChatCompletionModel model, ThreadPool threadPool) { + return new AnthropicCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final AnthropicChatCompletionModel model; + + private AnthropicCompletionRequestManager(AnthropicChatCompletionModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + @Nullable String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(input, model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } + + private static ResponseHandler createCompletionHandler() { + return new AnthropicResponseHandler("anthropic completions", AnthropicChatCompletionResponseEntity::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicRequestManager.java new file mode 100644 index 0000000000000..a47910c0b37c8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicRequestManager.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.anthropic.AnthropicAccount; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicModel; + +import java.util.Objects; + +abstract class AnthropicRequestManager extends BaseRequestManager { + + protected AnthropicRequestManager(ThreadPool threadPool, AnthropicModel model) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + } + + record RateLimitGrouping(int accountHash, int modelIdHash) { + public static RateLimitGrouping of(AnthropicModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(AnthropicAccount.of(model).hashCode(), model.rateLimitServiceSettings().modelId().hashCode()); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java index 002fa71b7fb5d..e295cf5cc43dd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java @@ -15,8 +15,8 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioChatCompletionRequest; -import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiErrorResponseEntity; import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiExternalResponseHandler; +import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity; import org.elasticsearch.xpack.inference.external.response.azureaistudio.AzureAiStudioChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel; @@ -52,7 +52,7 @@ private static ResponseHandler createCompletionHandler() { return new AzureMistralOpenAiExternalResponseHandler( "azure ai studio completion", new AzureAiStudioChatCompletionResponseEntity(), - AzureMistralOpenAiErrorResponseEntity::fromResponse + ErrorMessageResponseEntity::fromResponse ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java index ec5ab2fee6a57..f0f87402fb3a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java @@ -16,8 +16,8 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioEmbeddingsRequest; -import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiErrorResponseEntity; import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiExternalResponseHandler; +import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity; import org.elasticsearch.xpack.inference.external.response.azureaistudio.AzureAiStudioEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel; @@ -56,7 +56,7 @@ private static ResponseHandler createEmbeddingsHandler() { return new AzureMistralOpenAiExternalResponseHandler( "azure ai studio text embedding", new AzureAiStudioEmbeddingsResponseEntity(), - AzureMistralOpenAiErrorResponseEntity::fromResponse + ErrorMessageResponseEntity::fromResponse ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java new file mode 100644 index 0000000000000..c79e1a088ad5f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.googlevertexai.GoogleVertexAiResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.googlevertexai.GoogleVertexAiEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.googlevertexai.GoogleVertexAiEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class GoogleVertexAiEmbeddingsRequestManager extends GoogleVertexAiRequestManager { + + private static final Logger logger = LogManager.getLogger(GoogleVertexAiEmbeddingsRequestManager.class); + + private static final ResponseHandler HANDLER = createEmbeddingsHandler(); + + private static ResponseHandler createEmbeddingsHandler() { + return new GoogleVertexAiResponseHandler("google vertex ai embeddings", GoogleVertexAiEmbeddingsResponseEntity::fromResponse); + } + + private final GoogleVertexAiEmbeddingsModel model; + + private final Truncator truncator; + + public GoogleVertexAiEmbeddingsRequestManager(GoogleVertexAiEmbeddingsModel model, Truncator truncator, ThreadPool threadPool) { + super(threadPool, model); + this.model = Objects.requireNonNull(model); + this.truncator = Objects.requireNonNull(truncator); + } + + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + var request = new GoogleVertexAiEmbeddingsRequest(truncator, truncatedInput, model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRequestManager.java new file mode 100644 index 0000000000000..698bce3e337d6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRequestManager.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel; + +import java.util.Objects; + +public abstract class GoogleVertexAiRequestManager extends BaseRequestManager { + + GoogleVertexAiRequestManager(ThreadPool threadPool, GoogleVertexAiModel model) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + } + + record RateLimitGrouping(int modelIdHash) { + public static RateLimitGrouping of(GoogleVertexAiModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode()); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java index ab6a1bfb31372..1807712a31ac5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java @@ -16,8 +16,8 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.request.mistral.MistralEmbeddingsRequest; -import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiErrorResponseEntity; import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiExternalResponseHandler; +import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity; import org.elasticsearch.xpack.inference.external.response.mistral.MistralEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; @@ -38,7 +38,7 @@ private static ResponseHandler createEmbeddingsHandler() { return new AzureMistralOpenAiExternalResponseHandler( "mistral text embedding", new MistralEmbeddingsResponseEntity(), - AzureMistralOpenAiErrorResponseEntity::fromResponse + ErrorMessageResponseEntity::fromResponse ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequest.java new file mode 100644 index 0000000000000..fa6bb31d0f401 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequest.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.anthropic; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.anthropic.AnthropicAccount; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicRequestUtils.createVersionHeader; + +public class AnthropicChatCompletionRequest implements Request { + + private final AnthropicAccount account; + private final List input; + private final AnthropicChatCompletionModel model; + + public AnthropicChatCompletionRequest(List input, AnthropicChatCompletionModel model) { + this.account = AnthropicAccount.of(model); + this.input = Objects.requireNonNull(input); + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(account.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new AnthropicChatCompletionRequestEntity(input, model.getServiceSettings(), model.getTaskSettings())) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(AnthropicRequestUtils.createAuthBearerHeader(account.apiKey())); + httpPost.setHeader(createVersionHeader()); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return account.uri(); + } + + @Override + public Request truncate() { + // No truncation for Anthropic completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for Anthropic completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..4186ad0a722ce --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestEntity.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.anthropic; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class AnthropicChatCompletionRequestEntity implements ToXContentObject { + + private static final String MESSAGES_FIELD = "messages"; + private static final String MODEL_FIELD = "model"; + + private static final String ROLE_FIELD = "role"; + private static final String USER_VALUE = "user"; + private static final String CONTENT_FIELD = "content"; + private static final String MAX_TOKENS_FIELD = "max_tokens"; + private static final String TEMPERATURE_FIELD = "temperature"; + private static final String TOP_P_FIELD = "top_p"; + private static final String TOP_K_FIELD = "top_k"; + + private final List messages; + private final AnthropicChatCompletionServiceSettings serviceSettings; + private final AnthropicChatCompletionTaskSettings taskSettings; + + public AnthropicChatCompletionRequestEntity( + List messages, + AnthropicChatCompletionServiceSettings serviceSettings, + AnthropicChatCompletionTaskSettings taskSettings + ) { + this.messages = Objects.requireNonNull(messages); + this.serviceSettings = Objects.requireNonNull(serviceSettings); + this.taskSettings = Objects.requireNonNull(taskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.startArray(MESSAGES_FIELD); + { + for (String message : messages) { + builder.startObject(); + + { + builder.field(ROLE_FIELD, USER_VALUE); + builder.field(CONTENT_FIELD, message); + } + + builder.endObject(); + } + } + builder.endArray(); + + builder.field(MODEL_FIELD, serviceSettings.modelId()); + builder.field(MAX_TOKENS_FIELD, taskSettings.maxTokens()); + + if (taskSettings.temperature() != null) { + builder.field(TEMPERATURE_FIELD, taskSettings.temperature()); + } + + if (taskSettings.topP() != null) { + builder.field(TOP_P_FIELD, taskSettings.topP()); + } + + if (taskSettings.topK() != null) { + builder.field(TOP_K_FIELD, taskSettings.topK()); + } + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicRequestUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicRequestUtils.java new file mode 100644 index 0000000000000..2e8ce980dcc08 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicRequestUtils.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.anthropic; + +import org.apache.http.Header; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.settings.SecureString; + +public class AnthropicRequestUtils { + public static final String HOST = "api.anthropic.com"; + public static final String API_VERSION_1 = "v1"; + public static final String MESSAGES_PATH = "messages"; + + public static final String ANTHROPIC_VERSION_2023_06_01 = "2023-06-01"; + + public static final String X_API_KEY = "x-api-key"; + public static final String VERSION = "anthropic-version"; + + public static Header createAuthBearerHeader(SecureString apiKey) { + return new BasicHeader(X_API_KEY, apiKey.toString()); + } + + public static Header createVersionHeader() { + return new BasicHeader(VERSION, ANTHROPIC_VERSION_2023_06_01); + } + + private AnthropicRequestUtils() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequest.java new file mode 100644 index 0000000000000..c0e36baf2e98f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequest.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googlevertexai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +public class GoogleVertexAiEmbeddingsRequest implements GoogleVertexAiRequest { + + private final Truncator truncator; + + private final Truncator.TruncationResult truncationResult; + + private final GoogleVertexAiEmbeddingsModel model; + + public GoogleVertexAiEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, GoogleVertexAiEmbeddingsModel model) { + this.truncator = Objects.requireNonNull(truncator); + this.truncationResult = Objects.requireNonNull(input); + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), model.getTaskSettings().autoTruncate())) + .getBytes(StandardCharsets.UTF_8) + ); + + httpPost.setEntity(byteEntity); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + + decorateWithAuth(httpPost); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + public void decorateWithAuth(HttpPost httpPost) { + GoogleVertexAiRequest.decorateWithBearerToken(httpPost, model.getSecretSettings()); + } + + Truncator truncator() { + return truncator; + } + + Truncator.TruncationResult truncationResult() { + return truncationResult; + } + + GoogleVertexAiEmbeddingsModel model() { + return model; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + + return new GoogleVertexAiEmbeddingsRequest(truncator, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..2fae999599ba2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestEntity.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googlevertexai; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record GoogleVertexAiEmbeddingsRequestEntity(List inputs, @Nullable Boolean autoTruncation) implements ToXContentObject { + + private static final String INSTANCES_FIELD = "instances"; + private static final String CONTENT_FIELD = "content"; + private static final String PARAMETERS_FIELD = "parameters"; + private static final String AUTO_TRUNCATE_FIELD = "autoTruncate"; + + public GoogleVertexAiEmbeddingsRequestEntity { + Objects.requireNonNull(inputs); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(INSTANCES_FIELD); + + for (String input : inputs) { + builder.startObject(); + { + builder.field(CONTENT_FIELD, input); + } + builder.endObject(); + } + + builder.endArray(); + + if (autoTruncation != null) { + builder.startObject(PARAMETERS_FIELD); + { + builder.field(AUTO_TRUNCATE_FIELD, autoTruncation); + } + builder.endObject(); + } + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRequest.java new file mode 100644 index 0000000000000..69859ef3de642 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRequest.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googlevertexai; + +import com.google.auth.oauth2.GoogleCredentials; +import com.google.auth.oauth2.ServiceAccountCredentials; + +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.SpecialPermission; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; + +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public interface GoogleVertexAiRequest extends Request { + List AUTH_SCOPE = Collections.singletonList("https://www.googleapis.com/auth/cloud-platform"); + + static void decorateWithBearerToken(HttpPost httpPost, GoogleVertexAiSecretSettings secretSettings) { + SpecialPermission.check(); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + GoogleCredentials credentials = ServiceAccountCredentials.fromStream( + new ByteArrayInputStream(secretSettings.serviceAccountJson().toString().getBytes(StandardCharsets.UTF_8)) + ).createScoped(AUTH_SCOPE); + credentials.refreshIfExpired(); + httpPost.setHeader(createAuthBearerHeader(new SecureString(credentials.getAccessToken().getTokenValue().toCharArray()))); + + return null; + }); + } catch (Exception e) { + throw new ElasticsearchStatusException(e.getMessage(), RestStatus.FORBIDDEN, e); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiUtils.java new file mode 100644 index 0000000000000..8258679bc6dfe --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiUtils.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googlevertexai; + +public final class GoogleVertexAiUtils { + + public static final String GOOGLE_VERTEX_AI_HOST_SUFFIX = "-aiplatform.googleapis.com"; + + public static final String V1 = "v1"; + + public static final String PROJECTS = "projects"; + + public static final String LOCATIONS = "locations"; + + public static final String PUBLISHERS = "publishers"; + + public static final String PUBLISHER_GOOGLE = "google"; + + public static final String MODELS = "models"; + + public static final String PREDICT = "predict"; + + private GoogleVertexAiUtils() {} + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiExternalResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiExternalResponseHandler.java index dfdb6712d5e45..e4e96ca644c7f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiExternalResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiExternalResponseHandler.java @@ -116,7 +116,7 @@ public static boolean isContentTooLarge(HttpResult result) { } if (statusCode == 400) { - var errorEntity = AzureMistralOpenAiErrorResponseEntity.fromResponse(result); + var errorEntity = ErrorMessageResponseEntity.fromResponse(result); return errorEntity != null && errorEntity.getErrorMessage().contains(CONTENT_TOO_LARGE_MESSAGE); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java similarity index 90% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiErrorResponseEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java index 83ea7801dfd58..dbf2b37955b22 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiErrorResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java @@ -31,10 +31,10 @@ * This currently covers error handling for Azure AI Studio, however this pattern * can be used to simplify and refactor handling for Azure OpenAI and OpenAI responses. */ -public class AzureMistralOpenAiErrorResponseEntity implements ErrorMessage { +public class ErrorMessageResponseEntity implements ErrorMessage { protected String errorMessage; - public AzureMistralOpenAiErrorResponseEntity(String errorMessage) { + public ErrorMessageResponseEntity(String errorMessage) { this.errorMessage = errorMessage; } @@ -62,7 +62,7 @@ public static ErrorMessage fromResponse(HttpResult response) { if (error != null) { var message = (String) error.get("message"); if (message != null) { - return new AzureMistralOpenAiErrorResponseEntity(message); + return new ErrorMessageResponseEntity(message); } } } catch (Exception e) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/anthropic/AnthropicChatCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/anthropic/AnthropicChatCompletionResponseEntity.java new file mode 100644 index 0000000000000..75b504cbd8102 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/anthropic/AnthropicChatCompletionResponseEntity.java @@ -0,0 +1,146 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.anthropic; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class AnthropicChatCompletionResponseEntity { + + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Anthropic chat completions response"; + + /** + * Parses the Anthropic chat completion response. + * For a request like: + * + *
    +     *     
    +     *         {
    +     *             "inputs": ["Please summarize this text: some text"]
    +     *         }
    +     *     
    +     * 
    + * + * The response would look like: + * + *
    +     *     
    +     *  {
    +     *      "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb",
    +     *      "type": "message",
    +     *      "role": "assistant",
    +     *      "model": "claude-3-opus-20240229",
    +     *      "content": [
    +     *          {
    +     *              "type": "text",
    +     *              "text": "result"
    +     *          }
    +     *      ],
    +     *      "stop_reason": "end_turn",
    +     *      "stop_sequence": null,
    +     *      "usage": {
    +     *          "input_tokens": 16,
    +     *          "output_tokens": 326
    +     *      }
    +     *  }
    +     *     
    +     * 
    + */ + + public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "content", FAILED_TO_FIND_FIELD_TEMPLATE); + + var completionResults = doParse(jsonParser); + + return new ChatCompletionResults(completionResults); + } + } + + private static List doParse(XContentParser parser) throws IOException { + var parsedResults = parseList(parser, (listParser) -> { + var parsedObject = TextObject.parse(parser); + // Anthropic also supports a tool_use type, we want to ignore those objects + if (parsedObject.type == null || parsedObject.type.equals("text") == false || parsedObject.text == null) { + return null; + } + + return new ChatCompletionResults.Result(parsedObject.text); + }); + + parsedResults.removeIf(Objects::isNull); + return parsedResults; + } + + private record TextObject(@Nullable String type, @Nullable String text) { + + private static final ParseField TEXT = new ParseField("text"); + private static final ParseField TYPE = new ParseField("type"); + private static final ObjectParser PARSER = new ObjectParser<>( + "anthropic_chat_completions_response", + true, + Builder::new + ); + + static { + PARSER.declareString(Builder::setText, TEXT); + PARSER.declareString(Builder::setType, TYPE); + } + + public static TextObject parse(XContentParser parser) throws IOException { + Builder builder = PARSER.apply(parser, null); + return builder.build(); + } + + private static final class Builder { + + private String type; + private String text; + + private Builder() {} + + public Builder setType(String type) { + this.type = type; + return this; + } + + public Builder setText(String text) { + this.text = text; + return this; + } + + public TextObject build() { + return new TextObject(type, text); + } + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsResponseEntity.java new file mode 100644 index 0000000000000..7205ea83d0a7a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsResponseEntity.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googlevertexai; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class GoogleVertexAiEmbeddingsResponseEntity { + + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = + "Failed to find required field [%s] in Google Vertex AI embeddings response"; + + /** + * Parses the Google Vertex AI get embeddings response. + * For a request like: + * + *
    +     *     
    +     *         {
    +     *             "inputs": ["Embed this", "Embed this, too"]
    +     *         }
    +     *     
    +     * 
    + * + * The response would look like: + * + *
    +     *     
    +     *         {
    +     *           "predictions": [
    +     *              {
    +     *                "embeddings": {
    +     *                  "statistics": {
    +     *                    "truncated": false,
    +     *                    "token_count": 6
    +     *                  },
    +     *                  "values": [ ... ]
    +     *                }
    +     *              }
    +     *           ]
    +     *         }
    +     *     
    +     * 
    + */ + + public static InferenceTextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "predictions", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingList = parseList( + jsonParser, + GoogleVertexAiEmbeddingsResponseEntity::parseEmbeddingObject + ); + + return new InferenceTextEmbeddingFloatResults(embeddingList); + } + } + + private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseEmbeddingObject(XContentParser parser) + throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + positionParserAtTokenAfterField(parser, "values", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingValueList = parseList(parser, GoogleVertexAiEmbeddingsResponseEntity::parseEmbeddingList); + + // parse and discard the rest of the two objects + consumeUntilObjectEnd(parser); + consumeUntilObjectEnd(parser); + + return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValueList); + } + + private static float parseEmbeddingList(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); + return parser.floatValue(); + } + + private GoogleVertexAiEmbeddingsResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiErrorResponseEntity.java new file mode 100644 index 0000000000000..bf14d751db868 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiErrorResponseEntity.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googlevertexai; + +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorMessage; + +import java.util.Map; + +public class GoogleVertexAiErrorResponseEntity implements ErrorMessage { + + private final String errorMessage; + + private GoogleVertexAiErrorResponseEntity(String errorMessage) { + this.errorMessage = errorMessage; + } + + @Override + public String getErrorMessage() { + return errorMessage; + } + + /** + * An example error response for invalid auth would look like + * + * { + * "error": { + * "code": 401, + * "message": "some error message", + * "status": "UNAUTHENTICATED", + * "details": [ + * { + * "@type": "type.googleapis.com/google.rpc.ErrorInfo", + * "reason": "CREDENTIALS_MISSING", + * "domain": "googleapis.com", + * "metadata": { + * "method": "google.cloud.aiplatform.v1.PredictionService.Predict", + * "service": "aiplatform.googleapis.com" + * } + * } + * ] + * } + * } + * + * + * @param response The error response + * @return An error entity if the response is JSON with the above structure + * or null if the response does not contain the `error.message` field + */ + @SuppressWarnings("unchecked") + public static GoogleVertexAiErrorResponseEntity fromResponse(HttpResult response) { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + var responseMap = jsonParser.map(); + var error = (Map) responseMap.get("error"); + if (error != null) { + var message = (String) error.get("message"); + if (message != null) { + return new GoogleVertexAiErrorResponseEntity(message); + } + } + } catch (Exception e) { + // swallow the error + } + + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java index cb72b4d02302b..1af79a69839ac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceFields.java @@ -14,6 +14,7 @@ public final class ServiceFields { public static final String SIMILARITY = "similarity"; public static final String DIMENSIONS = "dimensions"; + // Typically we use this to define the maximum tokens for the input text (text being sent to an integration) public static final String MAX_INPUT_TOKENS = "max_input_tokens"; public static final String URL = "url"; public static final String MODEL_ID = "model_id"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index f9aca89969614..966cc029232b1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -233,11 +233,7 @@ public static String invalidSettingError(String settingName, String scope) { public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) { try { - if (url == null) { - return null; - } - - return createUri(url); + return createOptionalUri(url); } catch (IllegalArgumentException cause) { validationException.addValidationError(ServiceUtils.invalidUrlErrorMsg(url, settingName, settingScope, cause.getMessage())); return null; @@ -355,6 +351,32 @@ public static String extractOptionalString( return optionalField; } + public static Integer extractRequiredPositiveInteger( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + int initialValidationErrorCount = validationException.validationErrors().size(); + Integer field = ServiceUtils.removeAsType(map, settingName, Integer.class, validationException); + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + if (field == null) { + validationException.addValidationError(ServiceUtils.missingSettingErrorMsg(settingName, scope)); + } else if (field <= 0) { + validationException.addValidationError(ServiceUtils.mustBeAPositiveIntegerErrorMessage(settingName, scope, field)); + } + + if (validationException.validationErrors().size() > initialValidationErrorCount) { + return null; + } + + return field; + } + public static Integer extractOptionalPositiveInteger( Map map, String settingName, @@ -625,5 +647,9 @@ public static SecureString apiKey(@Nullable ApiKeySecrets secrets) { return secrets == null ? new SecureString(new char[0]) : secrets.apiKey(); } + public static T nonNullOrDefault(@Nullable T requestValue, @Nullable T originalSettingsValue) { + return requestValue == null ? originalSettingsValue : requestValue; + } + private ServiceUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicModel.java new file mode 100644 index 0000000000000..88d4e3e0d0c82 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicModel.java @@ -0,0 +1,89 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.anthropic.AnthropicActionVisitor; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.Objects; + +public abstract class AnthropicModel extends Model { + + private final AnthropicRateLimitServiceSettings rateLimitServiceSettings; + private final SecureString apiKey; + private final URI uri; + + public AnthropicModel( + ModelConfigurations configurations, + ModelSecrets secrets, + AnthropicRateLimitServiceSettings rateLimitServiceSettings, + CheckedSupplier uriSupplier, + @Nullable ApiKeySecrets apiKeySecrets + ) { + super(configurations, secrets); + + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + apiKey = ServiceUtils.apiKey(apiKeySecrets); + + try { + uri = uriSupplier.get(); + } catch (URISyntaxException e) { + throw new ElasticsearchStatusException( + Strings.format("Failed to construct %s URL", configurations.getService()), + RestStatus.BAD_REQUEST, + e + ); + } + } + + protected AnthropicModel(AnthropicModel model, TaskSettings taskSettings) { + super(model, taskSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + apiKey = model.apiKey(); + uri = model.getUri(); + } + + protected AnthropicModel(AnthropicModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + apiKey = model.apiKey(); + uri = model.getUri(); + } + + public URI getUri() { + return uri; + } + + public SecureString apiKey() { + return apiKey; + } + + public AnthropicRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } + + public abstract ExecutableAction accept(AnthropicActionVisitor creator, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicRateLimitServiceSettings.java new file mode 100644 index 0000000000000..1d452e6415bc9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicRateLimitServiceSettings.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +/** + * The service setting fields for anthropic that determine how to rate limit requests. + */ +public interface AnthropicRateLimitServiceSettings { + String modelId(); + + RateLimitSettings rateLimitSettings(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java new file mode 100644 index 0000000000000..d1db6f260351b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -0,0 +1,217 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.anthropic.AnthropicActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class AnthropicService extends SenderService { + public static final String NAME = "anthropic"; + + public AnthropicService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + Set platformArchitectures, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + AnthropicModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + private static AnthropicModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + private static AnthropicModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + return switch (taskType) { + case COMPLETION -> new AnthropicChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + }; + } + + @Override + public AnthropicModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof AnthropicModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + AnthropicModel anthropicModel = (AnthropicModel) model; + var actionCreator = new AnthropicActionCreator(getSender(), getServiceComponents()); + + var action = anthropicModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(input), timeout, listener); + } + + @Override + protected void doInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("Anthropic service does not support inference with query input"); + } + + @Override + protected void doChunkedInfer( + Model model, + @Nullable String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + throw new UnsupportedOperationException("Anthropic service does not support chunked inference"); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_ANTHROPIC_INTEGRATION_ADDED; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceFields.java new file mode 100644 index 0000000000000..f633df963a098 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceFields.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic; + +public class AnthropicServiceFields { + + public static final String MAX_TOKENS = "max_tokens"; + public static final String TEMPERATURE_FIELD = "temperature"; + public static final String TOP_P_FIELD = "top_p"; + public static final String TOP_K_FIELD = "top_k"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionModel.java new file mode 100644 index 0000000000000..942cae8960daf --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionModel.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic.completion; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.anthropic.AnthropicActionVisitor; +import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicRequestUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +public class AnthropicChatCompletionModel extends AnthropicModel { + + public static AnthropicChatCompletionModel of(AnthropicChatCompletionModel model, Map taskSettings) { + if (taskSettings == null || taskSettings.isEmpty()) { + return model; + } + + var requestTaskSettings = AnthropicChatCompletionRequestTaskSettings.fromMap(taskSettings); + return new AnthropicChatCompletionModel( + model, + AnthropicChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings) + ); + } + + public AnthropicChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + AnthropicChatCompletionServiceSettings.fromMap(serviceSettings, context), + AnthropicChatCompletionTaskSettings.fromMap(taskSettings, context), + DefaultSecretSettings.fromMap(secrets) + ); + } + + AnthropicChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + AnthropicChatCompletionServiceSettings serviceSettings, + AnthropicChatCompletionTaskSettings taskSettings, + @Nullable DefaultSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings, + AnthropicChatCompletionModel::buildDefaultUri, + secrets + ); + } + + // This should only be used for testing + AnthropicChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + String url, + AnthropicChatCompletionServiceSettings serviceSettings, + AnthropicChatCompletionTaskSettings taskSettings, + @Nullable DefaultSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings, + () -> ServiceUtils.createUri(url), + secrets + ); + } + + private AnthropicChatCompletionModel(AnthropicChatCompletionModel originalModel, AnthropicChatCompletionTaskSettings taskSettings) { + super(originalModel, taskSettings); + } + + @Override + public AnthropicChatCompletionServiceSettings getServiceSettings() { + return (AnthropicChatCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public AnthropicChatCompletionTaskSettings getTaskSettings() { + return (AnthropicChatCompletionTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(AnthropicActionVisitor creator, Map taskSettings) { + return creator.create(this, taskSettings); + } + + private static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(AnthropicRequestUtils.HOST) + .setPathSegments(AnthropicRequestUtils.API_VERSION_1, AnthropicRequestUtils.MESSAGES_PATH) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionRequestTaskSettings.java new file mode 100644 index 0000000000000..85fdc12685fde --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionRequestTaskSettings.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic.completion; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; +import static org.elasticsearch.xpack.inference.services.anthropic.AnthropicServiceFields.MAX_TOKENS; +import static org.elasticsearch.xpack.inference.services.anthropic.AnthropicServiceFields.TEMPERATURE_FIELD; +import static org.elasticsearch.xpack.inference.services.anthropic.AnthropicServiceFields.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.anthropic.AnthropicServiceFields.TOP_P_FIELD; + +/** + * This class handles extracting Anthropic task settings from a request. The difference between this class and + * {@link AnthropicChatCompletionTaskSettings} is that this class considers all fields as optional. It will not throw an error if a field + * is missing. This allows overriding persistent task settings. + * @param maxTokens the number of tokens to generate before stopping + */ +public record AnthropicChatCompletionRequestTaskSettings( + @Nullable Integer maxTokens, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Integer topK +) { + + public static final AnthropicChatCompletionRequestTaskSettings EMPTY_SETTINGS = new AnthropicChatCompletionRequestTaskSettings( + null, + null, + null, + null + ); + + /** + * Extracts the task settings from a map. All settings are considered optional and the absence of a setting + * does not throw an error. + * + * @param map the settings received from a request + * @return a {@link AnthropicChatCompletionRequestTaskSettings} + */ + public static AnthropicChatCompletionRequestTaskSettings fromMap(Map map) { + if (map.isEmpty()) { + return AnthropicChatCompletionRequestTaskSettings.EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + Integer maxTokens = extractOptionalPositiveInteger(map, MAX_TOKENS, ModelConfigurations.SERVICE_SETTINGS, validationException); + // At the time of writing the allowed values are -1, and range 0-1. I'm intentionally not validating the values here, we'll let + // Anthropic return an error when we send it instead. + Double temperature = removeAsType(map, TEMPERATURE_FIELD, Double.class); + Double topP = removeAsType(map, TOP_P_FIELD, Double.class); + Integer topK = removeAsType(map, TOP_K_FIELD, Integer.class); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AnthropicChatCompletionRequestTaskSettings(maxTokens, temperature, topP, topK); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..3a70a26a82387 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionServiceSettings.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +/** + * Defines the service settings for interacting with Anthropic's chat completion models. + */ +public class AnthropicChatCompletionServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + AnthropicRateLimitServiceSettings { + + public static final String NAME = "anthropic_completion_service_settings"; + + // The rate limit for build tier 1 is 50 request per minute + // Details are here https://docs.anthropic.com/en/api/rate-limits + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(50); + + public static AnthropicChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + AnthropicService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AnthropicChatCompletionServiceSettings(modelId, rateLimitSettings); + } + + private final String modelId; + + private final RateLimitSettings rateLimitSettings; + + public AnthropicChatCompletionServiceSettings(String modelId, @Nullable RateLimitSettings ratelimitSettings) { + this.modelId = Objects.requireNonNull(modelId); + this.rateLimitSettings = Objects.requireNonNullElse(ratelimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public AnthropicChatCompletionServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_COMPLETION_INFERENCE_SERVICE_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + AnthropicChatCompletionServiceSettings that = (AnthropicChatCompletionServiceSettings) object; + return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettings.java new file mode 100644 index 0000000000000..a1457dda64e40 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettings.java @@ -0,0 +1,185 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.nonNullOrDefault; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; +import static org.elasticsearch.xpack.inference.services.anthropic.AnthropicServiceFields.MAX_TOKENS; +import static org.elasticsearch.xpack.inference.services.anthropic.AnthropicServiceFields.TEMPERATURE_FIELD; +import static org.elasticsearch.xpack.inference.services.anthropic.AnthropicServiceFields.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.anthropic.AnthropicServiceFields.TOP_P_FIELD; + +public class AnthropicChatCompletionTaskSettings implements TaskSettings { + + public static final String NAME = "anthropic_completion_task_settings"; + + public static AnthropicChatCompletionTaskSettings fromMap(Map map, ConfigurationParseContext context) { + return switch (context) { + case REQUEST -> fromRequestMap(map); + case PERSISTENT -> fromPersistedMap(map); + }; + } + + private static AnthropicChatCompletionTaskSettings fromRequestMap(Map map) { + ValidationException validationException = new ValidationException(); + + var commonFields = fromMap(map, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AnthropicChatCompletionTaskSettings(commonFields); + } + + private static AnthropicChatCompletionTaskSettings fromPersistedMap(Map map) { + var commonFields = fromMap(map, new ValidationException()); + + return new AnthropicChatCompletionTaskSettings(commonFields); + } + + private record CommonFields(int maxTokens, Double temperature, Double topP, Integer topK) {} + + private static CommonFields fromMap(Map map, ValidationException validationException) { + Integer maxTokens = extractRequiredPositiveInteger(map, MAX_TOKENS, ModelConfigurations.TASK_SETTINGS, validationException); + + // At the time of writing the allowed values for the temperature field are -1, and range 0-1. + // I'm intentionally not validating the values here, we'll let Anthropic return an error when we send it instead. + Double temperature = removeAsType(map, TEMPERATURE_FIELD, Double.class); + + // I'm intentionally not validating these so that Anthropic will return an error if they aren't in the correct range + Double topP = removeAsType(map, TOP_P_FIELD, Double.class); + Integer topK = removeAsType(map, TOP_K_FIELD, Integer.class); + + return new CommonFields(Objects.requireNonNullElse(maxTokens, -1), temperature, topP, topK); + } + + public static AnthropicChatCompletionTaskSettings of( + AnthropicChatCompletionTaskSettings originalSettings, + AnthropicChatCompletionRequestTaskSettings requestSettings + ) { + return new AnthropicChatCompletionTaskSettings( + Objects.requireNonNullElse(requestSettings.maxTokens(), originalSettings.maxTokens), + nonNullOrDefault(requestSettings.temperature(), originalSettings.temperature), + nonNullOrDefault(requestSettings.topP(), originalSettings.topP), + nonNullOrDefault(requestSettings.topK(), originalSettings.topK) + ); + } + + private final int maxTokens; + private final Double temperature; + private final Double topP; + private final Integer topK; + + public AnthropicChatCompletionTaskSettings(int maxTokens, @Nullable Double temperature, @Nullable Double topP, @Nullable Integer topK) { + this.maxTokens = maxTokens; + this.temperature = temperature; + this.topP = topP; + this.topK = topK; + } + + public AnthropicChatCompletionTaskSettings(StreamInput in) throws IOException { + this.maxTokens = in.readVInt(); + this.temperature = in.readOptionalDouble(); + this.topP = in.readOptionalDouble(); + this.topK = in.readOptionalInt(); + } + + private AnthropicChatCompletionTaskSettings(CommonFields commonFields) { + this(commonFields.maxTokens, commonFields.temperature, commonFields.topP, commonFields.topK); + } + + public int maxTokens() { + return maxTokens; + } + + public Double temperature() { + return temperature; + } + + public Double topP() { + return topP; + } + + public Integer topK() { + return topK; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MAX_TOKENS, maxTokens); + + if (temperature != null) { + builder.field(TEMPERATURE_FIELD, temperature); + } + + if (topP != null) { + builder.field(TOP_P_FIELD, topP); + } + + if (topK != null) { + builder.field(TOP_P_FIELD, topK); + } + + builder.endObject(); + + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_ANTHROPIC_INTEGRATION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(maxTokens); + out.writeOptionalDouble(temperature); + out.writeOptionalDouble(topP); + out.writeOptionalInt(topK); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + AnthropicChatCompletionTaskSettings that = (AnthropicChatCompletionTaskSettings) object; + return Objects.equals(maxTokens, that.maxTokens) + && Objects.equals(temperature, that.temperature) + && Objects.equals(topP, that.topP) + && Objects.equals(topK, that.topK); + } + + @Override + public int hashCode() { + return Objects.hash(maxTokens, temperature, topP, topK); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java new file mode 100644 index 0000000000000..17e6ec2152e7e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionVisitor; + +import java.util.Map; +import java.util.Objects; + +public abstract class GoogleVertexAiModel extends Model { + + private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings; + + public GoogleVertexAiModel( + ModelConfigurations configurations, + ModelSecrets secrets, + GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings + ) { + super(configurations, secrets); + + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + } + + public GoogleVertexAiModel(GoogleVertexAiModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + } + + public abstract ExecutableAction accept(GoogleVertexAiActionVisitor creator, Map taskSettings); + + public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiRateLimitServiceSettings.java new file mode 100644 index 0000000000000..bd1373ae3ab8f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiRateLimitServiceSettings.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public interface GoogleVertexAiRateLimitServiceSettings { + + String modelId(); + + RateLimitSettings rateLimitSettings(); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java new file mode 100644 index 0000000000000..57c8d61f9f9a5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; + +public class GoogleVertexAiSecretSettings implements SecretSettings { + + public static final String NAME = "google_vertex_ai_secret_settings"; + + public static final String SERVICE_ACCOUNT_JSON = "service_account_json"; + + private final SecureString serviceAccountJson; + + public static GoogleVertexAiSecretSettings fromMap(@Nullable Map map) { + if (map == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + SecureString secureServiceAccountJson = extractRequiredSecureString( + map, + SERVICE_ACCOUNT_JSON, + ModelSecrets.SECRET_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleVertexAiSecretSettings(secureServiceAccountJson); + } + + public GoogleVertexAiSecretSettings(SecureString serviceAccountJson) { + this.serviceAccountJson = Objects.requireNonNull(serviceAccountJson); + } + + public GoogleVertexAiSecretSettings(StreamInput in) throws IOException { + this(in.readSecureString()); + } + + public SecureString serviceAccountJson() { + return serviceAccountJson; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(SERVICE_ACCOUNT_JSON, serviceAccountJson.toString()); + + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_GOOGLE_VERTEX_AI_EMBEDDINGS_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeSecureString(serviceAccountJson); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + GoogleVertexAiSecretSettings that = (GoogleVertexAiSecretSettings) object; + return Objects.equals(serviceAccountJson, that.serviceAccountJson); + } + + @Override + public int hashCode() { + return Objects.hash(serviceAccountJson); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java new file mode 100644 index 0000000000000..4708d5b7d5300 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -0,0 +1,273 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE; + +public class GoogleVertexAiService extends SenderService { + + public static final String NAME = "googlevertexai"; + + public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + Set platformArchitectures, + ActionListener parseModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + GoogleVertexAiModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parseModelListener.onResponse(model); + } catch (Exception e) { + parseModelListener.onFailure(e); + } + } + + @Override + public Model parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_GOOGLE_VERTEX_AI_EMBEDDINGS_ADDED; + } + + @Override + public void checkModelConfig(Model model, ActionListener listener) { + if (model instanceof GoogleVertexAiEmbeddingsModel embeddingsModel) { + ServiceUtils.getEmbeddingSize( + model, + this, + listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size))) + ); + } + } + + @Override + protected void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof GoogleVertexAiModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model; + + var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); + + var action = googleVertexAiModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(input), timeout, listener); + } + + @Override + protected void doInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("Query input not supported for Google Vertex AI"); + } + + @Override + protected void doChunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model; + var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); + + var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT) + .batchRequestsWithListeners(listener); + for (var request : batchedRequests) { + var action = googleVertexAiModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); + } + } + + private GoogleVertexAiEmbeddingsModel updateModelWithEmbeddingDetails(GoogleVertexAiEmbeddingsModel model, int embeddingSize) { + if (model.getServiceSettings().dimensionsSetByUser() + && model.getServiceSettings().dimensions() != null + && model.getServiceSettings().dimensions() != embeddingSize) { + throw new ElasticsearchStatusException( + Strings.format( + "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. " + + "Please recreate the [%s] configuration with the correct dimensions", + embeddingSize, + model.getServiceSettings().dimensions(), + model.getConfigurations().getInferenceEntityId() + ), + RestStatus.BAD_REQUEST + ); + } + + GoogleVertexAiEmbeddingsServiceSettings serviceSettings = new GoogleVertexAiEmbeddingsServiceSettings( + model.getServiceSettings().location(), + model.getServiceSettings().projectId(), + model.getServiceSettings().modelId(), + model.getServiceSettings().dimensionsSetByUser(), + model.getServiceSettings().maxInputTokens(), + embeddingSize, + model.getServiceSettings().similarity(), + model.getServiceSettings().rateLimitSettings() + ); + + return new GoogleVertexAiEmbeddingsModel(model, serviceSettings); + } + + private static GoogleVertexAiModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + private static GoogleVertexAiModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + return switch (taskType) { + case TEXT_EMBEDDING -> new GoogleVertexAiEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + }; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceFields.java new file mode 100644 index 0000000000000..c669155a6cf2c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceFields.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +public class GoogleVertexAiServiceFields { + + public static final String LOCATION = "location"; + + public static final String PROJECT_ID = "project_id"; + + /** + * In `us-central-1` the max input size is `250`, but in every other region it's `5` according + * to these docs: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings. + * + * Therefore, being conservative and setting it to `5`. + */ + static final int EMBEDDING_MAX_BATCH_SIZE = 5; + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java new file mode 100644 index 0000000000000..eb49e3f182a5e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionVisitor; +import org.elasticsearch.xpack.inference.external.request.googlevertexai.GoogleVertexAiUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +import static org.elasticsearch.core.Strings.format; + +public class GoogleVertexAiEmbeddingsModel extends GoogleVertexAiModel { + + private URI uri; + + public GoogleVertexAiEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + GoogleVertexAiEmbeddingsServiceSettings.fromMap(serviceSettings, context), + GoogleVertexAiEmbeddingsTaskSettings.fromMap(taskSettings), + GoogleVertexAiSecretSettings.fromMap(secrets) + ); + } + + public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, GoogleVertexAiEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + // Should only be used directly for testing + GoogleVertexAiEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + GoogleVertexAiEmbeddingsServiceSettings serviceSettings, + GoogleVertexAiEmbeddingsTaskSettings taskSettings, + @Nullable GoogleVertexAiSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = buildUri(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + // Should only be used directly for testing + protected GoogleVertexAiEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + String uri, + GoogleVertexAiEmbeddingsServiceSettings serviceSettings, + GoogleVertexAiEmbeddingsTaskSettings taskSettings, + @Nullable GoogleVertexAiSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = new URI(uri); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Override + public GoogleVertexAiEmbeddingsServiceSettings getServiceSettings() { + return (GoogleVertexAiEmbeddingsServiceSettings) super.getServiceSettings(); + } + + @Override + public GoogleVertexAiEmbeddingsTaskSettings getTaskSettings() { + return (GoogleVertexAiEmbeddingsTaskSettings) super.getTaskSettings(); + } + + @Override + public GoogleVertexAiSecretSettings getSecretSettings() { + return (GoogleVertexAiSecretSettings) super.getSecretSettings(); + } + + public URI uri() { + return uri; + } + + @Override + public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map taskSettings) { + return visitor.create(this, taskSettings); + } + + public static URI buildUri(String location, String projectId, String modelId) throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX)) + .setPathSegments( + GoogleVertexAiUtils.V1, + GoogleVertexAiUtils.PROJECTS, + projectId, + GoogleVertexAiUtils.LOCATIONS, + location, + GoogleVertexAiUtils.PUBLISHERS, + GoogleVertexAiUtils.PUBLISHER_GOOGLE, + GoogleVertexAiUtils.MODELS, + format("%s:%s", modelId, GoogleVertexAiUtils.PREDICT) + ) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsRequestTaskSettings.java new file mode 100644 index 0000000000000..14a67a64377e2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsRequestTaskSettings.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; + +public record GoogleVertexAiEmbeddingsRequestTaskSettings(@Nullable Boolean autoTruncate) { + + public static final GoogleVertexAiEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new GoogleVertexAiEmbeddingsRequestTaskSettings(null); + + public static GoogleVertexAiEmbeddingsRequestTaskSettings fromMap(Map map) { + if (map.isEmpty()) { + return GoogleVertexAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + Boolean autoTruncate = extractOptionalBoolean(map, GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleVertexAiEmbeddingsRequestTaskSettings(autoTruncate); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..5f037f4530999 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettings.java @@ -0,0 +1,275 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; +import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; + +public class GoogleVertexAiEmbeddingsServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + GoogleVertexAiRateLimitServiceSettings { + + public static final String NAME = "google_vertex_ai_embeddings_service_settings"; + + public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + + // See online prediction requests per minute: https://cloud.google.com/vertex-ai/docs/quotas. + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(30_000); + + public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String location = extractRequiredString(map, LOCATION, ModelConfigurations.SERVICE_SETTINGS, validationException); + String projectId = extractRequiredString(map, PROJECT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + GoogleVertexAiService.NAME, + context + ); + + Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); + + switch (context) { + case REQUEST -> { + if (dimensionsSetByUser != null) { + validationException.addValidationError( + ServiceUtils.invalidSettingError(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + dimensionsSetByUser = dims != null; + } + case PERSISTENT -> { + if (dimensionsSetByUser == null) { + validationException.addValidationError( + ServiceUtils.missingSettingErrorMsg(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + } + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleVertexAiEmbeddingsServiceSettings( + location, + projectId, + model, + dimensionsSetByUser, + maxInputTokens, + dims, + similarityMeasure, + rateLimitSettings + ); + } + + private final String location; + + private final String projectId; + + private final String modelId; + + private final Integer dims; + + private final SimilarityMeasure similarity; + private final Integer maxInputTokens; + + private final RateLimitSettings rateLimitSettings; + + private final Boolean dimensionsSetByUser; + + public GoogleVertexAiEmbeddingsServiceSettings( + String location, + String projectId, + String modelId, + Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable Integer dims, + @Nullable SimilarityMeasure similarity, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.location = location; + this.projectId = projectId; + this.modelId = modelId; + this.dimensionsSetByUser = dimensionsSetByUser; + this.maxInputTokens = maxInputTokens; + this.dims = dims; + this.similarity = Objects.requireNonNullElse(similarity, SimilarityMeasure.DOT_PRODUCT); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public GoogleVertexAiEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.location = in.readString(); + this.projectId = in.readString(); + this.modelId = in.readString(); + this.dimensionsSetByUser = in.readBoolean(); + this.maxInputTokens = in.readOptionalVInt(); + this.dims = in.readOptionalVInt(); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.rateLimitSettings = new RateLimitSettings(in); + } + + public String projectId() { + return projectId; + } + + public String location() { + return location; + } + + @Override + public String modelId() { + return modelId; + } + + public Boolean dimensionsSetByUser() { + return dimensionsSetByUser; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public Integer dimensions() { + return dims; + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_GOOGLE_VERTEX_AI_EMBEDDINGS_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(location); + out.writeString(projectId); + out.writeString(modelId); + out.writeBoolean(dimensionsSetByUser); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalVInt(dims); + out.writeOptionalEnum(similarity); + rateLimitSettings.writeTo(out); + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(LOCATION, location); + builder.field(PROJECT_ID, projectId); + builder.field(MODEL_ID, modelId); + + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + + if (dims != null) { + builder.field(DIMENSIONS, dims); + } + + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + GoogleVertexAiEmbeddingsServiceSettings that = (GoogleVertexAiEmbeddingsServiceSettings) object; + return Objects.equals(location, that.location) + && Objects.equals(projectId, that.projectId) + && Objects.equals(modelId, that.modelId) + && Objects.equals(dims, that.dims) + && similarity == that.similarity + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser); + } + + @Override + public int hashCode() { + return Objects.hash(location, projectId, modelId, dims, similarity, maxInputTokens, rateLimitSettings, dimensionsSetByUser); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettings.java new file mode 100644 index 0000000000000..6de44fe470a2f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettings.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; + +public class GoogleVertexAiEmbeddingsTaskSettings implements TaskSettings { + + public static final String NAME = "google_vertex_ai_embeddings_task_settings"; + + public static final String AUTO_TRUNCATE = "auto_truncate"; + + public static final GoogleVertexAiEmbeddingsTaskSettings EMPTY_SETTINGS = new GoogleVertexAiEmbeddingsTaskSettings( + Boolean.valueOf(null) + ); + + public static GoogleVertexAiEmbeddingsTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + Boolean autoTruncate = extractOptionalBoolean(map, AUTO_TRUNCATE, validationException); + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate); + } + + public static GoogleVertexAiEmbeddingsTaskSettings of( + GoogleVertexAiEmbeddingsTaskSettings originalSettings, + GoogleVertexAiEmbeddingsRequestTaskSettings requestSettings + ) { + var autoTruncate = requestSettings.autoTruncate() == null ? originalSettings.autoTruncate : requestSettings.autoTruncate(); + return new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate); + } + + private final Boolean autoTruncate; + + public GoogleVertexAiEmbeddingsTaskSettings(@Nullable Boolean autoTruncate) { + this.autoTruncate = autoTruncate; + } + + public GoogleVertexAiEmbeddingsTaskSettings(StreamInput in) throws IOException { + this.autoTruncate = in.readOptionalBoolean(); + } + + public Boolean autoTruncate() { + return autoTruncate; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_GOOGLE_VERTEX_AI_EMBEDDINGS_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalBoolean(this.autoTruncate); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (autoTruncate != null) { + builder.field(AUTO_TRUNCATE, autoTruncate); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + GoogleVertexAiEmbeddingsTaskSettings that = (GoogleVertexAiEmbeddingsTaskSettings) object; + return Objects.equals(autoTruncate, that.autoTruncate); + } + + @Override + public int hashCode() { + return Objects.hash(autoTruncate); + } +} diff --git a/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..f21a46521a7f7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +grant { + // required by: com.google.api.client.json.JsonParser#parseValue + permission java.lang.RuntimePermission "accessDeclaredMembers"; + // required by: com.google.api.client.json.GenericJson# + permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + // required to add google certs to the gcs client trustore + permission java.lang.RuntimePermission "setFactory"; + + // gcs client opens socket connections for to access repository + permission java.net.SocketPermission "*", "connect"; +}; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index a352116278e7a..4545327b62272 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -16,8 +16,10 @@ import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; @@ -27,8 +29,10 @@ import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.hamcrest.Matchers; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; @@ -42,6 +46,7 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -161,4 +166,51 @@ public static SimilarityMeasure randomSimilarityMeasure() { } public record PersistedConfig(Map config, Map secrets) {} + + public static PersistedConfig getPersistedConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + + return new PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) + ); + } + + public static PersistedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) { + return new PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + null + ); + } + + public static Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + public static Map buildExpectationCompletions(List completions) { + return Map.of( + ChatCompletionResults.COMPLETION, + completions.stream().map(completion -> Map.of(ChatCompletionResults.Result.RESULT, completion)).collect(Collectors.toList()) + ); + } + + public static ActionListener getModelListenerForException(Class exceptionClass, String expectedMessage) { + return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { + assertThat(e, Matchers.instanceOf(exceptionClass)); + assertThat(e.getMessage(), is(expectedMessage)); + }); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java new file mode 100644 index 0000000000000..a3114300c5ddc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java @@ -0,0 +1,199 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.anthropic; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicRequestUtils; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettingsTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class AnthropicActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testCreate_ChatCompletionModel() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content": [ + { + "type": "text", + "text": "San Francisco has a cool-summer Mediterranean climate." + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AnthropicChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model", 0); + var actionCreator = new AnthropicActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 2.0, -3.0, 3); + var action = actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("San Francisco has a cool-summer Mediterranean climate.")))); + assertThat(webServer.requests(), hasSize(1)); + + var request = webServer.requests().get(0); + + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(request.getHeader(AnthropicRequestUtils.X_API_KEY), equalTo("secret")); + assertThat(request.getHeader(AnthropicRequestUtils.VERSION), equalTo(AnthropicRequestUtils.ANTHROPIC_VERSION_2023_06_01)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(6)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("max_tokens"), is(1)); + assertThat(requestMap.get("temperature"), is(2.0)); + assertThat(requestMap.get("top_p"), is(-3.0)); + assertThat(requestMap.get("top_k"), is(3)); + } + } + + public void testCreate_ChatCompletionModel_FailsFromInvalidResponseFormat() throws IOException { + // timeout as zero for no retries + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content_does_not_exist": [ + { + "type": "text", + "text": "San Francisco has a cool-summer Mediterranean climate." + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AnthropicChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model", 0); + var actionCreator = new AnthropicActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, null, null, null); + var action = actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is(format("Failed to send Anthropic chat completions request to [%s]", getUrl(webServer))) + ); + assertThat( + thrownException.getCause().getMessage(), + is("Failed to find required field [content] in Anthropic chat completions response") + ); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(AnthropicRequestUtils.X_API_KEY), equalTo("secret")); + assertThat( + webServer.requests().get(0).getHeader(AnthropicRequestUtils.VERSION), + equalTo(AnthropicRequestUtils.ANTHROPIC_VERSION_2023_06_01) + ); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(3)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("max_tokens"), is(1)); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java new file mode 100644 index 0000000000000..ffa0ac307490e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java @@ -0,0 +1,242 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.anthropic; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockRequest; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicRequestUtils; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class AnthropicChatCompletionActionTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content": [ + { + "type": "text", + "text": "San Francisco has a cool-summer Mediterranean climate." + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), "secret", "model", 1, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("San Francisco has a cool-summer Mediterranean climate.")))); + assertThat(webServer.requests(), hasSize(1)); + + MockRequest request = webServer.requests().get(0); + + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(request.getHeader(AnthropicRequestUtils.X_API_KEY), equalTo("secret")); + assertThat(request.getHeader(AnthropicRequestUtils.VERSION), equalTo(AnthropicRequestUtils.ANTHROPIC_VERSION_2023_06_01)); + + var requestMap = entityAsMap(request.getBody()); + assertThat(requestMap.size(), is(3)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("max_tokens"), is(1)); + } + } + + public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException { + try (var sender = mock(Sender.class)) { + var thrownException = expectThrows(IllegalArgumentException.class, () -> createAction("^^", "secret", "model", 1, sender)); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", "model", 1, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", "model", 1, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is(format("Failed to send Anthropic chat completions request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", "model", 1, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is(format("Failed to send Anthropic chat completions request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content": [ + { + "type": "text", + "text": "San Francisco has a cool-summer Mediterranean climate." + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), "secret", "model", 1, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("Anthropic completions only accepts 1 input")); + assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); + } + } + + private AnthropicChatCompletionAction createAction(String url, String apiKey, String modelName, int maxTokens, Sender sender) { + var model = AnthropicChatCompletionModelTests.createChatCompletionModel(url, apiKey, modelName, maxTokens); + + return new AnthropicChatCompletionAction(sender, model, createWithEmptySettings(threadPool)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiEmbeddingsActionTests.java new file mode 100644 index 0000000000000..17a2c29e195f1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiEmbeddingsActionTests.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googlevertexai; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class GoogleVertexAiEmbeddingsActionTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + // Successful case tested via end-to-end notebook tests in AppEx repo + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "location", "projectId", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "location", "projectId", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is(format("Failed to send Google Vertex AI embeddings request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "location", "projectId", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is(format("Failed to send Google Vertex AI embeddings request to [%s]", getUrl(webServer))) + ); + } + + private GoogleVertexAiEmbeddingsAction createAction(String url, String location, String projectId, String modelName, Sender sender) { + var model = createModel(location, projectId, modelName, url, "{}"); + + return new GoogleVertexAiEmbeddingsAction(sender, model, createWithEmptySettings(threadPool)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandlerTests.java new file mode 100644 index 0000000000000..0b9390f293ff9 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandlerTests.java @@ -0,0 +1,168 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.anthropic; + +import org.apache.http.Header; +import org.apache.http.HeaderElement; +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AnthropicResponseHandlerTests extends ESTestCase { + + public void testCheckForFailureStatusCode_DoesNotThrowFor200() { + callCheckForFailureStatusCode(200, "id"); + } + + public void testCheckForFailureStatusCode_ThrowsFor500_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(500, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [500]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor529_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(529, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString( + "Received an Anthropic server is temporarily overloaded status code for request from inference entity id [id] status [529]" + ) + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor505_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(505, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [505]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor429_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString( + "Received a rate limit status code. Token limit [unknown], remaining tokens [unknown], tokens reset [unknown]. " + + "Request limit [unknown], remaining requests [unknown], request reset [unknown]. " + + "Retry after [unknown] for request from inference entity id [id] status [429]" + ) + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS)); + } + + public void testCheckForFailureStatusCode_ThrowsFor429_ShouldRetry_RetrievesFieldsFromHeaders() { + int statusCode = 429; + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + var httpResult = new HttpResult(response, new byte[] {}); + + when(response.getFirstHeader(AnthropicResponseHandler.REQUESTS_LIMIT)).thenReturn( + new BasicHeader(AnthropicResponseHandler.REQUESTS_LIMIT, "3000") + ); + when(response.getFirstHeader(AnthropicResponseHandler.REMAINING_REQUESTS)).thenReturn( + new BasicHeader(AnthropicResponseHandler.REMAINING_REQUESTS, "2999") + ); + when(response.getFirstHeader(AnthropicResponseHandler.TOKENS_LIMIT)).thenReturn( + new BasicHeader(AnthropicResponseHandler.TOKENS_LIMIT, "10000") + ); + when(response.getFirstHeader(AnthropicResponseHandler.REMAINING_TOKENS)).thenReturn( + new BasicHeader(AnthropicResponseHandler.REMAINING_TOKENS, "99800") + ); + when(response.getFirstHeader(AnthropicResponseHandler.REQUEST_RESET)).thenReturn( + new BasicHeader(AnthropicResponseHandler.REQUEST_RESET, "123") + ); + when(response.getFirstHeader(AnthropicResponseHandler.TOKENS_RESET)).thenReturn( + new BasicHeader(AnthropicResponseHandler.TOKENS_RESET, "456") + ); + when(response.getFirstHeader(AnthropicResponseHandler.RETRY_AFTER)).thenReturn( + new BasicHeader(AnthropicResponseHandler.RETRY_AFTER, "2") + ); + + var error = AnthropicResponseHandler.buildRateLimitErrorMessage(httpResult); + assertThat( + error, + containsString( + "Received a rate limit status code. Token limit [10000], remaining tokens [99800], tokens reset [456]. " + + "Request limit [3000], remaining requests [2999], request reset [123]. Retry after [2]" + ) + ); + } + + public void testCheckForFailureStatusCode_ThrowsFor403_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(403, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a permission denied error status code for request from inference entity id [id] status [403]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.FORBIDDEN)); + } + + public void testCheckForFailureStatusCode_ThrowsFor300_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(300, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Unhandled redirection for request from inference entity id [id] status [300]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.MULTIPLE_CHOICES)); + } + + public void testCheckForFailureStatusCode_ThrowsFor425_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(425, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received an unsuccessful status code for request from inference entity id [id] status [425]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + private static void callCheckForFailureStatusCode(int statusCode, String inferenceEntityId) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + var header = mock(Header.class); + when(header.getElements()).thenReturn(new HeaderElement[] {}); + when(httpResponse.getFirstHeader(anyString())).thenReturn(header); + + var mockRequest = mock(Request.class); + when(mockRequest.getInferenceEntityId()).thenReturn(inferenceEntityId); + var httpResult = new HttpResult(httpResponse, new byte[] {}); + var handler = new AnthropicResponseHandler("", (request, result) -> null); + + handler.checkForFailureStatusCode(mockRequest, httpResult); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/googlevertexai/GoogleVertexAiResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/googlevertexai/GoogleVertexAiResponseHandlerTests.java new file mode 100644 index 0000000000000..f2de009edec44 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/googlevertexai/GoogleVertexAiResponseHandlerTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.googlevertexai; + +import org.apache.http.Header; +import org.apache.http.HeaderElement; +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GoogleVertexAiResponseHandlerTests extends ESTestCase { + + public void testCheckForFailureStatusCode_DoesNotThrowFor200() { + callCheckForFailureStatusCode(200, "id"); + } + + public void testCheckForFailureStatusCode_ThrowsFor500_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(500, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [500]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor503_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(503, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString( + "The Google Vertex AI service may be temporarily overloaded or down for request from inference entity id [id] status [503]" + ) + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor505_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(505, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [505]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor429_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a rate limit status code for request from inference entity id [id] status [429]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS)); + } + + public void testCheckForFailureStatusCode_ThrowsFor404_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(404, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Resource not found at [null] for request from inference entity id [id] status [404]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.NOT_FOUND)); + } + + public void testCheckForFailureStatusCode_ThrowsFor403_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(403, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a permission denied error status code for request from inference entity id [id] status [403]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.FORBIDDEN)); + } + + public void testCheckForFailureStatusCode_ThrowsFor300_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(300, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Unhandled redirection for request from inference entity id [id] status [300]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.MULTIPLE_CHOICES)); + } + + public void testCheckForFailureStatusCode_ThrowsFor425_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(425, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received an unsuccessful status code for request from inference entity id [id] status [425]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + private static void callCheckForFailureStatusCode(int statusCode, String modelId) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + var header = mock(Header.class); + when(header.getElements()).thenReturn(new HeaderElement[] {}); + when(httpResponse.getFirstHeader(anyString())).thenReturn(header); + + var mockRequest = mock(Request.class); + when(mockRequest.getInferenceEntityId()).thenReturn(modelId); + var httpResult = new HttpResult(httpResponse, new byte[] {}); + var handler = new GoogleVertexAiResponseHandler("", (request, result) -> null); + + handler.checkForFailureStatusCode(mockRequest, httpResult); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java index 30bd40bdcc111..c2842a1278a49 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java @@ -18,9 +18,11 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.UncategorizedExecutionException; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xpack.inference.external.http.HttpClient; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.HttpRequestTests; @@ -33,6 +35,7 @@ import java.net.UnknownHostException; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.createDefaultRetrySettings; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -455,6 +458,86 @@ public void testSend_ReturnsFailure_WhenHttpResultsListenerCallsOnFailure_WithNo verifyNoMoreInteractions(httpClient); } + public void testSend_DoesNotRetryIndefinitely() throws IOException { + var threadPool = new TestThreadPool(getTestName()); + try { + + var httpClient = mock(HttpClient.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + // respond with a retryable exception + listener.onFailure(new ConnectionClosedException("failed")); + + return Void.TYPE; + }).when(httpClient).send(any(), any(), any()); + + var handler = mock(ResponseHandler.class); + + var retrier = new RetryingHttpSender( + httpClient, + mock(ThrottlerManager.class), + createDefaultRetrySettings(), + threadPool, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + + var listener = new PlainActionFuture(); + retrier.send(mock(Logger.class), mockRequest(), HttpClientContext.create(), () -> false, handler, listener); + + // Assert that the retrying sender stopped after max retires even though the exception is retryable + var thrownException = expectThrows(UncategorizedExecutionException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getCause(), instanceOf(ConnectionClosedException.class)); + assertThat(thrownException.getMessage(), is("Failed execution")); + assertThat(thrownException.getSuppressed().length, is(0)); + verify(httpClient, times(RetryingHttpSender.MAX_RETIES)).send(any(), any(), any()); + verifyNoMoreInteractions(httpClient); + } finally { + terminate(threadPool); + } + } + + public void testSend_DoesNotRetryIndefinitely_WithAlwaysRetryingResponseHandler() throws IOException { + var threadPool = new TestThreadPool(getTestName()); + try { + + var httpClient = mock(HttpClient.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new ConnectionClosedException("failed")); + + return Void.TYPE; + }).when(httpClient).send(any(), any(), any()); + + // This handler will always tell the sender to retry + var handler = createRetryingResponseHandler(); + + var retrier = new RetryingHttpSender( + httpClient, + mock(ThrottlerManager.class), + createDefaultRetrySettings(), + threadPool, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + + var listener = new PlainActionFuture(); + retrier.send(mock(Logger.class), mockRequest(), HttpClientContext.create(), () -> false, handler, listener); + + // Assert that the retrying sender stopped after max retires + var thrownException = expectThrows(UncategorizedExecutionException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getCause(), instanceOf(ConnectionClosedException.class)); + assertThat(thrownException.getMessage(), is("Failed execution")); + assertThat(thrownException.getSuppressed().length, is(0)); + verify(httpClient, times(RetryingHttpSender.MAX_RETIES)).send(any(), any(), any()); + verifyNoMoreInteractions(httpClient); + } finally { + terminate(threadPool); + } + } + private static HttpResponse mockHttpResponse() { var statusLine = mock(StatusLine.class); when(statusLine.getStatusCode()).thenReturn(200); @@ -499,4 +582,27 @@ private RetryingHttpSender createRetrier(HttpClient httpClient) { EsExecutors.DIRECT_EXECUTOR_SERVICE ); } + + private ResponseHandler createRetryingResponseHandler() { + // Returns a response handler that wants to retry. + // Does not need to handle parsing as it should only be used + // testing failed requests + return new ResponseHandler() { + @Override + public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) + throws RetryException { + throw new RetryException(true, new IOException("response handler validate failed as designed")); + } + + @Override + public InferenceServiceResults parseResult(Request request, HttpResult result) throws RetryException { + throw new RetryException(true, new IOException("response handler parse failed as designed")); + } + + @Override + public String getRequestType() { + return "foo"; + } + }; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..f293a59e47d11 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestEntityTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.anthropic; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class AnthropicChatCompletionRequestEntityTests extends ESTestCase { + + public void testXContent() throws IOException { + var entity = new AnthropicChatCompletionRequestEntity( + List.of("abc"), + new AnthropicChatCompletionServiceSettings("model", null), + new AnthropicChatCompletionTaskSettings(1, -1.0, 1.2, 3) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"messages":[{"role":"user","content":"abc"}],"model":"model","max_tokens":1,"temperature":-1.0,"top_p":1.2,"top_k":3}""")); + + } + + public void testXContent_WithoutTemperature() throws IOException { + var entity = new AnthropicChatCompletionRequestEntity( + List.of("abc"), + new AnthropicChatCompletionServiceSettings("model", null), + new AnthropicChatCompletionTaskSettings(1, null, 1.2, 3) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"messages":[{"role":"user","content":"abc"}],"model":"model","max_tokens":1,"top_p":1.2,"top_k":3}""")); + + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestTests.java new file mode 100644 index 0000000000000..0a606c522c13e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestTests.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.anthropic; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class AnthropicChatCompletionRequestTests extends ESTestCase { + + public void testCreateRequest() throws IOException { + var request = createRequest("secret", "abc", "model", 2); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is(buildAnthropicUri())); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(AnthropicRequestUtils.X_API_KEY).getValue(), is("secret")); + assertThat( + httpPost.getLastHeader(AnthropicRequestUtils.VERSION).getValue(), + is(AnthropicRequestUtils.ANTHROPIC_VERSION_2023_06_01) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("max_tokens"), is(2)); + } + + public void testCreateRequest_TestUrl() throws IOException { + var request = createRequest("fake_url", "secret", "abc", "model", 2); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is("fake_url")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(AnthropicRequestUtils.X_API_KEY).getValue(), is("secret")); + assertThat( + httpPost.getLastHeader(AnthropicRequestUtils.VERSION).getValue(), + is(AnthropicRequestUtils.ANTHROPIC_VERSION_2023_06_01) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("max_tokens"), is(2)); + } + + public void testTruncate_DoesNotReduceInputTextSize() throws IOException { + var request = createRequest("secret", "abc", "model", 2); + + var truncatedRequest = request.truncate(); + assertThat(request.getURI().toString(), is(buildAnthropicUri())); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(3)); + + // We do not truncate for Anthropic chat completions + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("max_tokens"), is(2)); + } + + public void testTruncationInfo_ReturnsNull() { + var request = createRequest("secret", "abc", "model", 2); + assertNull(request.getTruncationInfo()); + } + + public static AnthropicChatCompletionRequest createRequest(String apiKey, String input, String model, int maxTokens) { + var chatCompletionModel = AnthropicChatCompletionModelTests.createChatCompletionModel(apiKey, model, maxTokens); + return new AnthropicChatCompletionRequest(List.of(input), chatCompletionModel); + } + + public static AnthropicChatCompletionRequest createRequest(String url, String apiKey, String input, String model, int maxTokens) { + var chatCompletionModel = AnthropicChatCompletionModelTests.createChatCompletionModel(url, apiKey, model, maxTokens); + return new AnthropicChatCompletionRequest(List.of(input), chatCompletionModel); + } + + private static String buildAnthropicUri() { + return Strings.format( + "https://%s/%s/%s", + AnthropicRequestUtils.HOST, + AnthropicRequestUtils.API_VERSION_1, + AnthropicRequestUtils.MESSAGES_PATH + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..f4912e0862e60 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestEntityTests.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googlevertexai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase { + + public void testToXContent_SingleEmbeddingRequest_WritesAutoTruncationIfDefined() throws IOException { + var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc"), true); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "instances": [ + { + "content": "abc" + } + ], + "parameters": { + "autoTruncate": true + } + } + """)); + } + + public void testToXContent_SingleEmbeddingRequest_DoesNotWriteAutoTruncationIfNotDefined() throws IOException { + var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc"), null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "instances": [ + { + "content": "abc" + } + ] + } + """)); + } + + public void testToXContent_MultipleEmbeddingsRequest_WritesAutoTruncationIfDefined() throws IOException { + var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), true); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "instances": [ + { + "content": "abc" + }, + { + "content": "def" + } + ], + "parameters": { + "autoTruncate": true + } + } + """)); + } + + public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteAutoTruncationIfNotDefined() throws IOException { + var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "instances": [ + { + "content": "abc" + }, + { + "content": "def" + } + ] + } + """)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..b28fd8d3a0cf9 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestTests.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googlevertexai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase { + + private static final String AUTH_HEADER_VALUE = "foo"; + + public void testCreateRequest_WithoutDimensionsSet_And_WithoutAutoTruncateSet() throws IOException { + var model = "model"; + var input = "input"; + + var request = createRequest(model, input, null); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(1)); + assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input"))))); + } + + public void testCreateRequest_WithAutoTruncateSet() throws IOException { + var model = "model"; + var input = "input"; + var autoTruncate = true; + + var request = createRequest(model, input, autoTruncate); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of("autoTruncate", true)))); + } + + public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { + var model = "model"; + var input = "abcd"; + + var request = createRequest(model, input, null); + var truncatedRequest = request.truncate(); + var httpRequest = truncatedRequest.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(1)); + assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "ab"))))); + } + + private static GoogleVertexAiEmbeddingsRequest createRequest(String modelId, String input, @Nullable Boolean autoTruncate) { + var embeddingsModel = GoogleVertexAiEmbeddingsModelTests.createModel(modelId, autoTruncate); + + return new GoogleVertexAiEmbeddingsWithoutAuthRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of(input), new boolean[] { false }), + embeddingsModel + ); + } + + /** + * We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest} + */ + private static class GoogleVertexAiEmbeddingsWithoutAuthRequest extends GoogleVertexAiEmbeddingsRequest { + + GoogleVertexAiEmbeddingsWithoutAuthRequest( + Truncator truncator, + Truncator.TruncationResult input, + GoogleVertexAiEmbeddingsModel model + ) { + super(truncator, input, model); + } + + @Override + public void decorateWithAuth(HttpPost httpPost) { + httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE); + } + + @Override + public Request truncate() { + GoogleVertexAiEmbeddingsRequest embeddingsRequest = (GoogleVertexAiEmbeddingsRequest) super.truncate(); + return new GoogleVertexAiEmbeddingsWithoutAuthRequest( + embeddingsRequest.truncator(), + embeddingsRequest.truncationResult(), + embeddingsRequest.model() + ); + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java index 9ef9ab4daa0ae..53bb38943d35b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java @@ -45,7 +45,7 @@ public void testCheckForFailureStatusCode() { var handler = new AzureMistralOpenAiExternalResponseHandler( "", (request, result) -> null, - AzureMistralOpenAiErrorResponseEntity::fromResponse + ErrorMessageResponseEntity::fromResponse ); // 200 ok diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntityTests.java similarity index 64% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntityTests.java index 48a560341f392..d57d1537f6c30 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntityTests.java @@ -12,10 +12,12 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.external.http.HttpResult; +import java.nio.charset.StandardCharsets; + import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; -public class AzureAndOpenAiErrorResponseEntityTests extends ESTestCase { +public class ErrorMessageResponseEntityTests extends ESTestCase { private static HttpResult getMockResult(String jsonString) { var response = mock(HttpResponse.class); @@ -26,23 +28,38 @@ public void testErrorResponse_ExtractsError() { var result = getMockResult(""" {"error":{"message":"test_error_message"}}"""); - var error = AzureMistralOpenAiErrorResponseEntity.fromResponse(result); + var error = ErrorMessageResponseEntity.fromResponse(result); assertNotNull(error); assertThat(error.getErrorMessage(), is("test_error_message")); } + public void testFromResponse_noMessage() { + String responseJson = """ + { + "error": { + "type": "not_found_error", + } + } + """; + + var errorMessage = ErrorMessageResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + assertNull(errorMessage); + } + public void testErrorResponse_ReturnsNullIfNoError() { var result = getMockResult(""" {"noerror":true}"""); - var error = AzureMistralOpenAiErrorResponseEntity.fromResponse(result); + var error = ErrorMessageResponseEntity.fromResponse(result); assertNull(error); } public void testErrorResponse_ReturnsNullIfNotJson() { var result = getMockResult("not a json string"); - var error = AzureMistralOpenAiErrorResponseEntity.fromResponse(result); + var error = ErrorMessageResponseEntity.fromResponse(result); assertNull(error); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/anthropic/AnthropicChatCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/anthropic/AnthropicChatCompletionResponseEntityTests.java new file mode 100644 index 0000000000000..e5490d9f8d3ca --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/anthropic/AnthropicChatCompletionResponseEntityTests.java @@ -0,0 +1,265 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.anthropic; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class AnthropicChatCompletionResponseEntityTests extends ESTestCase { + + public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content": [ + { + "type": "text", + "text": "result" + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + ChatCompletionResults chatCompletionResults = AnthropicChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); + } + + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content": [ + { + "type": "text", + "text": "result" + }, + { + "type": "text", + "text": "result2" + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + ChatCompletionResults chatCompletionResults = AnthropicChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(chatCompletionResults.getResults().size(), is(2)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); + assertThat(chatCompletionResults.getResults().get(1).content(), is("result2")); + } + + public void testFromResponse_CreatesResultsForMultipleItems_IgnoresTools() throws IOException { + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content": [ + { + "type": "text", + "text": "result" + }, + { + "type": "tool_use", + "id": "toolu_01Dc8BGR8aEuToS2B9uz6HMX", + "name": "get_weather", + "input": { + "location": "San Francisco, CA" + } + }, + { + "type": "text", + "text": "result2" + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + ChatCompletionResults chatCompletionResults = AnthropicChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(chatCompletionResults.getResults().size(), is(2)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); + assertThat(chatCompletionResults.getResults().get(1).content(), is("result2")); + } + + public void testFromResponse_FailsWhenContentIsNotPresent() { + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "not_content": [ + { + "type": "text", + "text": "result" + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> AnthropicChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [content] in Anthropic chat completions response")); + } + + public void testFromResponse_FailsWhenContentFieldNotAnArray() { + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content": { + "type": "text", + "text": "result" + }, + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + var thrownException = expectThrows( + ParsingException.class, + () -> AnthropicChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") + ); + } + + public void testFromResponse_FailsWhenTypeDoesNotExist() { + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "not_content": [ + { + "text": "result" + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> AnthropicChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [content] in Anthropic chat completions response")); + } + + public void testFromResponse_FailsWhenContentValueIsAString() { + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content": "hello", + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + var thrownException = expectThrows( + ParsingException.class, + () -> AnthropicChatCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [VALUE_STRING]") + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsRequestTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsRequestTaskSettingsTests.java new file mode 100644 index 0000000000000..87edbddb257a0 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsRequestTaskSettingsTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googlevertexai; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsRequestTaskSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings; + +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class GoogleVertexAiEmbeddingsRequestTaskSettingsTests extends ESTestCase { + + public void testFromMap_ReturnsEmptySettings_IfMapEmpty() { + var requestTaskSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>()); + assertThat(requestTaskSettings, is(GoogleVertexAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS)); + } + + public void testFromMap_DoesNotThrowValidationException_IfAutoTruncateIsMissing() { + var requestTaskSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of("unrelated", true))); + assertThat(requestTaskSettings, is(new GoogleVertexAiEmbeddingsRequestTaskSettings(null))); + } + + public void testFromMap_ExtractsAutoTruncate() { + var autoTruncate = true; + var requestTaskSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap( + new HashMap<>(Map.of(GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, autoTruncate)) + ); + assertThat(requestTaskSettings, is(new GoogleVertexAiEmbeddingsRequestTaskSettings(autoTruncate))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsResponseEntityTests.java new file mode 100644 index 0000000000000..39bf08a21a76b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsResponseEntityTests.java @@ -0,0 +1,208 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googlevertexai; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class GoogleVertexAiEmbeddingsResponseEntityTests extends ESTestCase { + + public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + String responseJson = """ + { + "predictions": [ + { + "embeddings": { + "statistics": { + "truncated": false, + "token_count": 6 + }, + "values": [ + -0.123, + 0.123 + ] + } + } + ] + } + """; + + InferenceTextEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is(List.of(InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(List.of(-0.123F, 0.123F)))) + ); + } + + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { + String responseJson = """ + { + "predictions": [ + { + "embeddings": { + "statistics": { + "truncated": false, + "token_count": 6 + }, + "values": [ + -0.123, + 0.123 + ] + } + }, + { + "embeddings": { + "statistics": { + "truncated": false, + "token_count": 6 + }, + "values": [ + -0.456, + 0.456 + ] + } + } + ] + } + """; + + InferenceTextEmbeddingFloatResults parsedResults = GoogleVertexAiEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(List.of(-0.123F, 0.123F)), + InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(List.of(-0.456F, 0.456F)) + ) + ) + ); + } + + public void testFromResponse_FailsWhenPredictionsFieldIsNotPresent() { + String responseJson = """ + { + "not_predictions": [ + { + "embeddings": { + "statistics": { + "truncated": false, + "token_count": 6 + }, + "values": [ + -0.123, + 0.123 + ] + } + }, + { + "embeddings": { + "statistics": { + "truncated": false, + "token_count": 6 + }, + "values": [ + -0.456, + 0.456 + ] + } + } + ] + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> GoogleVertexAiEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [predictions] in Google Vertex AI embeddings response")); + } + + public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() { + String responseJson = """ + { + "predictions": [ + { + "not_embeddings": { + "statistics": { + "truncated": false, + "token_count": 6 + }, + "values": [ + -0.123, + 0.123 + ] + } + } + ] + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> GoogleVertexAiEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [embeddings] in Google Vertex AI embeddings response")); + } + + public void testFromResponse_FailsWhenValuesFieldIsNotPresent() { + String responseJson = """ + { + "predictions": [ + { + "embeddings": { + "statistics": { + "truncated": false, + "token_count": 6 + }, + "not_values": [ + -0.123, + 0.123 + ] + } + } + ] + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> GoogleVertexAiEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [values] in Google Vertex AI embeddings response")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsTaskSettingsTests.java new file mode 100644 index 0000000000000..23e4e836ff510 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiEmbeddingsTaskSettingsTests.java @@ -0,0 +1,135 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googlevertexai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsRequestTaskSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE; +import static org.hamcrest.Matchers.is; + +public class GoogleVertexAiEmbeddingsTaskSettingsTests extends AbstractBWCWireSerializationTestCase { + + public void testFromMap_AutoTruncateIsSet() { + var autoTruncate = true; + var taskSettingsMap = getTaskSettingsMap(autoTruncate); + var taskSettings = GoogleVertexAiEmbeddingsTaskSettings.fromMap(taskSettingsMap); + + assertThat(taskSettings, is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + } + + public void testFromMap_ThrowsValidationException_IfAutoTruncateIsInvalidValue() { + var taskSettings = getTaskSettingsMap("invalid"); + + expectThrows(ValidationException.class, () -> GoogleVertexAiEmbeddingsTaskSettings.fromMap(taskSettings)); + } + + public void testFromMap_AutoTruncateIsNull() { + var taskSettingsMap = getTaskSettingsMap(null); + var taskSettings = GoogleVertexAiEmbeddingsTaskSettings.fromMap(taskSettingsMap); + // needed, because of constructors being ambiguous otherwise + Boolean nullBoolean = null; + + assertThat(taskSettings, is(new GoogleVertexAiEmbeddingsTaskSettings(nullBoolean))); + } + + public void testFromMap_DoesNotThrow_WithEmptyMap() { + assertNull(GoogleVertexAiEmbeddingsTaskSettings.fromMap(new HashMap<>()).autoTruncate()); + } + + public void testOf_UseRequestSettings() { + var originalAutoTruncate = true; + var originalSettings = new GoogleVertexAiEmbeddingsTaskSettings(originalAutoTruncate); + + var requestAutoTruncate = originalAutoTruncate == false; + var requestTaskSettings = new GoogleVertexAiEmbeddingsRequestTaskSettings(requestAutoTruncate); + + assertThat(GoogleVertexAiEmbeddingsTaskSettings.of(originalSettings, requestTaskSettings).autoTruncate(), is(requestAutoTruncate)); + } + + public void testOf_UseOriginalSettings() { + var originalAutoTruncate = true; + var originalSettings = new GoogleVertexAiEmbeddingsTaskSettings(originalAutoTruncate); + + var requestTaskSettings = new GoogleVertexAiEmbeddingsRequestTaskSettings(null); + + assertThat(GoogleVertexAiEmbeddingsTaskSettings.of(originalSettings, requestTaskSettings).autoTruncate(), is(originalAutoTruncate)); + } + + public void testToXContent_WritesAutoTruncateIfNotNull() throws IOException { + var settings = GoogleVertexAiEmbeddingsTaskSettings.fromMap(getTaskSettingsMap(true)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + settings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"auto_truncate":true}""")); + } + + public void testToXContent_DoesNotWriteAutoTruncateIfNull() throws IOException { + var settings = GoogleVertexAiEmbeddingsTaskSettings.fromMap(getTaskSettingsMap(null)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + settings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return GoogleVertexAiEmbeddingsTaskSettings::new; + } + + @Override + protected GoogleVertexAiEmbeddingsTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected GoogleVertexAiEmbeddingsTaskSettings mutateInstance(GoogleVertexAiEmbeddingsTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, GoogleVertexAiEmbeddingsTaskSettingsTests::createRandom); + } + + @Override + protected GoogleVertexAiEmbeddingsTaskSettings mutateInstanceForVersion( + GoogleVertexAiEmbeddingsTaskSettings instance, + TransportVersion version + ) { + return instance; + } + + private static GoogleVertexAiEmbeddingsTaskSettings createRandom() { + return new GoogleVertexAiEmbeddingsTaskSettings(randomFrom(new Boolean[] { null, randomBoolean() })); + } + + private static Map getTaskSettingsMap(@Nullable Object autoTruncate) { + var map = new HashMap(); + + if (autoTruncate != null) { + map.put(AUTO_TRUNCATE, autoTruncate); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiErrorResponseEntityTests.java new file mode 100644 index 0000000000000..e2c9ebed2c164 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiErrorResponseEntityTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googlevertexai; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class GoogleVertexAiErrorResponseEntityTests extends ESTestCase { + + private static HttpResult getMockResult(String jsonString) { + var response = mock(HttpResponse.class); + return new HttpResult(response, Strings.toUTF8Bytes(jsonString)); + } + + public void testErrorResponse_ExtractsError() { + var result = getMockResult(""" + { + "error": { + "code": 400, + "message": "error message", + "status": "INVALID_ARGUMENT", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.BadRequest", + "fieldViolations": [ + { + "description": "Invalid JSON payload received. Unknown name \\"abc\\": Cannot find field." + } + ] + } + ] + } + } + """); + + var error = GoogleVertexAiErrorResponseEntity.fromResponse(result); + assertNotNull(error); + assertThat(error.getErrorMessage(), is("error message")); + } + + public void testErrorResponse_ReturnsNullIfNoError() { + var result = getMockResult(""" + { + "foo": "bar" + } + """); + + var error = GoogleVertexAiErrorResponseEntity.fromResponse(result); + assertNull(error); + } + + public void testErrorResponse_ReturnsNullIfNotJson() { + var result = getMockResult("error message"); + + var error = GoogleVertexAiErrorResponseEntity.fromResponse(result); + assertNull(error); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 599df8d1cfb3b..a14b42d51c6f8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -36,6 +36,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalTimeValue; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.getEmbeddingSize; @@ -466,6 +467,41 @@ public void testExtractOptionalPositiveLong() { assertThat(validation.validationErrors(), hasSize(1)); } + public void testExtractRequiredPositiveInteger_ReturnsValue() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", 1)); + var parsedInt = extractRequiredPositiveInteger(map, "key", "scope", validation); + + assertThat(validation.validationErrors(), hasSize(1)); + assertNotNull(parsedInt); + assertThat(parsedInt, is(1)); + assertTrue(map.isEmpty()); + } + + public void testExtractRequiredPositiveInteger_AddsErrorForNegativeValue() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", -1)); + var parsedInt = extractRequiredPositiveInteger(map, "key", "scope", validation); + + assertThat(validation.validationErrors(), hasSize(2)); + assertNull(parsedInt); + assertTrue(map.isEmpty()); + assertThat(validation.validationErrors().get(1), is("[scope] Invalid value [-1]. [key] must be a positive integer")); + } + + public void testExtractRequiredPositiveInteger_AddsErrorWhenKeyIsMissing() { + var validation = new ValidationException(); + validation.addValidationError("previous error"); + Map map = modifiableMap(Map.of("key", -1)); + var parsedInt = extractRequiredPositiveInteger(map, "not_key", "scope", validation); + + assertThat(validation.validationErrors(), hasSize(2)); + assertNull(parsedInt); + assertThat(validation.validationErrors().get(1), is("[scope] does not contain the required setting [not_key]")); + } + public void testExtractOptionalEnum_ReturnsNull_WhenFieldDoesNotExist() { var validation = new ValidationException(); Map map = modifiableMap(Map.of("key", "value")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java new file mode 100644 index 0000000000000..5e32344ab3840 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -0,0 +1,537 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicRequestUtils; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings; +import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettingsTests; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.buildExpectationCompletions; +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.getModelListenerForException; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; +import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class AnthropicServiceTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesACompletionModel() throws IOException { + var apiKey = "apiKey"; + var modelId = "model"; + + try (var service = createServiceWithMockSender()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); + + var completionModel = (AnthropicChatCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + new HashMap<>(Map.of(AnthropicServiceFields.MAX_TOKENS, 1)), + getSecretSettingsMap(apiKey) + ), + Set.of(), + modelListener + ); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createServiceWithMockSender()) { + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "The [anthropic] service does not support task type [sparse_embedding]" + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + new HashMap<>(Map.of()), + getSecretSettingsMap("secret") + ), + Set.of(), + failureListener + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createServiceWithMockSender()) { + var config = getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, null, null, null), + getSecretSettingsMap("secret") + ); + config.put("extra_key", "value"); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createServiceWithMockSender()) { + Map serviceSettings = new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + serviceSettings, + AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, null, null, null), + getSecretSettingsMap("api_key") + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createServiceWithMockSender()) { + var taskSettingsMap = AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, null, null, null); + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + taskSettingsMap, + getSecretSettingsMap("secret") + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createServiceWithMockSender()) { + Map secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, null, null, null), + secretSettings + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [anthropic] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createServiceWithMockSender()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); + + var completionModel = (AnthropicChatCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(new AnthropicChatCompletionTaskSettings(1, 1.0, 2.1, 3))); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createServiceWithMockSender()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3), + getSecretSettingsMap(apiKey) + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); + + var completionModel = (AnthropicChatCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(new AnthropicChatCompletionTaskSettings(1, 1.0, 2.1, 3))); + assertThat(completionModel.getSecretSettings().apiKey(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createServiceWithMockSender()) { + var secretSettingsMap = getSecretSettingsMap(apiKey); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); + + var completionModel = (AnthropicChatCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(new AnthropicChatCompletionTaskSettings(1, 1.0, 2.1, 3))); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createServiceWithMockSender()) { + Map serviceSettingsMap = new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); + + var completionModel = (AnthropicChatCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(new AnthropicChatCompletionTaskSettings(1, 1.0, 2.1, 3))); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createServiceWithMockSender()) { + Map taskSettings = AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3); + taskSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + taskSettings, + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); + + var completionModel = (AnthropicChatCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(new AnthropicChatCompletionTaskSettings(1, 1.0, 2.1, 3))); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfig_CreatesACompletionModel() throws IOException { + var modelId = "model"; + + try (var service = createServiceWithMockSender()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3) + ); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); + + var completionModel = (AnthropicChatCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(new AnthropicChatCompletionTaskSettings(1, 1.0, 2.1, 3))); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + var modelId = "model"; + + try (var service = createServiceWithMockSender()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3) + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); + + var completionModel = (AnthropicChatCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(new AnthropicChatCompletionTaskSettings(1, 1.0, 2.1, 3))); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + var modelId = "model"; + + try (var service = createServiceWithMockSender()) { + Map serviceSettingsMap = new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3) + ); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); + + var completionModel = (AnthropicChatCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(new AnthropicChatCompletionTaskSettings(1, 1.0, 2.1, 3))); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + var modelId = "model"; + + try (var service = createServiceWithMockSender()) { + Map taskSettings = AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, 1.0, 2.1, 3); + taskSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), taskSettings); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(AnthropicChatCompletionModel.class)); + + var completionModel = (AnthropicChatCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(new AnthropicChatCompletionTaskSettings(1, 1.0, 2.1, 3))); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new AnthropicService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_SendsCompletionRequest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) { + String responseJson = """ + { + "id": "msg_01XzZQmG41BMGe5NZ5p2vEWb", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content": [ + { + "type": "text", + "text": "result" + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 326 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AnthropicChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model", 1); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("input"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletions(List.of("result")))); + var request = webServer.requests().get(0); + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType())); + assertThat(request.getHeader(AnthropicRequestUtils.X_API_KEY), Matchers.equalTo("secret")); + assertThat( + request.getHeader(AnthropicRequestUtils.VERSION), + Matchers.equalTo(AnthropicRequestUtils.ANTHROPIC_VERSION_2023_06_01) + ); + + var requestMap = entityAsMap(request.getBody()); + assertThat( + requestMap, + is(Map.of("messages", List.of(Map.of("role", "user", "content", "input")), "model", "model", "max_tokens", 1)) + ); + } + } + + private AnthropicService createServiceWithMockSender() { + return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionModelTests.java new file mode 100644 index 0000000000000..85535b1400b86 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionModelTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class AnthropicChatCompletionModelTests extends ESTestCase { + + public void testOverrideWith_OverridesMaxInput() { + var model = createChatCompletionModel("url", "api_key", "model_name", 0); + var requestTaskSettingsMap = AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap(1, null, null, null); + + var overriddenModel = AnthropicChatCompletionModel.of(model, requestTaskSettingsMap); + + assertThat(overriddenModel, is(createChatCompletionModel("url", "api_key", "model_name", 1))); + } + + public void testOverrideWith_EmptyMap() { + var model = createChatCompletionModel("url", "api_key", "model_name", 0); + + var requestTaskSettingsMap = Map.of(); + + var overriddenModel = AnthropicChatCompletionModel.of(model, requestTaskSettingsMap); + assertThat(overriddenModel, sameInstance(model)); + } + + public void testOverrideWith_NullMap() { + var model = createChatCompletionModel("url", "api_key", "model_name", 0); + + var overriddenModel = AnthropicChatCompletionModel.of(model, null); + assertThat(overriddenModel, sameInstance(model)); + } + + public static AnthropicChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelName, int maxTokens) { + return new AnthropicChatCompletionModel( + "id", + TaskType.COMPLETION, + "service", + url, + new AnthropicChatCompletionServiceSettings(modelName, null), + new AnthropicChatCompletionTaskSettings(maxTokens, null, null, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static AnthropicChatCompletionModel createChatCompletionModel(String apiKey, String modelName, int maxTokens) { + return new AnthropicChatCompletionModel( + "id", + TaskType.COMPLETION, + "service", + new AnthropicChatCompletionServiceSettings(modelName, null), + new AnthropicChatCompletionTaskSettings(maxTokens, null, null, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionRequestTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionRequestTaskSettingsTests.java new file mode 100644 index 0000000000000..86a6b36947f25 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionRequestTaskSettingsTests.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic.completion; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicServiceFields; + +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap; +import static org.hamcrest.Matchers.is; + +public class AnthropicChatCompletionRequestTaskSettingsTests extends ESTestCase { + + public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() { + var settings = AnthropicChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of())); + assertNull(settings.maxTokens()); + } + + public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() { + var settings = AnthropicChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "value"))); + assertNull(settings.maxTokens()); + } + + public void testFromMap_ReturnsMaxTokens() { + var settings = AnthropicChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(AnthropicServiceFields.MAX_TOKENS, 1))); + assertThat(settings.maxTokens(), is(1)); + } + + public void testFromMap_ReturnsAllValues() { + var settings = AnthropicChatCompletionRequestTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1, -1.1, 0.1, 1)); + assertThat(settings.maxTokens(), is(1)); + assertThat(settings.temperature(), is(-1.1)); + assertThat(settings.topP(), is(0.1)); + assertThat(settings.topK(), is(1)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..11c2cd56e7955 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionServiceSettingsTests.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AnthropicChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase< + AnthropicChatCompletionServiceSettings> { + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var modelId = "some model"; + + var serviceSettings = AnthropicChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new AnthropicChatCompletionServiceSettings(modelId, null))); + } + + public void testFromMap_Request_CreatesSettingsCorrectly_WithRateLimit() { + var modelId = "some model"; + var rateLimit = 2; + var serviceSettings = AnthropicChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, rateLimit)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new AnthropicChatCompletionServiceSettings(modelId, new RateLimitSettings(2)))); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = new AnthropicChatCompletionServiceSettings("model", null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"model_id":"model","rate_limit":{"requests_per_minute":50}}""")); + } + + public void testToXContent_WritesAllValues_WithCustomRateLimit() throws IOException { + var serviceSettings = new AnthropicChatCompletionServiceSettings("model", new RateLimitSettings(2)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"model_id":"model","rate_limit":{"requests_per_minute":2}}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return AnthropicChatCompletionServiceSettings::new; + } + + @Override + protected AnthropicChatCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AnthropicChatCompletionServiceSettings mutateInstance(AnthropicChatCompletionServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, AnthropicChatCompletionServiceSettingsTests::createRandom); + } + + private static AnthropicChatCompletionServiceSettings createRandom() { + var modelId = randomAlphaOfLength(8); + + return new AnthropicChatCompletionServiceSettings(modelId, RateLimitSettingsTests.createRandom()); + } + + @Override + protected AnthropicChatCompletionServiceSettings mutateInstanceForVersion( + AnthropicChatCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettingsTests.java new file mode 100644 index 0000000000000..78762af6eee8c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/completion/AnthropicChatCompletionTaskSettingsTests.java @@ -0,0 +1,130 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.anthropic.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.anthropic.AnthropicServiceFields; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AnthropicChatCompletionTaskSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static Map getChatCompletionTaskSettingsMap( + @Nullable Integer maxTokens, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Integer topK + ) { + var map = new HashMap(); + + if (maxTokens != null) { + map.put(AnthropicServiceFields.MAX_TOKENS, maxTokens); + } + + if (temperature != null) { + map.put(AnthropicServiceFields.TEMPERATURE_FIELD, temperature); + } + + if (topP != null) { + map.put(AnthropicServiceFields.TOP_P_FIELD, topP); + } + + if (topK != null) { + map.put(AnthropicServiceFields.TOP_K_FIELD, topK); + } + + return map; + } + + public static AnthropicChatCompletionTaskSettings createRandom() { + return new AnthropicChatCompletionTaskSettings(randomNonNegativeInt(), randomDouble(), randomDouble(), randomInt()); + } + + public void testFromMap_WithMaxTokens() { + assertEquals( + new AnthropicChatCompletionTaskSettings(1, null, null, null), + AnthropicChatCompletionTaskSettings.fromMap( + getChatCompletionTaskSettingsMap(1, null, null, null), + ConfigurationParseContext.REQUEST + ) + ); + } + + public void testFromMap_AllValues() { + assertEquals( + new AnthropicChatCompletionTaskSettings(1, -1.1, 2.2, 3), + AnthropicChatCompletionTaskSettings.fromMap( + getChatCompletionTaskSettingsMap(1, -1.1, 2.2, 3), + ConfigurationParseContext.REQUEST + ) + ); + } + + public void testFromMap_WithoutMaxTokens_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> AnthropicChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of()), ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + is("Validation Failed: 1: [task_settings] does not contain the required setting [max_tokens];") + ); + } + + public void testOf_KeepsOriginalValuesWithOverridesAreEmpty() { + var taskSettings = new AnthropicChatCompletionTaskSettings(1, null, null, null); + + var overriddenTaskSettings = AnthropicChatCompletionTaskSettings.of( + taskSettings, + AnthropicChatCompletionRequestTaskSettings.EMPTY_SETTINGS + ); + assertThat(overriddenTaskSettings, is(taskSettings)); + } + + public void testOf_UsesOverriddenSettings() { + var taskSettings = new AnthropicChatCompletionTaskSettings(1, -1.2, 2.1, 3); + + var requestTaskSettings = new AnthropicChatCompletionRequestTaskSettings(2, 3.0, 4.0, 4); + + var overriddenTaskSettings = AnthropicChatCompletionTaskSettings.of(taskSettings, requestTaskSettings); + assertThat(overriddenTaskSettings, is(new AnthropicChatCompletionTaskSettings(2, 3.0, 4.0, 4))); + } + + @Override + protected AnthropicChatCompletionTaskSettings mutateInstanceForVersion( + AnthropicChatCompletionTaskSettings instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AnthropicChatCompletionTaskSettings::new; + } + + @Override + protected AnthropicChatCompletionTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AnthropicChatCompletionTaskSettings mutateInstance(AnthropicChatCompletionTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettingsTests.java new file mode 100644 index 0000000000000..95d3522b863a9 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettingsTests.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class GoogleVertexAiSecretSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static GoogleVertexAiSecretSettings createRandom() { + return new GoogleVertexAiSecretSettings(randomSecureStringOfLength(30)); + } + + public void testFromMap_ReturnsNull_WhenMapIsNUll() { + assertNull(GoogleVertexAiSecretSettings.fromMap(null)); + } + + public void testFromMap_ThrowsError_IfServiceAccountJsonIsMissing() { + expectThrows(ValidationException.class, () -> GoogleVertexAiSecretSettings.fromMap(new HashMap<>())); + } + + public void testFromMap_ThrowsError_IfServiceAccountJsonIsEmpty() { + expectThrows( + ValidationException.class, + () -> GoogleVertexAiSecretSettings.fromMap(new HashMap<>(Map.of(GoogleVertexAiSecretSettings.SERVICE_ACCOUNT_JSON, ""))) + ); + } + + public void testToXContent_WritesServiceAccountJson() throws IOException { + var secretSettings = new GoogleVertexAiSecretSettings(new SecureString("json")); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + secretSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"service_account_json":"json"}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return GoogleVertexAiSecretSettings::new; + } + + @Override + protected GoogleVertexAiSecretSettings createTestInstance() { + return createRandom(); + } + + @Override + protected GoogleVertexAiSecretSettings mutateInstance(GoogleVertexAiSecretSettings instance) throws IOException { + return randomValueOtherThan(instance, GoogleVertexAiSecretSettingsTests::createRandom); + } + + @Override + protected GoogleVertexAiSecretSettings mutateInstanceForVersion(GoogleVertexAiSecretSettings instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java new file mode 100644 index 0000000000000..a8e1dd3997ca0 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -0,0 +1,556 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class GoogleVertexAiServiceTests extends ESTestCase { + + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesGoogleVertexAiEmbeddingsModel() throws IOException { + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + }, e -> fail("Model parsing should succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId + ) + ), + new HashMap<>(Map.of()), + getSecretSettingsMap(serviceAccountJson) + ), + Set.of(), + modelListener + ); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createGoogleVertexAiService()) { + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "The [googlevertexai] service does not support task type [sparse_embedding]" + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + "model", + GoogleVertexAiServiceFields.LOCATION, + "location", + GoogleVertexAiServiceFields.PROJECT_ID, + "project" + ) + ), + new HashMap<>(Map.of()), + getSecretSettingsMap("{}") + ), + Set.of(), + failureListener + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createGoogleVertexAiService()) { + var config = getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + "model", + GoogleVertexAiServiceFields.LOCATION, + "location", + GoogleVertexAiServiceFields.PROJECT_ID, + "project" + ) + ), + getTaskSettingsMap(true), + getSecretSettingsMap("{}") + ); + config.put("extra_key", "value"); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createGoogleVertexAiService()) { + Map serviceSettings = new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + "model", + GoogleVertexAiServiceFields.LOCATION, + "location", + GoogleVertexAiServiceFields.PROJECT_ID, + "project" + ) + ); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(serviceSettings, getTaskSettingsMap(true), getSecretSettingsMap("{}")); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createGoogleVertexAiService()) { + Map taskSettingsMap = new HashMap<>(); + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + "model", + GoogleVertexAiServiceFields.LOCATION, + "location", + GoogleVertexAiServiceFields.PROJECT_ID, + "project" + ) + ), + taskSettingsMap, + getSecretSettingsMap("{}") + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createGoogleVertexAiService()) { + Map secretSettings = getSecretSettingsMap("{}"); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + "model", + GoogleVertexAiServiceFields.LOCATION, + "location", + GoogleVertexAiServiceFields.PROJECT_ID, + "project" + ) + ), + getTaskSettingsMap(true), + secretSettings + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googlevertexai] service" + ); + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), failureListener); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiEmbeddingsModel() throws IOException { + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + getSecretSettingsMap(serviceAccountJson) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + getSecretSettingsMap(serviceAccountJson) + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var secretSettingsMap = getSecretSettingsMap(serviceAccountJson); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + "model", + GoogleVertexAiServiceFields.LOCATION, + "location", + GoogleVertexAiServiceFields.PROJECT_ID, + "project", + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var serviceSettingsMap = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + "model", + GoogleVertexAiServiceFields.LOCATION, + "location", + GoogleVertexAiServiceFields.PROJECT_ID, + "project", + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + getTaskSettingsMap(autoTruncate), + getSecretSettingsMap(serviceAccountJson) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var taskSettings = getTaskSettingsMap(autoTruncate); + taskSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + "model", + GoogleVertexAiServiceFields.LOCATION, + "location", + GoogleVertexAiServiceFields.PROJECT_ID, + "project", + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + taskSettings, + getSecretSettingsMap(serviceAccountJson) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + } + } + + // testInfer tested via end-to-end notebook tests in AppEx repo + + private GoogleVertexAiService createGoogleVertexAiService() { + return new GoogleVertexAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + // TODO: deduplicate + private PersistedConfig getPersistedConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + + return new PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) + ); + } + + private PersistedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) { + return new PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + null + ); + } + + private record PersistedConfig(Map config, Map secrets) {} + + private static Map getSecretSettingsMap(String serviceAccountJson) { + return new HashMap<>(Map.of(GoogleVertexAiSecretSettings.SERVICE_ACCOUNT_JSON, serviceAccountJson)); + } + + private static ActionListener getModelListenerForException(Class exceptionClass, String expectedMessage) { + return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { + assertThat(e, Matchers.instanceOf(exceptionClass)); + assertThat(e.getMessage(), CoreMatchers.is(expectedMessage)); + }); + } + + private static Map getTaskSettingsMap(Boolean autoTruncate) { + var taskSettings = new HashMap(); + + taskSettings.put(GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, autoTruncate); + + return taskSettings; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java new file mode 100644 index 0000000000000..ca38bdb6e2c6c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; + +import static org.hamcrest.Matchers.is; + +public class GoogleVertexAiEmbeddingsModelTests extends ESTestCase { + + public void testBuildUri() throws URISyntaxException { + var location = "location"; + var projectId = "project"; + var modelId = "model"; + + URI uri = GoogleVertexAiEmbeddingsModel.buildUri(location, projectId, modelId); + + assertThat( + uri, + is( + new URI( + Strings.format( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", + location, + projectId, + location, + modelId + ) + ) + ) + ); + } + + public static GoogleVertexAiEmbeddingsModel createModel( + String location, + String projectId, + String modelId, + String uri, + String serviceAccountJson + ) { + return new GoogleVertexAiEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + uri, + new GoogleVertexAiEmbeddingsServiceSettings(location, projectId, modelId, false, null, null, null, null), + new GoogleVertexAiEmbeddingsTaskSettings(Boolean.FALSE), + new GoogleVertexAiSecretSettings(new SecureString(serviceAccountJson.toCharArray())) + ); + } + + public static GoogleVertexAiEmbeddingsModel createModel(String modelId, @Nullable Boolean autoTruncate) { + return new GoogleVertexAiEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + new GoogleVertexAiEmbeddingsServiceSettings( + randomAlphaOfLength(8), + randomAlphaOfLength(8), + modelId, + false, + null, + null, + SimilarityMeasure.DOT_PRODUCT, + null + ), + new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate), + new GoogleVertexAiSecretSettings(new SecureString(randomAlphaOfLength(8).toCharArray())) + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..2b8630ec7e041 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettingsTests.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; +import static org.hamcrest.Matchers.is; + +public class GoogleVertexAiEmbeddingsServiceSettingsTests extends AbstractBWCWireSerializationTestCase< + GoogleVertexAiEmbeddingsServiceSettings> { + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var location = randomAlphaOfLength(8); + var projectId = randomAlphaOfLength(8); + var model = randomAlphaOfLength(8); + var dimensionsSetByUser = randomBoolean(); + var maxInputTokens = randomFrom(new Integer[] { null, randomNonNegativeInt() }); + var similarityMeasure = randomFrom(new SimilarityMeasure[] { null, randomFrom(SimilarityMeasure.values()) }); + var similarityMeasureString = similarityMeasure == null ? null : similarityMeasure.toString(); + var dims = randomFrom(new Integer[] { null, randomNonNegativeInt() }); + var configurationParseContext = ConfigurationParseContext.PERSISTENT; + + var serviceSettings = GoogleVertexAiEmbeddingsServiceSettings.fromMap(new HashMap<>() { + { + put(GoogleVertexAiServiceFields.LOCATION, location); + put(GoogleVertexAiServiceFields.PROJECT_ID, projectId); + put(ServiceFields.MODEL_ID, model); + put(GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + put(ServiceFields.SIMILARITY, similarityMeasureString); + put(ServiceFields.DIMENSIONS, dims); + } + }, configurationParseContext); + + assertThat( + serviceSettings, + is( + new GoogleVertexAiEmbeddingsServiceSettings( + location, + projectId, + model, + dimensionsSetByUser, + maxInputTokens, + dims, + similarityMeasure, + null + ) + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new GoogleVertexAiEmbeddingsServiceSettings( + "location", + "projectId", + "modelId", + true, + 10, + 10, + SimilarityMeasure.DOT_PRODUCT, + null + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "location": "location", + "project_id": "projectId", + "model_id": "modelId", + "max_input_tokens": 10, + "dimensions": 10, + "similarity": "dot_product", + "rate_limit": { + "requests_per_minute": 30000 + }, + "dimensions_set_by_user": true + } + """)); + } + + public void testFilteredXContentObject_WritesAllValues_ExceptDimensionsSetByUser() throws IOException { + var entity = new GoogleVertexAiEmbeddingsServiceSettings( + "location", + "projectId", + "modelId", + true, + 10, + 10, + SimilarityMeasure.DOT_PRODUCT, + null + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + var filteredXContent = entity.getFilteredXContentObject(); + filteredXContent.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "location": "location", + "project_id": "projectId", + "model_id": "modelId", + "max_input_tokens": 10, + "dimensions": 10, + "similarity": "dot_product", + "rate_limit": { + "requests_per_minute": 30000 + } + } + """)); + } + + @Override + protected Writeable.Reader instanceReader() { + return GoogleVertexAiEmbeddingsServiceSettings::new; + } + + @Override + protected GoogleVertexAiEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected GoogleVertexAiEmbeddingsServiceSettings mutateInstance(GoogleVertexAiEmbeddingsServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, GoogleVertexAiEmbeddingsServiceSettingsTests::createRandom); + } + + @Override + protected GoogleVertexAiEmbeddingsServiceSettings mutateInstanceForVersion( + GoogleVertexAiEmbeddingsServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + private static GoogleVertexAiEmbeddingsServiceSettings createRandom() { + return new GoogleVertexAiEmbeddingsServiceSettings( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomBoolean(), + randomFrom(new Integer[] { null, randomNonNegativeInt() }), + randomFrom(new Integer[] { null, randomNonNegativeInt() }), + randomFrom(new SimilarityMeasure[] { null, randomFrom(SimilarityMeasure.values()) }), + randomFrom(new RateLimitSettings[] { null, RateLimitSettingsTests.createRandom() }) + ); + } +} diff --git a/x-pack/plugin/mapper-constant-keyword/src/yamlRestTest/resources/rest-api-spec/test/30_sort.yml b/x-pack/plugin/mapper-constant-keyword/src/yamlRestTest/resources/rest-api-spec/test/30_sort.yml new file mode 100644 index 0000000000000..8d489b8211eb1 --- /dev/null +++ b/x-pack/plugin/mapper-constant-keyword/src/yamlRestTest/resources/rest-api-spec/test/30_sort.yml @@ -0,0 +1,65 @@ +setup: + - do: + indices.create: + index: test + body: + mappings: + properties: + keyword: + type: keyword + + - do: + indices.create: + index: test_numeric + body: + mappings: + properties: + keyword: + type: long + + - do: + indices.create: + index: test_constant + body: + mappings: + properties: + keyword: + type: constant_keyword + value: value + + - do: + bulk: + refresh: true + body: | + { "index": {"_index" : "test", "_id": 3} } + { "keyword": "abc" } + { "index": {"_index" : "test_numeric", "_id": 2} } + { "keyword": 42 } + { "index": {"_index" : "test_constant", "_id": 1} } + {} + +--- +"constant_keyword mixed sort": + - do: + search: + index: test,test_constant + body: + sort: keyword + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.1._id: "1" } + +--- +"constant_keyword invalid mixed sort": + - requires: + cluster_features: [ "gte_v8.15.0" ] + reason: Better error message in 8.15.0 + + - do: + catch: /Can't sort on field \[keyword\]\; the field has incompatible sort types/ + search: + index: test* + body: + sort: keyword + diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java index 6b6ab43e10c58..c28fc8f44c3fa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java @@ -62,10 +62,18 @@ static InferenceResults processResult( if (chunkResults) { var embeddings = new ArrayList(); for (int i = 0; i < pyTorchResult.getInferenceResult()[0].length; i++) { - int startOffset = tokenization.getTokenization(i).tokens().get(0).get(0).startOffset(); - int lastIndex = tokenization.getTokenization(i).tokens().get(0).size() - 1; - int endOffset = tokenization.getTokenization(i).tokens().get(0).get(lastIndex).endOffset(); - String matchedText = tokenization.getTokenization(i).input().get(0).substring(startOffset, endOffset); + String matchedText; + if (tokenization.getTokenization(i).tokens().get(0).isEmpty() == false) { + int startOffset = tokenization.getTokenization(i).tokens().get(0).get(0).startOffset(); + int lastIndex = tokenization.getTokenization(i).tokens().get(0).size() - 1; + int endOffset = tokenization.getTokenization(i).tokens().get(0).get(lastIndex).endOffset(); + matchedText = tokenization.getTokenization(i).input().get(0).substring(startOffset, endOffset); + + } else { + // No tokens in the input, this should only happen with and empty string + assert tokenization.getTokenization(i).input().get(0).isEmpty(); + matchedText = ""; + } embeddings.add( new MlChunkedTextEmbeddingFloatResults.EmbeddingChunk(matchedText, pyTorchResult.getInferenceResult()[0][i]) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java index 3939bbef4052a..2efeb7e6564f3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessor.java @@ -75,10 +75,17 @@ static InferenceResults processResult( var chunkedResults = new ArrayList(); for (int i = 0; i < pyTorchResult.getInferenceResult()[0].length; i++) { - int startOffset = tokenization.getTokenization(i).tokens().get(0).get(0).startOffset(); - int lastIndex = tokenization.getTokenization(i).tokens().get(0).size() - 1; - int endOffset = tokenization.getTokenization(i).tokens().get(0).get(lastIndex).endOffset(); - String matchedText = tokenization.getTokenization(i).input().get(0).substring(startOffset, endOffset); + String matchedText; + if (tokenization.getTokenization(i).tokens().get(0).isEmpty() == false) { + int startOffset = tokenization.getTokenization(i).tokens().get(0).get(0).startOffset(); + int lastIndex = tokenization.getTokenization(i).tokens().get(0).size() - 1; + int endOffset = tokenization.getTokenization(i).tokens().get(0).get(lastIndex).endOffset(); + matchedText = tokenization.getTokenization(i).input().get(0).substring(startOffset, endOffset); + } else { + // No tokens in the input, this should only happen with and empty string + assert tokenization.getTokenization(i).input().get(0).isEmpty(); + matchedText = ""; + } var weightedTokens = sparseVectorToTokenWeights(pyTorchResult.getInferenceResult()[0][i], tokenization, replacementVocab); weightedTokens.sort((t1, t2) -> Float.compare(t2.weight(), t1.weight())); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java index bba2844784117..8369412580b88 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessorTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; @@ -16,9 +17,13 @@ import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer; import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult; +import java.util.Map; + +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.core.IsNot.not; public class TextEmbeddingProcessorTests extends ESTestCase { @@ -67,4 +72,26 @@ public void testChunking() { assertThat(chunkedResult.getChunks().get(1).embedding().length, greaterThan(0)); } } + + public void testChunkingWithEmptyString() { + try ( + BertTokenizer tokenizer = BertTokenizer.builder( + TextExpansionProcessorTests.TEST_CASED_VOCAB, + new BertTokenization(null, false, 5, Tokenization.Truncate.NONE, 0) + ).build() + ) { + var pytorchResult = new PyTorchInferenceResult(new double[][][] { { { 1.0, 2.0, 3.0, 4.0, 5.0 } } }); + + var input = ""; + var tokenization = tokenizer.tokenize(input, Tokenization.Truncate.NONE, 0, 0, null); + var tokenizationResult = new BertTokenizationResult(TextExpansionProcessorTests.TEST_CASED_VOCAB, tokenization, 0); + var inferenceResult = TextExpansionProcessor.processResult(tokenizationResult, pytorchResult, Map.of(), "foo", true); + assertThat(inferenceResult, instanceOf(MlChunkedTextExpansionResults.class)); + + var chunkedResult = (MlChunkedTextExpansionResults) inferenceResult; + assertThat(chunkedResult.getChunks(), hasSize(1)); + assertEquals("", chunkedResult.getChunks().get(0).matchedText()); + assertThat(chunkedResult.getChunks().get(0).weightedTokens(), not(empty())); + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java index 9803467644db9..1991275dbd2f7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextExpansionProcessorTests.java @@ -147,4 +147,26 @@ public void testChunking() { assertThat(chunkedResult.getChunks().get(1).weightedTokens(), not(empty())); } } + + public void testChunkingWithEmptyString() { + try ( + BertTokenizer tokenizer = BertTokenizer.builder( + TEST_CASED_VOCAB, + new BertTokenization(null, false, 5, Tokenization.Truncate.NONE, 0) + ).build() + ) { + var pytorchResult = new PyTorchInferenceResult(new double[][][] { { { 1.0, 2.0, 3.0, 4.0, 5.0 } } }); + + var input = ""; + var tokenization = tokenizer.tokenize(input, Tokenization.Truncate.NONE, 0, 0, null); + var tokenizationResult = new BertTokenizationResult(TEST_CASED_VOCAB, tokenization, 0); + var inferenceResult = TextExpansionProcessor.processResult(tokenizationResult, pytorchResult, Map.of(), "foo", true); + assertThat(inferenceResult, instanceOf(MlChunkedTextExpansionResults.class)); + + var chunkedResult = (MlChunkedTextExpansionResults) inferenceResult; + assertThat(chunkedResult.getChunks(), hasSize(1)); + assertEquals("", chunkedResult.getChunks().get(0).matchedText()); + assertThat(chunkedResult.getChunks().get(0).weightedTokens(), not(empty())); + } + } } diff --git a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/HostMetadata.java b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/HostMetadata.java index 29f3b66956d55..acfcd228b731e 100644 --- a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/HostMetadata.java +++ b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/HostMetadata.java @@ -25,9 +25,9 @@ final class HostMetadata implements ToXContentObject { final int profilingNumCores; // number of cores on the profiling host machine HostMetadata(String hostID, InstanceType instanceType, String hostArchitecture, Integer profilingNumCores) { - this.hostID = hostID; - this.instanceType = instanceType; - this.hostArchitecture = hostArchitecture; + this.hostID = hostID != null ? hostID : ""; + this.instanceType = instanceType != null ? instanceType : new InstanceType("", "", ""); + this.hostArchitecture = hostArchitecture != null ? hostArchitecture : ""; this.profilingNumCores = profilingNumCores != null ? profilingNumCores : DEFAULT_PROFILING_NUM_CORES; } diff --git a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/persistence/ProfilingIndexTemplateRegistry.java b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/persistence/ProfilingIndexTemplateRegistry.java index d3af7402b7da6..282cded9418cc 100644 --- a/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/persistence/ProfilingIndexTemplateRegistry.java +++ b/x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/persistence/ProfilingIndexTemplateRegistry.java @@ -51,7 +51,8 @@ public class ProfilingIndexTemplateRegistry extends IndexTemplateRegistry { // version 8: Changed from disabled _source to synthetic _source for profiling-events-* and profiling-metrics // version 9: Changed sort order for profiling-events-* // version 10: changed mapping profiling-events @timestamp to 'date_nanos' from 'date' - public static final int INDEX_TEMPLATE_VERSION = 10; + // version 11: Added 'profiling.agent.protocol' keyword mapping to profiling-hosts + public static final int INDEX_TEMPLATE_VERSION = 11; // history for individual indices / index templates. Only bump these for breaking changes that require to create a new index public static final int PROFILING_EVENTS_VERSION = 4; diff --git a/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/CO2CalculatorTests.java b/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/CO2CalculatorTests.java index a7b9a97b71acc..ff698465a56c5 100644 --- a/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/CO2CalculatorTests.java +++ b/x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/CO2CalculatorTests.java @@ -82,6 +82,41 @@ public void testCreateFromRegularSource() { checkCO2Calculation(co2Calculator.getAnnualCO2Tons(HOST_ID_D, samples), annualCoreHours, 1.7d, 0.000379069d, 2.8d); } + // Make sure that malformed data doesn't cause the CO2 calculation to fail. + public void testCreateFromMalformedSource() { + // tag::noformat + Map hostsTable = Map.ofEntries( + Map.entry(HOST_ID_A, + // known datacenter and instance type + new HostMetadata(HOST_ID_A, + new InstanceType( + "aws", + "eu-west-1", + "c5n.xlarge" + ), + null, + null + ) + ), + Map.entry(HOST_ID_B, + new HostMetadata(HOST_ID_B, + null, + null, + null + ) + ) + ); + // end::noformat + + double samplingDurationInSeconds = 1_800.0d; // 30 minutes + long samples = 100_000L; // 100k samples + double annualCoreHours = CostCalculator.annualCoreHours(samplingDurationInSeconds, samples, 20.0d); + CO2Calculator co2Calculator = new CO2Calculator(hostsTable, samplingDurationInSeconds, null, null, null, null); + + checkCO2Calculation(co2Calculator.getAnnualCO2Tons(HOST_ID_A, samples), annualCoreHours, 1.135d, 0.0002786d, 7.0d); + checkCO2Calculation(co2Calculator.getAnnualCO2Tons(HOST_ID_B, samples), annualCoreHours, 1.7d, 0.000379069d, 7.0d); + } + private void checkCO2Calculation( double calculatedAnnualCO2Tons, double annualCoreHours, diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsIntegTests.java index c738033761b3e..56aec13cbab29 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsIntegTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsIntegTests.java @@ -145,18 +145,17 @@ public void testCreateAndRestoreSearchableSnapshot() throws Exception { assertShardFolders(indexName, false); - assertThat( - clusterAdmin().prepareState() - .clear() - .setMetadata(true) - .setIndices(indexName) - .get() - .getState() - .metadata() - .index(indexName) - .getTimestampRange(), - sameInstance(IndexLongFieldRange.UNKNOWN) - ); + IndexMetadata indexMetadata = clusterAdmin().prepareState() + .clear() + .setMetadata(true) + .setIndices(indexName) + .get() + .getState() + .metadata() + .index(indexName); + + assertThat(indexMetadata.getTimestampRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); + assertThat(indexMetadata.getEventIngestedRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); final boolean deletedBeforeMount = randomBoolean(); if (deletedBeforeMount) { @@ -252,18 +251,17 @@ public void testCreateAndRestoreSearchableSnapshot() throws Exception { ensureGreen(restoredIndexName); assertBusy(() -> assertShardFolders(restoredIndexName, true), 30, TimeUnit.SECONDS); - assertThat( - clusterAdmin().prepareState() - .clear() - .setMetadata(true) - .setIndices(restoredIndexName) - .get() - .getState() - .metadata() - .index(restoredIndexName) - .getTimestampRange(), - sameInstance(IndexLongFieldRange.UNKNOWN) - ); + indexMetadata = clusterAdmin().prepareState() + .clear() + .setMetadata(true) + .setIndices(restoredIndexName) + .get() + .getState() + .metadata() + .index(restoredIndexName); + + assertThat(indexMetadata.getTimestampRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); + assertThat(indexMetadata.getEventIngestedRange(), sameInstance(IndexLongFieldRange.UNKNOWN)); if (deletedBeforeMount) { assertThat(indicesAdmin().prepareGetAliases(aliasName).get().getAliases().size(), equalTo(0)); @@ -684,21 +682,29 @@ public void testSnapshotMountedIndexLeavesBlobsUntouched() throws Exception { public void testSnapshotMountedIndexWithTimestampsRecordsTimestampRangeInIndexMetadata() throws Exception { final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); - final int numShards = between(1, 3); + int numShards = between(1, 3); boolean indexed = randomBoolean(); - final String dateType = randomFrom("date", "date_nanos"); + String dateType = randomFrom("date", "date_nanos"); assertAcked( indicesAdmin().prepareCreate(indexName) .setMapping( XContentFactory.jsonBuilder() .startObject() .startObject("properties") + .startObject(DataStream.TIMESTAMP_FIELD_NAME) .field("type", dateType) .field("index", indexed) .field("format", "strict_date_optional_time_nanos") .endObject() + + .startObject(IndexMetadata.EVENT_INGESTED_FIELD_NAME) + .field("type", dateType) + .field("index", indexed) + .field("format", "strict_date_optional_time_nanos") + .endObject() + .endObject() .endObject() ) @@ -712,6 +718,15 @@ public void testSnapshotMountedIndexWithTimestampsRecordsTimestampRangeInIndexMe indexRequestBuilders.add( prepareIndex(indexName).setSource( DataStream.TIMESTAMP_FIELD_NAME, + String.format( + Locale.ROOT, + "2020-11-26T%02d:%02d:%02d.%09dZ", + between(0, 23), + between(0, 59), + between(0, 59), + randomLongBetween(0, 999999999L) + ), + IndexMetadata.EVENT_INGESTED_FIELD_NAME, String.format( Locale.ROOT, "2020-11-26T%02d:%02d:%02d.%09dZ", @@ -740,32 +755,45 @@ public void testSnapshotMountedIndexWithTimestampsRecordsTimestampRangeInIndexMe mountSnapshot(repositoryName, snapshotOne.getName(), indexName, indexName, Settings.EMPTY); ensureGreen(indexName); - final IndexLongFieldRange timestampRange = clusterAdmin().prepareState() + final IndexMetadata indexMetadata = clusterAdmin().prepareState() .clear() .setMetadata(true) .setIndices(indexName) .get() .getState() .metadata() - .index(indexName) - .getTimestampRange(); + .index(indexName); + final IndexLongFieldRange timestampRange = indexMetadata.getTimestampRange(); assertTrue(timestampRange.isComplete()); + final IndexLongFieldRange eventIngestedRange = indexMetadata.getEventIngestedRange(); + assertTrue(eventIngestedRange.isComplete()); + if (indexed) { assertThat(timestampRange, not(sameInstance(IndexLongFieldRange.UNKNOWN))); + assertThat(eventIngestedRange, not(sameInstance(IndexLongFieldRange.UNKNOWN))); if (docCount == 0) { assertThat(timestampRange, sameInstance(IndexLongFieldRange.EMPTY)); + assertThat(eventIngestedRange, sameInstance(IndexLongFieldRange.EMPTY)); } else { assertThat(timestampRange, not(sameInstance(IndexLongFieldRange.EMPTY))); + assertThat(eventIngestedRange, not(sameInstance(IndexLongFieldRange.EMPTY))); + + // both @timestamp and event.ingested have the same resolution in this test DateFieldMapper.Resolution resolution = dateType.equals("date") ? DateFieldMapper.Resolution.MILLISECONDS : DateFieldMapper.Resolution.NANOSECONDS; + assertThat(timestampRange.getMin(), greaterThanOrEqualTo(resolution.convert(Instant.parse("2020-11-26T00:00:00Z")))); assertThat(timestampRange.getMin(), lessThanOrEqualTo(resolution.convert(Instant.parse("2020-11-27T00:00:00Z")))); + + assertThat(eventIngestedRange.getMin(), greaterThanOrEqualTo(resolution.convert(Instant.parse("2020-11-26T00:00:00Z")))); + assertThat(eventIngestedRange.getMin(), lessThanOrEqualTo(resolution.convert(Instant.parse("2020-11-27T00:00:00Z")))); } } else { assertThat(timestampRange, sameInstance(IndexLongFieldRange.UNKNOWN)); + assertThat(eventIngestedRange, sameInstance(IndexLongFieldRange.UNKNOWN)); } } diff --git a/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/action/ReservedSnapshotLifecycleStateServiceTests.java b/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/action/ReservedSnapshotLifecycleStateServiceTests.java index ebce936cb385e..e863235e2bdb5 100644 --- a/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/action/ReservedSnapshotLifecycleStateServiceTests.java +++ b/x-pack/plugin/slm/src/test/java/org/elasticsearch/xpack/slm/action/ReservedSnapshotLifecycleStateServiceTests.java @@ -40,15 +40,21 @@ import org.elasticsearch.xpack.core.slm.SnapshotLifecycleMetadata; import org.elasticsearch.xpack.core.slm.action.DeleteSnapshotLifecycleAction; import org.elasticsearch.xpack.core.slm.action.PutSnapshotLifecycleAction; +import org.junit.Assert; import java.io.IOException; -import java.util.Collections; import java.util.List; +import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; @@ -69,7 +75,7 @@ private TransformState processJSON(ReservedSnapshotAction action, TransformState public void testDependencies() { var action = new ReservedSnapshotAction(); - assertTrue(action.optionalDependencies().contains(ReservedRepositoryAction.NAME)); + assertThat(action.optionalDependencies(), contains(ReservedRepositoryAction.NAME)); } public void testValidationFails() { @@ -79,7 +85,7 @@ public void testValidationFails() { ClusterState state = ClusterState.builder(clusterName).build(); ReservedSnapshotAction action = new ReservedSnapshotAction(); - TransformState prevState = new TransformState(state, Collections.emptySet()); + TransformState prevState = new TransformState(state, Set.of()); String badPolicyJSON = """ { @@ -100,9 +106,9 @@ public void testValidationFails() { } }"""; - assertEquals( - "Required [schedule]", - expectThrows(IllegalArgumentException.class, () -> processJSON(action, prevState, badPolicyJSON)).getMessage() + assertThat( + expectThrows(IllegalArgumentException.class, () -> processJSON(action, prevState, badPolicyJSON)).getMessage(), + is("Required [schedule]") ); } @@ -121,10 +127,10 @@ public void testActionAddRemove() throws Exception { String emptyJSON = ""; - TransformState prevState = new TransformState(state, Collections.emptySet()); + TransformState prevState = new TransformState(state, Set.of()); TransformState updatedState = processJSON(action, prevState, emptyJSON); - assertEquals(0, updatedState.keys().size()); + assertThat(updatedState.keys(), empty()); assertEquals(prevState.state(), updatedState.state()); String twoPoliciesJSON = """ @@ -337,9 +343,9 @@ public void testOperatorControllerFromJSONContent() throws IOException { AtomicReference x = new AtomicReference<>(); try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, testJSON)) { - controller.process("operator", parser, (e) -> x.set(e)); + controller.process("operator", parser, x::set); - assertTrue(x.get() instanceof IllegalStateException); + assertThat(x.get(), instanceOf(IllegalStateException.class)); assertThat(x.get().getMessage(), containsString("Error processing state change request for operator")); } @@ -357,11 +363,7 @@ public void testOperatorControllerFromJSONContent() throws IOException { ); try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, testJSON)) { - controller.process("operator", parser, (e) -> { - if (e != null) { - fail("Should not fail"); - } - }); + controller.process("operator", parser, Assert::assertNull); } } @@ -375,7 +377,7 @@ public void testDeleteSLMReservedStateHandler() { mock(ActionFilters.class), mock(IndexNameExpressionResolver.class) ); - assertEquals(ReservedSnapshotAction.NAME, deleteAction.reservedStateHandlerName().get()); + assertThat(deleteAction.reservedStateHandlerName().get(), equalTo(ReservedSnapshotAction.NAME)); var request = new DeleteSnapshotLifecycleAction.Request(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, "daily-snapshots1"); assertThat(deleteAction.modifiedKeys(request), containsInAnyOrder("daily-snapshots1")); @@ -391,7 +393,7 @@ public void testPutSLMReservedStateHandler() throws Exception { mock(ActionFilters.class), mock(IndexNameExpressionResolver.class) ); - assertEquals(ReservedSnapshotAction.NAME, putAction.reservedStateHandlerName().get()); + assertThat(putAction.reservedStateHandlerName().get(), equalTo(ReservedSnapshotAction.NAME)); String json = """ { diff --git a/x-pack/plugin/snapshot-based-recoveries/src/internalClusterTest/java/org/elasticsearch/xpack/snapshotbasedrecoveries/recovery/SnapshotBasedIndexRecoveryIT.java b/x-pack/plugin/snapshot-based-recoveries/src/internalClusterTest/java/org/elasticsearch/xpack/snapshotbasedrecoveries/recovery/SnapshotBasedIndexRecoveryIT.java index caae6dd393a0c..d2e5896a4cf77 100644 --- a/x-pack/plugin/snapshot-based-recoveries/src/internalClusterTest/java/org/elasticsearch/xpack/snapshotbasedrecoveries/recovery/SnapshotBasedIndexRecoveryIT.java +++ b/x-pack/plugin/snapshot-based-recoveries/src/internalClusterTest/java/org/elasticsearch/xpack/snapshotbasedrecoveries/recovery/SnapshotBasedIndexRecoveryIT.java @@ -15,7 +15,6 @@ import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotResponse; import org.elasticsearch.action.admin.indices.recovery.RecoveryResponse; import org.elasticsearch.action.admin.indices.stats.CommonStatsFlags; -import org.elasticsearch.action.admin.indices.stats.ShardStats; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; @@ -25,7 +24,6 @@ import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.RecoverySource; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.OperationPurpose; import org.elasticsearch.common.blobstore.support.FilterBlobContainer; @@ -1563,26 +1561,23 @@ private void indexDocs(String indexName, int docIdOffset, int docCount) throws E // Ensure that the safe commit == latest commit assertBusy(() -> { - ShardStats stats = indicesAdmin().prepareStats(indexName) - .clear() - .get() - .asMap() - .entrySet() - .stream() - .filter(e -> e.getKey().shardId().getId() == 0) - .map(Map.Entry::getValue) - .findFirst() - .orElse(null); - assertThat(stats, is(notNullValue())); - assertThat(stats.getSeqNoStats(), is(notNullValue())); - - assertThat(stats.getSeqNoStats().getMaxSeqNo(), is(greaterThan(-1L))); - assertThat(stats.getSeqNoStats().getGlobalCheckpoint(), is(greaterThan(-1L))); - assertThat( - Strings.toString(stats.getSeqNoStats()), - stats.getSeqNoStats().getMaxSeqNo(), - equalTo(stats.getSeqNoStats().getGlobalCheckpoint()) - ); + ClusterState clusterState = client().admin().cluster().prepareState().get().getState(); + var indexShardRoutingTable = clusterState.routingTable().index(indexName).shard(0); + assertThat(indexShardRoutingTable, is(notNullValue())); + + var assignedNodeId = indexShardRoutingTable.primaryShard().currentNodeId(); + var assignedNodeName = clusterState.nodes().resolveNode(assignedNodeId).getName(); + + var indexShard = internalCluster().getInstance(IndicesService.class, assignedNodeName) + .indexService(resolveIndex(indexName)) + .getShard(0); + assertThat(indexShard, is(notNullValue())); + + // The safe commit is determined using the last synced global checkpoint, hence we should wait until the translog is synced + // to cover cases where the translog is synced asynchronously + var lastSyncedGlobalCheckpoint = indexShard.getLastSyncedGlobalCheckpoint(); + var maxSeqNo = indexShard.seqNoStats().getMaxSeqNo(); + assertThat(lastSyncedGlobalCheckpoint, equalTo(maxSeqNo)); }, 60, TimeUnit.SECONDS); } diff --git a/x-pack/qa/xpack-prefix-rest-compat/src/yamlRestTestV7Compat/resources/rest-api-spec/api/xpack-ml.get_categories.json b/x-pack/qa/xpack-prefix-rest-compat/src/yamlRestTestV7Compat/resources/rest-api-spec/api/xpack-ml.get_categories.json index daf4d9bfc7889..4fce55f682248 100644 --- a/x-pack/qa/xpack-prefix-rest-compat/src/yamlRestTestV7Compat/resources/rest-api-spec/api/xpack-ml.get_categories.json +++ b/x-pack/qa/xpack-prefix-rest-compat/src/yamlRestTestV7Compat/resources/rest-api-spec/api/xpack-ml.get_categories.json @@ -34,7 +34,7 @@ } }, { - "path":"/_xpack/ml/anomaly_detectors/{job_id}/results/categories/", + "path":"/_xpack/ml/anomaly_detectors/{job_id}/results/categories", "methods":[ "GET", "POST"