From 2bc89903b62d8a983685c293f271c9817d0d539a Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Mon, 7 Oct 2024 10:12:56 -0700 Subject: [PATCH] Split construction of large rows into separate methods Split the processing of fields into batches to reduce the size of a single generated method. This makes it possible to call ROW(...) with a large number of fields. --- .../sql/gen/BytecodeGeneratorContext.java | 37 ++++++++++- .../sql/gen/CursorProcessorCompiler.java | 9 ++- .../sql/gen/JoinFilterFunctionCompiler.java | 4 +- .../sql/gen/LambdaBytecodeGenerator.java | 20 +++++- .../trino/sql/gen/PageFunctionCompiler.java | 8 ++- .../sql/gen/RowConstructorCodeGenerator.java | 61 ++++++++++++++++--- .../trino/sql/gen/RowExpressionCompiler.java | 23 +++++-- .../trino/sql/routine/SqlRoutineCompiler.java | 9 ++- 8 files changed, 149 insertions(+), 22 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeGeneratorContext.java b/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeGeneratorContext.java index dc373ac44d93a..2f3242780438e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeGeneratorContext.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/BytecodeGeneratorContext.java @@ -13,7 +13,10 @@ */ package io.trino.sql.gen; +import com.google.common.collect.ImmutableList; import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.Parameter; import io.airlift.bytecode.Scope; import io.airlift.bytecode.Variable; import io.trino.metadata.FunctionManager; @@ -40,19 +43,24 @@ public class BytecodeGeneratorContext private final CachedInstanceBinder cachedInstanceBinder; private final FunctionManager functionManager; private final Variable wasNull; + private final ClassDefinition classDefinition; + private final List contextArguments; // arguments that need to be propagated to generated methods to be able to resolve underlying references, session, etc. public BytecodeGeneratorContext( RowExpressionCompiler rowExpressionCompiler, Scope scope, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, - FunctionManager functionManager) + FunctionManager functionManager, + ClassDefinition classDefinition, + List contextArguments) { requireNonNull(rowExpressionCompiler, "rowExpressionCompiler is null"); requireNonNull(cachedInstanceBinder, "cachedInstanceBinder is null"); requireNonNull(scope, "scope is null"); requireNonNull(callSiteBinder, "callSiteBinder is null"); requireNonNull(functionManager, "functionManager is null"); + requireNonNull(classDefinition, "classDefinition is null"); this.rowExpressionCompiler = rowExpressionCompiler; this.scope = scope; @@ -60,6 +68,8 @@ public BytecodeGeneratorContext( this.cachedInstanceBinder = cachedInstanceBinder; this.functionManager = functionManager; this.wasNull = scope.getVariable("wasNull"); + this.classDefinition = classDefinition; + this.contextArguments = ImmutableList.copyOf(contextArguments); } public Scope getScope() @@ -110,4 +120,29 @@ public Variable wasNull() { return wasNull; } + + public ClassDefinition getClassDefinition() + { + return classDefinition; + } + + public RowExpressionCompiler getRowExpressionCompiler() + { + return rowExpressionCompiler; + } + + public CachedInstanceBinder getCachedInstanceBinder() + { + return cachedInstanceBinder; + } + + public FunctionManager getFunctionManager() + { + return functionManager; + } + + public List getContextArguments() + { + return contextArguments; + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/CursorProcessorCompiler.java b/core/trino-main/src/main/java/io/trino/sql/gen/CursorProcessorCompiler.java index 6203b475608f7..8091d3330d8d1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/CursorProcessorCompiler.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/CursorProcessorCompiler.java @@ -13,6 +13,7 @@ */ package io.trino.sql.gen; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.bytecode.BytecodeBlock; @@ -236,11 +237,13 @@ private void generateFilterMethod( Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull"); RowExpressionCompiler compiler = new RowExpressionCompiler( + classDefinition, callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(cursor), functionManager, - compiledLambdaMap); + compiledLambdaMap, + ImmutableList.of(session, cursor)); LabelNode end = new LabelNode("end"); method.getBody() @@ -276,11 +279,13 @@ private void generateProjectMethod( Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull"); RowExpressionCompiler compiler = new RowExpressionCompiler( + classDefinition, callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(cursor), functionManager, - compiledLambdaMap); + compiledLambdaMap, + ImmutableList.of(session, cursor, output)); method.getBody() .comment("boolean wasNull = false;") diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/JoinFilterFunctionCompiler.java b/core/trino-main/src/main/java/io/trino/sql/gen/JoinFilterFunctionCompiler.java index dca68e9b29406..a9093e33fadfe 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/JoinFilterFunctionCompiler.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/JoinFilterFunctionCompiler.java @@ -194,11 +194,13 @@ private void generateFilterMethod( scope.declareVariable("session", body, method.getThis().getField(sessionField)); RowExpressionCompiler compiler = new RowExpressionCompiler( + classDefinition, callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(callSiteBinder, leftPosition, leftPage, rightPosition, rightPage, leftBlocksSize), functionManager, - compiledLambdaMap); + compiledLambdaMap, + ImmutableList.of(leftPage, leftPosition, rightPage, rightPosition)); BytecodeNode visitorBody = compiler.compile(filter, scope); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java index c7eebde3f3a5f..6e08c9e7602c3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/LambdaBytecodeGenerator.java @@ -131,11 +131,13 @@ public static CompiledLambda preGenerateLambdaExpression( } RowExpressionCompiler innerExpressionCompiler = new RowExpressionCompiler( + classDefinition, callSiteBinder, cachedInstanceBinder, variableReferenceCompiler(parameterMapBuilder.buildOrThrow()), functionManager, - compiledLambdaMap); + compiledLambdaMap, + parameters.build()); return defineLambdaMethod( innerExpressionCompiler, @@ -266,18 +268,30 @@ public static Class> compileLambdaProvider(LambdaDefi scope.declareVariable("session", body, method.getThis().getField(sessionField)); RowExpressionCompiler rowExpressionCompiler = new RowExpressionCompiler( + lambdaProviderClassDefinition, callSiteBinder, cachedInstanceBinder, variableReferenceCompiler(ImmutableMap.of()), functionManager, - compiledLambdaMap); + compiledLambdaMap, + ImmutableList.of()); + + List parameters = new ArrayList<>(); + parameters.add(arg("session", ConnectorSession.class)); + for (int i = 0; i < lambdaExpression.arguments().size(); i++) { + Symbol argument = lambdaExpression.arguments().get(i); + Class type = Primitives.wrap(argument.type().getJavaType()); + parameters.add(arg("lambda_" + i + "_" + BytecodeUtils.sanitizeName(argument.name()), type)); + } BytecodeGeneratorContext generatorContext = new BytecodeGeneratorContext( rowExpressionCompiler, scope, callSiteBinder, cachedInstanceBinder, - functionManager); + functionManager, + lambdaProviderClassDefinition, + parameters); body.append( generateLambda( diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java b/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java index 23214eccf0c5e..a2918be1dab72 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java @@ -359,11 +359,13 @@ private MethodDefinition generateEvaluateMethod( Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse()); RowExpressionCompiler compiler = new RowExpressionCompiler( + classDefinition, callSiteBinder, cachedInstanceBinder, fieldReferenceCompilerProjection(callSiteBinder), functionManager, - compiledLambdaMap); + compiledLambdaMap, + ImmutableList.of(session, position)); body.append(thisVariable.getField(blockBuilder)) .append(compiler.compile(projection, scope)) @@ -543,11 +545,13 @@ private MethodDefinition generateFilterMethod( Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse()); RowExpressionCompiler compiler = new RowExpressionCompiler( + classDefinition, callSiteBinder, cachedInstanceBinder, fieldReferenceCompiler(callSiteBinder), functionManager, - compiledLambdaMap); + compiledLambdaMap, + ImmutableList.of(page, position)); Variable result = scope.declareVariable(boolean.class, "result"); body.append(compiler.compile(filter, scope)) diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/RowConstructorCodeGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/RowConstructorCodeGenerator.java index b627203319181..25c4239953a74 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/RowConstructorCodeGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/RowConstructorCodeGenerator.java @@ -15,6 +15,9 @@ import io.airlift.bytecode.BytecodeBlock; import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.MethodDefinition; +import io.airlift.bytecode.Parameter; import io.airlift.bytecode.Scope; import io.airlift.bytecode.Variable; import io.airlift.bytecode.control.IfStatement; @@ -28,8 +31,12 @@ import io.trino.sql.relational.RowExpression; import io.trino.sql.relational.SpecialForm; +import java.util.ArrayList; import java.util.List; +import static io.airlift.bytecode.Access.PUBLIC; +import static io.airlift.bytecode.Access.a; +import static io.airlift.bytecode.Parameter.arg; import static io.airlift.bytecode.ParameterizedType.type; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; @@ -49,6 +56,9 @@ public class RowConstructorCodeGenerator // Arbitrary value chosen to balance the code size vs performance trade off. Not perf tested. private static final int MEGAMORPHIC_FIELD_COUNT = 64; + // number of fields to initialize in a single method for large rows + private static final int LARGE_ROW_BATCH_SIZE = 100; + public RowConstructorCodeGenerator(SpecialForm specialForm) { requireNonNull(specialForm, "specialForm is null"); @@ -109,18 +119,58 @@ private BytecodeNode generateExpressionForLargeRows(BytecodeGeneratorContext con BytecodeBlock block = new BytecodeBlock().setDescription("Constructor for " + rowType); CallSiteBinder binder = context.getCallSiteBinder(); Scope scope = context.getScope(); - List types = rowType.getTypeParameters(); Variable fieldBuilders = scope.getOrCreateTempVariable(BlockBuilder[].class); block.append(fieldBuilders.set(invokeStatic(RowConstructorCodeGenerator.class, "createFieldBlockBuildersForSingleRow", BlockBuilder[].class, constantType(binder, rowType)))); Variable blockBuilder = scope.getOrCreateTempVariable(BlockBuilder.class); - for (int i = 0; i < arguments.size(); ++i) { + for (int i = 0; i < arguments.size(); i += LARGE_ROW_BATCH_SIZE) { + MethodDefinition partialRowConstructor = generatePartialRowConstructor(i, Math.min(i + LARGE_ROW_BATCH_SIZE, arguments.size()), context); + block.getVariable(scope.getThis()); + for (Parameter argument : context.getContextArguments()) { + block.getVariable(argument); + } + block.getVariable(fieldBuilders); + block.invokeVirtual(partialRowConstructor); + } + scope.releaseTempVariableForReuse(blockBuilder); + + block.append(invokeStatic(RowConstructorCodeGenerator.class, "createSqlRowFromFieldBuildersForSingleRow", SqlRow.class, fieldBuilders)); + scope.releaseTempVariableForReuse(fieldBuilders); + block.append(context.wasNull().set(constantFalse())); + return block; + } + + private MethodDefinition generatePartialRowConstructor(int start, int end, BytecodeGeneratorContext parentContext) + { + ClassDefinition classDefinition = parentContext.getClassDefinition(); + CallSiteBinder binder = parentContext.getCallSiteBinder(); + + Parameter fieldBuilders = arg("fieldBuilders", BlockBuilder[].class); + + List parameters = new ArrayList<>(parentContext.getContextArguments()); + parameters.add(fieldBuilders); + + MethodDefinition methodDefinition = classDefinition.declareMethod( + a(PUBLIC), + "partialRowConstructor" + System.identityHashCode(this) + "_" + start, + type(void.class), + parameters); + + Scope scope = methodDefinition.getScope(); + BytecodeBlock block = methodDefinition.getBody(); + scope.declareVariable("wasNull", block, constantFalse()); + + BytecodeGeneratorContext context = new BytecodeGeneratorContext(parentContext.getRowExpressionCompiler(), scope, binder, parentContext.getCachedInstanceBinder(), parentContext.getFunctionManager(), classDefinition, parentContext.getContextArguments()); + Variable blockBuilder = scope.getOrCreateTempVariable(BlockBuilder.class); + List types = rowType.getTypeParameters(); + for (int i = start; i < end; i++) { Type fieldType = types.get(i); block.append(blockBuilder.set(fieldBuilders.getElement(constantInt(i)))); block.comment("Clean wasNull and Generate + " + i + "-th field of row"); + block.append(context.wasNull().set(constantFalse())); block.append(context.generate(arguments.get(i))); Variable field = scope.getOrCreateTempVariable(fieldType.getJavaType()); @@ -131,12 +181,9 @@ private BytecodeNode generateExpressionForLargeRows(BytecodeGeneratorContext con .ifFalse(constantType(binder, fieldType).writeValue(blockBuilder, field).pop())); scope.releaseTempVariableForReuse(field); } - scope.releaseTempVariableForReuse(blockBuilder); - block.append(invokeStatic(RowConstructorCodeGenerator.class, "createSqlRowFromFieldBuildersForSingleRow", SqlRow.class, fieldBuilders)); - scope.releaseTempVariableForReuse(fieldBuilders); - block.append(context.wasNull().set(constantFalse())); - return block; + block.ret(); + return methodDefinition; } @UsedByGeneratedCode diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/RowExpressionCompiler.java b/core/trino-main/src/main/java/io/trino/sql/gen/RowExpressionCompiler.java index 959831dffbc05..87bb405f59ca5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/RowExpressionCompiler.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/RowExpressionCompiler.java @@ -17,6 +17,8 @@ import com.google.common.collect.ImmutableList; import io.airlift.bytecode.BytecodeBlock; import io.airlift.bytecode.BytecodeNode; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.Parameter; import io.airlift.bytecode.Scope; import io.airlift.bytecode.Variable; import io.trino.metadata.FunctionManager; @@ -31,6 +33,7 @@ import io.trino.sql.relational.SpecialForm; import io.trino.sql.relational.VariableReferenceExpression; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -45,24 +48,30 @@ public class RowExpressionCompiler { + private final ClassDefinition classDefinition; private final CallSiteBinder callSiteBinder; private final CachedInstanceBinder cachedInstanceBinder; private final RowExpressionVisitor fieldReferenceCompiler; private final FunctionManager functionManager; private final Map compiledLambdaMap; + private final List contextArguments; // arguments that need to be propagates to generated methods public RowExpressionCompiler( + ClassDefinition classDefinition, CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, RowExpressionVisitor fieldReferenceCompiler, FunctionManager functionManager, - Map compiledLambdaMap) + Map compiledLambdaMap, + List contextArguments) { + this.classDefinition = classDefinition; this.callSiteBinder = callSiteBinder; this.cachedInstanceBinder = cachedInstanceBinder; this.fieldReferenceCompiler = fieldReferenceCompiler; this.functionManager = functionManager; this.compiledLambdaMap = compiledLambdaMap; + this.contextArguments = ImmutableList.copyOf(contextArguments); } public BytecodeNode compile(RowExpression rowExpression, Scope scope) @@ -86,7 +95,9 @@ public BytecodeNode visitCall(CallExpression call, Context context) context.getScope(), callSiteBinder, cachedInstanceBinder, - functionManager); + functionManager, + classDefinition, + contextArguments); return generatorContext.generateFullCall(call.resolvedFunction(), call.arguments()); } @@ -116,7 +127,9 @@ public BytecodeNode visitSpecialForm(SpecialForm specialForm, Context context) context.getScope(), callSiteBinder, cachedInstanceBinder, - functionManager); + functionManager, + classDefinition, + contextArguments); return generator.generateExpression(generatorContext); } @@ -178,7 +191,9 @@ public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context conte context.getScope(), callSiteBinder, cachedInstanceBinder, - functionManager); + functionManager, + classDefinition, + contextArguments); return generateLambda( generatorContext, diff --git a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineCompiler.java b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineCompiler.java index af7b3d14e4e6e..0f19cb63e6246 100644 --- a/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineCompiler.java +++ b/core/trino-main/src/main/java/io/trino/sql/routine/SqlRoutineCompiler.java @@ -226,7 +226,7 @@ private void generateRunMethod( Map variables = VariableExtractor.extract(routine).stream().distinct() .collect(toImmutableMap(identity(), variable -> getOrDeclareVariable(scope, variable))); - BytecodeVisitor visitor = new BytecodeVisitor(cachedInstanceBinder, compiledLambdaMap, variables); + BytecodeVisitor visitor = new BytecodeVisitor(classDefinition, cachedInstanceBinder, compiledLambdaMap, variables); method.getBody().append(visitor.process(routine, scope)); } @@ -283,6 +283,7 @@ private static String name(int field) private class BytecodeVisitor implements IrNodeVisitor { + private final ClassDefinition classDefinition; private final CachedInstanceBinder cachedInstanceBinder; private final Map compiledLambdaMap; private final Map variables; @@ -291,10 +292,12 @@ private class BytecodeVisitor private final Map breakLabels = new HashMap<>(); public BytecodeVisitor( + ClassDefinition classDefinition, CachedInstanceBinder cachedInstanceBinder, Map compiledLambdaMap, Map variables) { + this.classDefinition = requireNonNull(classDefinition, "classDefinition is null"); this.cachedInstanceBinder = requireNonNull(cachedInstanceBinder, "cachedInstanceBinder is null"); this.compiledLambdaMap = requireNonNull(compiledLambdaMap, "compiledLambdaMap is null"); this.variables = requireNonNull(variables, "variables is null"); @@ -462,11 +465,13 @@ private BytecodeNode compile(RowExpression expression, Scope scope) } RowExpressionCompiler rowExpressionCompiler = new RowExpressionCompiler( + classDefinition, cachedInstanceBinder.getCallSiteBinder(), cachedInstanceBinder, FieldReferenceCompiler.INSTANCE, functionManager, - compiledLambdaMap); + compiledLambdaMap, + ImmutableList.of()); return new BytecodeBlock() .comment("boolean wasNull = false;")