Skip to content

Commit

Permalink
fix: Apply the ExtensionSecurityManager to UDAFs (#8776)
Browse files Browse the repository at this point in the history
* fix: Apply the ExtensionSecurityManager to UDAFs

Addresses #8662
  • Loading branch information
jnh5y authored Mar 24, 2022
1 parent 7f468bc commit a37688c
Show file tree
Hide file tree
Showing 12 changed files with 896 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.security.ExtensionSecurityManager;
import io.confluent.ksql.util.KsqlException;
import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -55,12 +56,17 @@ public BaseAggregateFunction(
) {
this.argIndexInValue = argIndexInValue;
this.initialValueSupplier = () -> {
final A val = initialValueSupplier.get();
if (val instanceof Struct && !((Struct) val).schema().isOptional()) {
throw new KsqlException("Initialize function for " + functionName
+ " must return struct with optional schema");
ExtensionSecurityManager.INSTANCE.pushInUdf();
try {
final A val = initialValueSupplier.get();
if (val instanceof Struct && !((Struct) val).schema().isOptional()) {
throw new KsqlException("Initialize function for " + functionName
+ " must return struct with optional schema");
}
return val;
} finally {
ExtensionSecurityManager.INSTANCE.popOutUdf();
}
return val;
};
this.aggregateSchema = Objects.requireNonNull(aggregateType, "aggregateType");
this.outputSchema = Objects.requireNonNull(outputType, "outputType");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.confluent.ksql.schema.ksql.SqlTypeParser;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.security.ExtensionSecurityManager;
import io.confluent.ksql.util.KsqlException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -262,13 +263,16 @@ private static SqlType invokeSchemaProviderMethod(
final String functionName
) {
try {
ExtensionSecurityManager.INSTANCE.pushInUdf();
return (SqlType) m.invoke(instance, args);
} catch (IllegalAccessException
| InvocationTargetException e) {
throw new KsqlException(String.format("Cannot invoke the schema provider "
+ "method %s for UDF %s. ",
m.getName(), functionName
), e);
} finally {
ExtensionSecurityManager.INSTANCE.popOutUdf();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.confluent.ksql.GenericKey;
import io.confluent.ksql.function.udaf.Udaf;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.security.ExtensionSecurityManager;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -106,9 +107,12 @@ private static Optional<Sensor> getSensor(
private static <T> T timed(final Optional<Sensor> maybeSensor, final Supplier<T> task) {
final long start = Time.SYSTEM.nanoseconds();
try {
// Since the timed() function wraps the calls to Udafs, we use it to protect the calls.
ExtensionSecurityManager.INSTANCE.pushInUdf();
return task.get();
} finally {
maybeSensor.ifPresent(sensor -> sensor.record(Time.SYSTEM.nanoseconds() - start));
ExtensionSecurityManager.INSTANCE.popOutUdf();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.SqlTypeParser;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.security.ExtensionSecurityManager;
import io.confluent.ksql.util.KsqlException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
Expand Down Expand Up @@ -89,8 +90,8 @@ KsqlAggregateFunction createFunction(final AggregateFunctionInitArguments initAr
final List<SqlArgument> argTypeList) {
final Object[] factoryArgs = initArgs.args().toArray();
try {
ExtensionSecurityManager.INSTANCE.pushInUdf();
final Udaf udaf = (Udaf)method.invoke(null, factoryArgs);

udaf.initializeTypeArguments(argTypeList);
if (udaf instanceof Configurable) {
((Configurable) udaf).configure(initArgs.config());
Expand Down Expand Up @@ -134,6 +135,8 @@ KsqlAggregateFunction createFunction(final AggregateFunctionInitArguments initAr
} catch (final Exception e) {
LOG.error("Failed to invoke UDAF factory method", e);
throw new KsqlException("Failed to invoke UDAF factory method", e);
} finally {
ExtensionSecurityManager.INSTANCE.popOutUdf();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.confluent.ksql.function.udaf.TableUdaf;
import io.confluent.ksql.function.udaf.Udaf;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.security.ExtensionSecurityManager;
import java.util.List;
import java.util.Optional;
import org.apache.kafka.common.metrics.Metrics;
Expand All @@ -42,6 +43,11 @@ public UdafTableAggregateFunction(

@Override
public A undo(final I valueToUndo, final A aggregateValue) {
return ((TableUdaf<I, A, O>)udaf).undo(valueToUndo, aggregateValue);
ExtensionSecurityManager.INSTANCE.pushInUdf();
try {
return ((TableUdaf<I, A, O>)udaf).undo(valueToUndo, aggregateValue);
} finally {
ExtensionSecurityManager.INSTANCE.popOutUdf();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.confluent.ksql.function.udf.UdfMetadata;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.SqlTypeParser;
import io.confluent.ksql.security.ExtensionSecurityManager;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -186,8 +187,13 @@ private Function<KsqlConfig, Kudf> getUdfFactory(
final Object actualUdf = FunctionLoaderUtils.instantiateFunctionInstance(
method.getDeclaringClass(), udfDescriptionAnnotation.name());
if (actualUdf instanceof Configurable) {
((Configurable) actualUdf)
.configure(ksqlConfig.getKsqlFunctionsConfigProps(functionName));
ExtensionSecurityManager.INSTANCE.pushInUdf();
try {
((Configurable) actualUdf)
.configure(ksqlConfig.getKsqlFunctionsConfigProps(functionName));
} finally {
ExtensionSecurityManager.INSTANCE.popOutUdf();
}
}
final PluggableUdf theUdf = new PluggableUdf(invoker, actualUdf);
return metrics.<Kudf>map(m -> new UdfMetricProducer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

package io.confluent.ksql.security;

import io.confluent.ksql.function.udf.PluggableUdf;
import java.security.AllPermission;
import java.security.CodeSource;
import java.security.Permission;
Expand Down Expand Up @@ -60,20 +59,16 @@ public boolean implies(final ProtectionDomain domain, final Permission permissio
}

public synchronized void pushInUdf() {
if (validateCaller()) {
if (UDF_IS_EXECUTING.get() == null) {
UDF_IS_EXECUTING.set(new Stack<>());
}
UDF_IS_EXECUTING.get().push(true);
if (UDF_IS_EXECUTING.get() == null) {
UDF_IS_EXECUTING.set(new Stack<>());
}
UDF_IS_EXECUTING.get().push(true);
}

public void popOutUdf() {
if (validateCaller()) {
final Stack<Boolean> stack = UDF_IS_EXECUTING.get();
if (stack != null && !stack.isEmpty()) {
stack.pop();
}
final Stack<Boolean> stack = UDF_IS_EXECUTING.get();
if (stack != null && !stack.isEmpty()) {
stack.pop();
}
}

Expand All @@ -93,18 +88,8 @@ public void checkExec(final String cmd) {
super.checkExec(cmd);
}


private boolean inUdfExecution() {
final Stack<Boolean> executing = UDF_IS_EXECUTING.get();
return executing != null && !executing.isEmpty();
}

/**
* Check if the caller is a PluggableUdf. It will be the third
* item in the class array.
* @return true if caller is allowed
*/
private boolean validateCaller() {
return getClassContext()[2].equals(PluggableUdf.class);
}
}
Loading

0 comments on commit a37688c

Please sign in to comment.