Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixed nondeterminism in UdfIndex #7719

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@
* <li>If two methods exist that match given the above rules, and both
* have variable arguments, return the method with the more non-variable
* arguments.</li>
* <li>If two methods exist that match only null values, return the one
* that was added first.</li>
* <li>If two methods exist that match given the above rules, return the
* method with fewer generic arguments.</li>
* <li>If two methods exist that match given the above rules, an exception
* is thrown due to a vague function call</li>
Sullivan-Patrick marked this conversation as resolved.
Show resolved Hide resolved
* </ul>
*/
public class UdfIndex<T extends FunctionSignature> {
Expand Down Expand Up @@ -112,7 +114,6 @@ void addFunction(final T function) {
);
}

final int order = allFunctions.size();

Node curr = root;
Node parent = curr;
Expand All @@ -125,48 +126,58 @@ void addFunction(final T function) {
if (function.isVariadic()) {
// first add the function to the parent to address the
// case of empty varargs
parent.update(function, order);
parent.update(function);

// then add a new child node with the parameter value type
// and add this function to that node
final ParamType varargSchema = Iterables.getLast(parameters);
final Parameter vararg = new Parameter(varargSchema, true);
final Node leaf = parent.children.computeIfAbsent(vararg, ignored -> new Node());
leaf.update(function, order);
leaf.update(function);

// add a self referential loop for varargs so that we can
// add as many of the same param at the end and still retrieve
// this node
leaf.children.putIfAbsent(vararg, leaf);
}

curr.update(function, order);
curr.update(function);
}

T getFunction(final List<SqlArgument> arguments) {
final List<Node> candidates = new ArrayList<>();

// first try to get the candidates without any implicit casting
getCandidates(arguments, 0, root, candidates, new HashMap<>(), false);
final Optional<T> fun = candidates
.stream()
.max(Node::compare)
.map(node -> node.value);

if (fun.isPresent()) {
return fun.get();
Optional<T> candidate = findMatchingCandidate(candidates, arguments, false);
if (candidate.isPresent()) {
return candidate.get();
} else if (!supportsImplicitCasts) {
throw createNoMatchingFunctionException(arguments);
}
candidates.clear();
Sullivan-Patrick marked this conversation as resolved.
Show resolved Hide resolved

// if none were found (candidates is empty) try again with
// implicit casting
getCandidates(arguments, 0, root, candidates, new HashMap<>(), true);
return candidates
.stream()
.max(Node::compare)
.map(node -> node.value)
.orElseThrow(() -> createNoMatchingFunctionException(arguments));
//if non were found (candidate isn't present) try again with implicit casting
Sullivan-Patrick marked this conversation as resolved.
Show resolved Hide resolved
candidate = findMatchingCandidate(candidates, arguments, true);
if (candidate.isPresent()) {
return candidate.get();
}
throw createNoMatchingFunctionException(arguments);
}

private Optional<T> findMatchingCandidate(final List<Node> candidates,
final List<SqlArgument> arguments, final boolean allowCasts) {

getCandidates(arguments, 0, root, candidates, new HashMap<>(), allowCasts);
candidates.sort(Node::compare);

final int len = candidates.size();
if (len == 1) {
return Optional.of(candidates.get(0).value);
} else if (len > 1 && candidates.get(len - 1).compare(candidates.get(len - 2)) > 0) {
return Optional.of(candidates.get(len - 1).value);
}

return Optional.empty();
}

private void getCandidates(
Expand Down Expand Up @@ -273,17 +284,15 @@ private final class Node {

private final Map<Parameter, Node> children;
private T value;
private int order = 0;

private Node() {
this.children = new HashMap<>();
this.value = null;
}

private void update(final T function, final int order) {
private void update(final T function) {
if (compareFunctions.compare(function, value) > 0) {
value = function;
this.order = order;
}
}

Expand All @@ -307,8 +316,15 @@ public String toString() {
}

int compare(final Node other) {
final int compare = compareFunctions.compare(value, other.value);
return compare == 0 ? -(order - other.order) : compare;
final int compareVal = compareFunctions.compare(value, other.value);
return compareVal == 0 ? countGenerics(other) - countGenerics(this) : compareVal;
Sullivan-Patrick marked this conversation as resolved.
Show resolved Hide resolved
}

private int countGenerics(final Node node) {
return node.value.parameters().stream()
.filter(GenericsUtil::hasGenerics)
.mapToInt(p -> 1)
.sum();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static io.confluent.ksql.function.KsqlScalarFunction.INTERNAL_PATH;
import static io.confluent.ksql.function.types.ArrayType.of;
import static io.confluent.ksql.schema.ksql.types.SqlTypes.INTEGER;
import static io.confluent.ksql.schema.ksql.types.SqlTypes.BIGINT;
import static java.lang.System.lineSeparator;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
Expand Down Expand Up @@ -562,35 +563,6 @@ public void shouldFindNonVarargWithPartialNullValues() {
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldChooseFirstAddedWithNullValues() {
// Given:
givenFunctions(
function(EXPECTED, false, STRING),
function(OTHER, false, INT)
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(Collections.singletonList(null));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldFindVarargWithNullValues() {
// Given:
givenFunctions(
function(EXPECTED, true, STRING_VARARGS)
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(Arrays.asList(new SqlArgument[]{null}));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}
Sullivan-Patrick marked this conversation as resolved.
Show resolved Hide resolved

@Test
public void shouldFindVarargWithSomeNullValues() {
// Given:
Expand Down Expand Up @@ -964,6 +936,60 @@ public void shouldThrowIfNoExactMatchAndImplicitCastDisabled() {
+ "(INTEGER)"));
}

@Test
public void shouldThrowWhenGivenVagueImplicitCast() {
Sullivan-Patrick marked this conversation as resolved.
Show resolved Hide resolved
// Given:
givenFunctions(
function(FIRST_FUNC, false, LONG, LONG),
function(SECOND_FUNC, false, DOUBLE, DOUBLE)
);

// When:
final Exception e = assertThrows(Exception.class,
Sullivan-Patrick marked this conversation as resolved.
Show resolved Hide resolved
() -> udfIndex
.getFunction(ImmutableList.of(SqlArgument.of(INTEGER), SqlArgument.of(BIGINT))));

// Then:
assertThat(e.getMessage(), containsString("Function 'name' does not accept parameters "
+ "(INTEGER, BIGINT)"));
}

@Test
public void shouldFindFewerGenerics() {
// Given:
givenFunctions(
function(EXPECTED, false, INT, GenericType.of("A"), INT),
function(OTHER, false, INT, GenericType.of("A"), GenericType.of("B"))
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(ImmutableList
.of(SqlArgument.of(INTEGER), SqlArgument.of(INTEGER), SqlArgument.of(INTEGER)));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldThrowOnComparablyEqualFunctionsWithSameGenericCount() {
Sullivan-Patrick marked this conversation as resolved.
Show resolved Hide resolved
// Given:
givenFunctions(
function(FIRST_FUNC, false, LONG, GenericType.of("A"), GenericType.of("B")),
function(SECOND_FUNC, false, DOUBLE, GenericType.of("A"), GenericType.of("B"))
);

// When:
final Exception e = assertThrows(Exception.class,
() -> udfIndex
.getFunction(ImmutableList
.of(SqlArgument.of(INTEGER), SqlArgument.of(INTEGER), SqlArgument.of(INTEGER))));

// Then:
assertThat(e.getMessage(), containsString("Function 'name' does not accept parameters "
Sullivan-Patrick marked this conversation as resolved.
Show resolved Hide resolved
+ "(INTEGER, INTEGER, INTEGER)"));
}


private void givenFunctions(final KsqlScalarFunction... functions) {
Arrays.stream(functions).forEach(udfIndex::addFunction);
}
Expand Down