Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Scopes to resolve lambda argument references in ExpressionAnalyzer #9026

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
Expand Down Expand Up @@ -255,12 +254,12 @@ public Map<NodeRef<Identifier>, LambdaArgumentDeclaration> getLambdaArgumentRefe
public Type analyze(Expression expression, Scope scope)
{
Visitor visitor = new Visitor(scope);
return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(Context.notInLambda()));
return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(Context.notInLambda(scope)));
}

private Type analyze(Expression expression, Scope scope, Context context)
private Type analyze(Expression expression, Scope baseScope, Context context)
{
Visitor visitor = new Visitor(scope);
Visitor visitor = new Visitor(baseScope);
return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(context));
}

Expand All @@ -282,11 +281,11 @@ public Set<NodeRef<QuantifiedComparisonExpression>> getQuantifiedComparisons()
private class Visitor
extends StackableAstVisitor<Type, Context>
{
private final Scope scope;
private final Scope baseScope;

private Visitor(Scope scope)
public Visitor(Scope baseScope)
{
this.scope = requireNonNull(scope, "scope is null");
this.baseScope = requireNonNull(baseScope, "baseScope is null");
}

@SuppressWarnings("SuspiciousMethodCalls")
Expand Down Expand Up @@ -347,10 +346,9 @@ protected Type visitCurrentTime(CurrentTime node, StackableAstVisitorContext<Con
protected Type visitSymbolReference(SymbolReference node, StackableAstVisitorContext<Context> context)
{
if (context.getContext().isInLambda()) {
LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getNameToLambdaArgumentDeclarationMap().get(node.getName());
if (lambdaArgumentDeclaration != null) {
Type result = getExpressionType(lambdaArgumentDeclaration);
return setExpressionType(node, result);
Optional<ResolvedField> resolvedField = context.getContext().getScope().tryResolveField(node, QualifiedName.of(node.getName()));
if (resolvedField.isPresent() && context.getContext().getFieldToLambdaArgumentDeclaration().containsKey(FieldId.from(resolvedField.get()))) {
return setExpressionType(node, resolvedField.get().getType());
}
}
Type type = symbolTypes.get(Symbol.from(node));
Expand All @@ -360,24 +358,26 @@ protected Type visitSymbolReference(SymbolReference node, StackableAstVisitorCon
@Override
protected Type visitIdentifier(Identifier node, StackableAstVisitorContext<Context> context)
{
if (context.getContext().isInLambda()) {
LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getNameToLambdaArgumentDeclarationMap().get(node.getValue());
if (lambdaArgumentDeclaration != null) {
lambdaArgumentReferences.put(NodeRef.of(node), lambdaArgumentDeclaration);
Type result = getExpressionType(lambdaArgumentDeclaration);
return setExpressionType(node, result);
}
}
return handleResolvedField(node, scope.resolveField(node, QualifiedName.of(node.getValue())));
ResolvedField resolvedField = context.getContext().getScope().resolveField(node, QualifiedName.of(node.getValue()));
return handleResolvedField(node, resolvedField, context);
}

private Type handleResolvedField(Expression node, ResolvedField resolvedField)
private Type handleResolvedField(Expression node, ResolvedField resolvedField, StackableAstVisitorContext<Context> context)
{
return handleResolvedField(node, FieldId.from(resolvedField), resolvedField.getType());
return handleResolvedField(node, FieldId.from(resolvedField), resolvedField.getType(), context);
}

private Type handleResolvedField(Expression node, FieldId fieldId, Type resolvedType)
private Type handleResolvedField(Expression node, FieldId fieldId, Type resolvedType, StackableAstVisitorContext<Context> context)
{
if (context.getContext().isInLambda()) {
LambdaArgumentDeclaration lambdaArgumentDeclaration = context.getContext().getFieldToLambdaArgumentDeclaration().get(fieldId);
if (lambdaArgumentDeclaration != null) {
// Lambda argument reference is not a column reference
lambdaArgumentReferences.put(NodeRef.of((Identifier) node), lambdaArgumentDeclaration);
return setExpressionType(node, resolvedType);
}
}

FieldId previous = columnReferences.put(NodeRef.of(node), fieldId);
checkState(previous == null, "%s already known to refer to %s", node, previous);
return setExpressionType(node, resolvedType);
Expand All @@ -388,17 +388,18 @@ protected Type visitDereferenceExpression(DereferenceExpression node, StackableA
{
QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node);

if (!context.getContext().isInLambda()) {
// If this Dereference looks like column reference, try match it to column first.
if (qualifiedName != null) {
Optional<ResolvedField> resolvedField = scope.tryResolveField(node, qualifiedName);
if (resolvedField.isPresent()) {
return handleResolvedField(node, resolvedField.get());
}
if (!scope.isColumnReference(qualifiedName)) {
throw missingAttributeException(node, qualifiedName);
// If this Dereference looks like column reference, try match it to column first.
if (qualifiedName != null) {
Scope scope = context.getContext().getScope();
Optional<ResolvedField> resolvedField = scope.tryResolveField(node, qualifiedName);
if (resolvedField.isPresent()) {
if (!context.getContext().isInLambda() || !context.getContext().getFieldToLambdaArgumentDeclaration().containsKey(FieldId.from(resolvedField.get()))) {
return handleResolvedField(node, resolvedField.get(), context);
}
}
if (!scope.isColumnReference(qualifiedName)) {
throw missingAttributeException(node, qualifiedName);
}
}

Type baseType = process(node.getBase(), context);
Expand Down Expand Up @@ -792,11 +793,11 @@ protected Type visitFunctionCall(FunctionCall node, StackableAstVisitorContext<C
parameters,
isDescribe);
if (context.getContext().isInLambda()) {
for (LambdaArgumentDeclaration argument : context.getContext().getNameToLambdaArgumentDeclarationMap().values()) {
for (LambdaArgumentDeclaration argument : context.getContext().getFieldToLambdaArgumentDeclaration().values()) {
innerExpressionAnalyzer.setExpressionType(argument, getExpressionType(argument));
}
}
return innerExpressionAnalyzer.analyze(expression, scope, context.getContext().expectingLambda(types)).getTypeSignature();
return innerExpressionAnalyzer.analyze(expression, baseScope, context.getContext().expectingLambda(types)).getTypeSignature();
}));
}
else {
Expand Down Expand Up @@ -969,7 +970,7 @@ protected Type visitSubqueryExpression(SubqueryExpression node, StackableAstVisi
}
StatementAnalyzer analyzer = statementAnalyzerFactory.apply(node);
Scope subqueryScope = Scope.builder()
.withParent(scope)
.withParent(context.getContext().getScope())
.build();
Scope queryScope = analyzer.analyze(node.getQuery(), subqueryScope);

Expand Down Expand Up @@ -1000,7 +1001,7 @@ else if (previousNode instanceof QuantifiedComparisonExpression) {
protected Type visitExists(ExistsPredicate node, StackableAstVisitorContext<Context> context)
{
StatementAnalyzer analyzer = statementAnalyzerFactory.apply(node);
Scope subqueryScope = Scope.builder().withParent(scope).build();
Scope subqueryScope = Scope.builder().withParent(context.getContext().getScope()).build();
analyzer.analyze(node.getSubquery(), subqueryScope);

existsSubqueries.add(NodeRef.of(node));
Expand Down Expand Up @@ -1035,7 +1036,7 @@ protected Type visitQuantifiedComparisonExpression(QuantifiedComparisonExpressio
}
break;
default:
checkState(false, "Unexpected comparison type: %s", node.getComparisonType());
throw new IllegalStateException(format("Unexpected comparison type: %s", node.getComparisonType()));
}

return setExpressionType(node, BOOLEAN);
Expand All @@ -1044,8 +1045,8 @@ protected Type visitQuantifiedComparisonExpression(QuantifiedComparisonExpressio
@Override
public Type visitFieldReference(FieldReference node, StackableAstVisitorContext<Context> context)
{
Type type = scope.getRelationType().getFieldByIndex(node.getFieldIndex()).getType();
return handleResolvedField(node, new FieldId(scope.getRelationId(), node.getFieldIndex()), type);
Type type = baseScope.getRelationType().getFieldByIndex(node.getFieldIndex()).getType();
return handleResolvedField(node, new FieldId(baseScope.getRelationId(), node.getFieldIndex()), type, context);
}

@Override
Expand All @@ -1063,18 +1064,30 @@ protected Type visitLambdaExpression(LambdaExpression node, StackableAstVisitorC
throw new SemanticException(INVALID_PARAMETER_USAGE, node,
format("Expected a lambda that takes %s argument(s) but got %s", types.size(), lambdaArguments.size()));
}
verify(types.size() == lambdaArguments.size());

Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap = new HashMap<>();
if (context.getContext().isInLambda()) {
nameToLambdaArgumentDeclarationMap.putAll(context.getContext().getNameToLambdaArgumentDeclarationMap());
}
ImmutableList.Builder<Field> fields = ImmutableList.builder();
for (int i = 0; i < lambdaArguments.size(); i++) {
LambdaArgumentDeclaration lambdaArgument = lambdaArguments.get(i);
nameToLambdaArgumentDeclarationMap.put(lambdaArgument.getName().getValue(), lambdaArgument);
setExpressionType(lambdaArgument, types.get(i));
Type type = types.get(i);
fields.add(Field.newUnqualified(lambdaArgument.getName().getValue(), type));
setExpressionType(lambdaArgument, type);
}

Scope lambdaScope = Scope.builder()
.withParent(context.getContext().getScope())
.withRelationType(RelationId.of(node), new RelationType(fields.build()))
.build();

ImmutableMap.Builder<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration = ImmutableMap.builder();
if (context.getContext().isInLambda()) {
fieldToLambdaArgumentDeclaration.putAll(context.getContext().getFieldToLambdaArgumentDeclaration());
}
Type returnType = process(node.getBody(), new StackableAstVisitorContext<Context>(Context.inLambda(nameToLambdaArgumentDeclarationMap)));
for (LambdaArgumentDeclaration lambdaArgument : lambdaArguments) {
ResolvedField resolvedField = lambdaScope.resolveField(lambdaArgument, QualifiedName.of(lambdaArgument.getName().getValue()));
fieldToLambdaArgumentDeclaration.put(FieldId.from(resolvedField), lambdaArgument);
}

Type returnType = process(node.getBody(), new StackableAstVisitorContext<>(Context.inLambda(lambdaScope, fieldToLambdaArgumentDeclaration.build())));
FunctionType functionType = new FunctionType(types, returnType);
return setExpressionType(node, functionType);
}
Expand Down Expand Up @@ -1117,6 +1130,7 @@ protected Type visitNode(Node node, StackableAstVisitorContext<Context> context)
throw new SemanticException(NOT_SUPPORTED, node, "not yet implemented: " + node.getClass().getName());
}

@Override
public Type visitGroupingOperation(GroupingOperation node, StackableAstVisitorContext<Context> context)
{
if (node.getGroupingColumns().size() > MAX_NUMBER_GROUPING_ARGUMENTS_BIGINT) {
Expand Down Expand Up @@ -1252,60 +1266,69 @@ else if (typeOnlyCoercions.contains(ref)) {

private static class Context
{
private final Scope scope;

// functionInputTypes and nameToLambdaDeclarationMap can be null or non-null independently. All 4 combinations are possible.

// The list of types when expecting a lambda (i.e. processing lambda parameters of a function); null otherwise.
// Empty list represents expecting a lambda with no arguments.
private final List<Type> functionInputTypes;
// The mapping from names to corresponding lambda argument declarations when inside a lambda; null otherwise.
// Empty map means that the all lambda expressions surrounding the current node has no arguments.
private final Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap;
private final Map<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration;

private Context(
Scope scope,
List<Type> functionInputTypes,
Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap)
Map<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration)
{
this.scope = requireNonNull(scope, "scope is null");
this.functionInputTypes = functionInputTypes;
this.nameToLambdaArgumentDeclarationMap = nameToLambdaArgumentDeclarationMap;
this.fieldToLambdaArgumentDeclaration = fieldToLambdaArgumentDeclaration;
}

static Context notInLambda(Scope scope)
{
return new Context(scope, null, null);
}

public static Context notInLambda()
static Context inLambda(Scope scope, Map<FieldId, LambdaArgumentDeclaration> fieldToLambdaArgumentDeclaration)
{
return new Context(null, null);
return new Context(scope, null, requireNonNull(fieldToLambdaArgumentDeclaration, "fieldToLambdaArgumentDeclaration is null"));
}

public static Context inLambda(Map<String, LambdaArgumentDeclaration> nameToLambdaArgumentDeclarationMap)
Context expectingLambda(List<Type> functionInputTypes)
{
return new Context(null, requireNonNull(nameToLambdaArgumentDeclarationMap, "nameToLambdaArgumentDeclarationMap is null"));
return new Context(scope, requireNonNull(functionInputTypes, "functionInputTypes is null"), this.fieldToLambdaArgumentDeclaration);
}

public Context expectingLambda(List<Type> functionInputTypes)
Context notExpectingLambda()
{
return new Context(requireNonNull(functionInputTypes, "functionInputTypes is null"), this.nameToLambdaArgumentDeclarationMap);
return new Context(scope, null, this.fieldToLambdaArgumentDeclaration);
}

public Context notExpectingLambda()
Scope getScope()
{
return new Context(null, this.nameToLambdaArgumentDeclarationMap);
return scope;
}

public boolean isInLambda()
boolean isInLambda()
{
return nameToLambdaArgumentDeclarationMap != null;
return fieldToLambdaArgumentDeclaration != null;
}

public boolean isExpectingLambda()
boolean isExpectingLambda()
{
return functionInputTypes != null;
}

public Map<String, LambdaArgumentDeclaration> getNameToLambdaArgumentDeclarationMap()
Map<FieldId, LambdaArgumentDeclaration> getFieldToLambdaArgumentDeclaration()
{
checkState(isInLambda());
return nameToLambdaArgumentDeclarationMap;
return fieldToLambdaArgumentDeclaration;
}

public List<Type> getFunctionInputTypes()
List<Type> getFunctionInputTypes()
{
checkState(isExpectingLambda());
return functionInputTypes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,18 @@ public void testLambdaCapture()
{
// Test for lambda expression without capture can be found in TestLambdaExpression

assertQuery("SELECT apply(0, x -> x + c1) FROM (VALUES 1) t(c1)", "VALUES 1");
assertQuery("SELECT apply(0, x -> x + t.c1) FROM (VALUES 1) t(c1)", "VALUES 1");
assertQuery("SELECT apply(c1, x -> x + c2) FROM (VALUES (1, 2), (3, 4), (5, 6)) t(c1, c2)", "VALUES 3, 7, 11");
assertQuery("SELECT apply(c1 + 10, x -> apply(x + 100, y -> c1)) FROM (VALUES 1) t(c1)", "VALUES 1");
assertQuery("SELECT apply(c1 + 10, x -> apply(x + 100, y -> t.c1)) FROM (VALUES 1) t(c1)", "VALUES 1");
assertQuery("SELECT apply(CAST(ROW(10) AS ROW(x INTEGER)), r -> r.x)", "VALUES 10");
assertQuery("SELECT apply(CAST(ROW(10) AS ROW(x INTEGER)), r -> r.x) FROM (VALUES 1) u(x)", "VALUES 10");
// assertQuery("SELECT apply(CAST(ROW(10) AS ROW(x INTEGER)), r -> r.x) FROM (VALUES 1) r(x)", "VALUES 10"); TODO #9025
assertQuery("SELECT apply(CAST(ROW(10) AS ROW(x INTEGER)), r -> apply(3, y -> y + r.x)) FROM (VALUES 1) u(x)", "VALUES 13");
// assertQuery("SELECT apply(CAST(ROW(10) AS ROW(x INTEGER)), r -> apply(3, y -> y + r.x)) FROM (VALUES 1) r(x)", "VALUES 13"); TODO #9025
// assertQuery("SELECT apply(CAST(ROW(10) AS ROW(x INTEGER)), r -> apply(3, y -> y + r.x)) FROM (VALUES 'a') r(x)", "VALUES 13"); TODO #9025
assertQuery("SELECT apply(CAST(ROW(10) AS ROW(x INTEGER)), z -> apply(3, y -> y + r.x)) FROM (VALUES 1) r(x)", "VALUES 4");

// reference lambda variable of the not-immediately-enclosing lambda
assertQuery("SELECT apply(1, x -> apply(10, y -> x)) FROM (VALUES 1000) t(x)", "VALUES 1");
Expand Down