Skip to content

Commit

Permalink
Split construction of large rows into separate methods
Browse files Browse the repository at this point in the history
Split the processing of fields into batches to reduce the
size of a single generated method. This makes it possible
to call ROW(...) with a large number of fields.
  • Loading branch information
martint committed Oct 10, 2024
1 parent 54827dc commit 2bc8990
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
*/
package io.trino.sql.gen;

import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.trino.metadata.FunctionManager;
Expand All @@ -40,26 +43,33 @@ public class BytecodeGeneratorContext
private final CachedInstanceBinder cachedInstanceBinder;
private final FunctionManager functionManager;
private final Variable wasNull;
private final ClassDefinition classDefinition;
private final List<Parameter> contextArguments; // arguments that need to be propagated to generated methods to be able to resolve underlying references, session, etc.

public BytecodeGeneratorContext(
RowExpressionCompiler rowExpressionCompiler,
Scope scope,
CallSiteBinder callSiteBinder,
CachedInstanceBinder cachedInstanceBinder,
FunctionManager functionManager)
FunctionManager functionManager,
ClassDefinition classDefinition,
List<Parameter> contextArguments)
{
requireNonNull(rowExpressionCompiler, "rowExpressionCompiler is null");
requireNonNull(cachedInstanceBinder, "cachedInstanceBinder is null");
requireNonNull(scope, "scope is null");
requireNonNull(callSiteBinder, "callSiteBinder is null");
requireNonNull(functionManager, "functionManager is null");
requireNonNull(classDefinition, "classDefinition is null");

this.rowExpressionCompiler = rowExpressionCompiler;
this.scope = scope;
this.callSiteBinder = callSiteBinder;
this.cachedInstanceBinder = cachedInstanceBinder;
this.functionManager = functionManager;
this.wasNull = scope.getVariable("wasNull");
this.classDefinition = classDefinition;
this.contextArguments = ImmutableList.copyOf(contextArguments);
}

public Scope getScope()
Expand Down Expand Up @@ -110,4 +120,29 @@ public Variable wasNull()
{
return wasNull;
}

public ClassDefinition getClassDefinition()
{
return classDefinition;
}

public RowExpressionCompiler getRowExpressionCompiler()
{
return rowExpressionCompiler;
}

public CachedInstanceBinder getCachedInstanceBinder()
{
return cachedInstanceBinder;
}

public FunctionManager getFunctionManager()
{
return functionManager;
}

public List<Parameter> getContextArguments()
{
return contextArguments;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.sql.gen;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.bytecode.BytecodeBlock;
Expand Down Expand Up @@ -236,11 +237,13 @@ private void generateFilterMethod(
Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull");

RowExpressionCompiler compiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompiler(cursor),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of(session, cursor));

LabelNode end = new LabelNode("end");
method.getBody()
Expand Down Expand Up @@ -276,11 +279,13 @@ private void generateProjectMethod(
Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull");

RowExpressionCompiler compiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompiler(cursor),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of(session, cursor, output));

method.getBody()
.comment("boolean wasNull = false;")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,13 @@ private void generateFilterMethod(
scope.declareVariable("session", body, method.getThis().getField(sessionField));

RowExpressionCompiler compiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompiler(callSiteBinder, leftPosition, leftPage, rightPosition, rightPage, leftBlocksSize),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of(leftPage, leftPosition, rightPage, rightPosition));

BytecodeNode visitorBody = compiler.compile(filter, scope);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,13 @@ public static CompiledLambda preGenerateLambdaExpression(
}

RowExpressionCompiler innerExpressionCompiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
variableReferenceCompiler(parameterMapBuilder.buildOrThrow()),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
parameters.build());

return defineLambdaMethod(
innerExpressionCompiler,
Expand Down Expand Up @@ -266,18 +268,30 @@ public static Class<? extends Supplier<Object>> compileLambdaProvider(LambdaDefi
scope.declareVariable("session", body, method.getThis().getField(sessionField));

RowExpressionCompiler rowExpressionCompiler = new RowExpressionCompiler(
lambdaProviderClassDefinition,
callSiteBinder,
cachedInstanceBinder,
variableReferenceCompiler(ImmutableMap.of()),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of());

List<Parameter> parameters = new ArrayList<>();
parameters.add(arg("session", ConnectorSession.class));
for (int i = 0; i < lambdaExpression.arguments().size(); i++) {
Symbol argument = lambdaExpression.arguments().get(i);
Class<?> type = Primitives.wrap(argument.type().getJavaType());
parameters.add(arg("lambda_" + i + "_" + BytecodeUtils.sanitizeName(argument.name()), type));
}

BytecodeGeneratorContext generatorContext = new BytecodeGeneratorContext(
rowExpressionCompiler,
scope,
callSiteBinder,
cachedInstanceBinder,
functionManager);
functionManager,
lambdaProviderClassDefinition,
parameters);

body.append(
generateLambda(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,13 @@ private MethodDefinition generateEvaluateMethod(

Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse());
RowExpressionCompiler compiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompilerProjection(callSiteBinder),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of(session, position));

body.append(thisVariable.getField(blockBuilder))
.append(compiler.compile(projection, scope))
Expand Down Expand Up @@ -543,11 +545,13 @@ private MethodDefinition generateFilterMethod(

Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse());
RowExpressionCompiler compiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompiler(callSiteBinder),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of(page, position));

Variable result = scope.declareVariable(boolean.class, "result");
body.append(compiler.compile(filter, scope))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.IfStatement;
Expand All @@ -28,8 +31,12 @@
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SpecialForm;

import java.util.ArrayList;
import java.util.List;

import static io.airlift.bytecode.Access.PUBLIC;
import static io.airlift.bytecode.Access.a;
import static io.airlift.bytecode.Parameter.arg;
import static io.airlift.bytecode.ParameterizedType.type;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt;
Expand All @@ -49,6 +56,9 @@ public class RowConstructorCodeGenerator
// Arbitrary value chosen to balance the code size vs performance trade off. Not perf tested.
private static final int MEGAMORPHIC_FIELD_COUNT = 64;

// number of fields to initialize in a single method for large rows
private static final int LARGE_ROW_BATCH_SIZE = 100;

public RowConstructorCodeGenerator(SpecialForm specialForm)
{
requireNonNull(specialForm, "specialForm is null");
Expand Down Expand Up @@ -109,18 +119,58 @@ private BytecodeNode generateExpressionForLargeRows(BytecodeGeneratorContext con
BytecodeBlock block = new BytecodeBlock().setDescription("Constructor for " + rowType);
CallSiteBinder binder = context.getCallSiteBinder();
Scope scope = context.getScope();
List<Type> types = rowType.getTypeParameters();

Variable fieldBuilders = scope.getOrCreateTempVariable(BlockBuilder[].class);
block.append(fieldBuilders.set(invokeStatic(RowConstructorCodeGenerator.class, "createFieldBlockBuildersForSingleRow", BlockBuilder[].class, constantType(binder, rowType))));

Variable blockBuilder = scope.getOrCreateTempVariable(BlockBuilder.class);
for (int i = 0; i < arguments.size(); ++i) {
for (int i = 0; i < arguments.size(); i += LARGE_ROW_BATCH_SIZE) {
MethodDefinition partialRowConstructor = generatePartialRowConstructor(i, Math.min(i + LARGE_ROW_BATCH_SIZE, arguments.size()), context);
block.getVariable(scope.getThis());
for (Parameter argument : context.getContextArguments()) {
block.getVariable(argument);
}
block.getVariable(fieldBuilders);
block.invokeVirtual(partialRowConstructor);
}
scope.releaseTempVariableForReuse(blockBuilder);

block.append(invokeStatic(RowConstructorCodeGenerator.class, "createSqlRowFromFieldBuildersForSingleRow", SqlRow.class, fieldBuilders));
scope.releaseTempVariableForReuse(fieldBuilders);
block.append(context.wasNull().set(constantFalse()));
return block;
}

private MethodDefinition generatePartialRowConstructor(int start, int end, BytecodeGeneratorContext parentContext)
{
ClassDefinition classDefinition = parentContext.getClassDefinition();
CallSiteBinder binder = parentContext.getCallSiteBinder();

Parameter fieldBuilders = arg("fieldBuilders", BlockBuilder[].class);

List<Parameter> parameters = new ArrayList<>(parentContext.getContextArguments());
parameters.add(fieldBuilders);

MethodDefinition methodDefinition = classDefinition.declareMethod(
a(PUBLIC),
"partialRowConstructor" + System.identityHashCode(this) + "_" + start,
type(void.class),
parameters);

Scope scope = methodDefinition.getScope();
BytecodeBlock block = methodDefinition.getBody();
scope.declareVariable("wasNull", block, constantFalse());

BytecodeGeneratorContext context = new BytecodeGeneratorContext(parentContext.getRowExpressionCompiler(), scope, binder, parentContext.getCachedInstanceBinder(), parentContext.getFunctionManager(), classDefinition, parentContext.getContextArguments());
Variable blockBuilder = scope.getOrCreateTempVariable(BlockBuilder.class);
List<Type> types = rowType.getTypeParameters();
for (int i = start; i < end; i++) {
Type fieldType = types.get(i);

block.append(blockBuilder.set(fieldBuilders.getElement(constantInt(i))));

block.comment("Clean wasNull and Generate + " + i + "-th field of row");

block.append(context.wasNull().set(constantFalse()));
block.append(context.generate(arguments.get(i)));
Variable field = scope.getOrCreateTempVariable(fieldType.getJavaType());
Expand All @@ -131,12 +181,9 @@ private BytecodeNode generateExpressionForLargeRows(BytecodeGeneratorContext con
.ifFalse(constantType(binder, fieldType).writeValue(blockBuilder, field).pop()));
scope.releaseTempVariableForReuse(field);
}
scope.releaseTempVariableForReuse(blockBuilder);

block.append(invokeStatic(RowConstructorCodeGenerator.class, "createSqlRowFromFieldBuildersForSingleRow", SqlRow.class, fieldBuilders));
scope.releaseTempVariableForReuse(fieldBuilders);
block.append(context.wasNull().set(constantFalse()));
return block;
block.ret();
return methodDefinition;
}

@UsedByGeneratedCode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.trino.metadata.FunctionManager;
Expand All @@ -31,6 +33,7 @@
import io.trino.sql.relational.SpecialForm;
import io.trino.sql.relational.VariableReferenceExpression;

import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand All @@ -45,24 +48,30 @@

public class RowExpressionCompiler
{
private final ClassDefinition classDefinition;
private final CallSiteBinder callSiteBinder;
private final CachedInstanceBinder cachedInstanceBinder;
private final RowExpressionVisitor<BytecodeNode, Scope> fieldReferenceCompiler;
private final FunctionManager functionManager;
private final Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap;
private final List<Parameter> contextArguments; // arguments that need to be propagates to generated methods

public RowExpressionCompiler(
ClassDefinition classDefinition,
CallSiteBinder callSiteBinder,
CachedInstanceBinder cachedInstanceBinder,
RowExpressionVisitor<BytecodeNode, Scope> fieldReferenceCompiler,
FunctionManager functionManager,
Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap)
Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap,
List<Parameter> contextArguments)
{
this.classDefinition = classDefinition;
this.callSiteBinder = callSiteBinder;
this.cachedInstanceBinder = cachedInstanceBinder;
this.fieldReferenceCompiler = fieldReferenceCompiler;
this.functionManager = functionManager;
this.compiledLambdaMap = compiledLambdaMap;
this.contextArguments = ImmutableList.copyOf(contextArguments);
}

public BytecodeNode compile(RowExpression rowExpression, Scope scope)
Expand All @@ -86,7 +95,9 @@ public BytecodeNode visitCall(CallExpression call, Context context)
context.getScope(),
callSiteBinder,
cachedInstanceBinder,
functionManager);
functionManager,
classDefinition,
contextArguments);

return generatorContext.generateFullCall(call.resolvedFunction(), call.arguments());
}
Expand Down Expand Up @@ -116,7 +127,9 @@ public BytecodeNode visitSpecialForm(SpecialForm specialForm, Context context)
context.getScope(),
callSiteBinder,
cachedInstanceBinder,
functionManager);
functionManager,
classDefinition,
contextArguments);

return generator.generateExpression(generatorContext);
}
Expand Down Expand Up @@ -178,7 +191,9 @@ public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context conte
context.getScope(),
callSiteBinder,
cachedInstanceBinder,
functionManager);
functionManager,
classDefinition,
contextArguments);

return generateLambda(
generatorContext,
Expand Down
Loading

0 comments on commit 2bc8990

Please sign in to comment.