Skip to content

Commit

Permalink
feat: Apply the ExtensionSecurityManager to UDAFs
Browse files Browse the repository at this point in the history
Addresses #8662
  • Loading branch information
jnh5y committed Mar 15, 2022
1 parent 0c48959 commit 4437a44
Show file tree
Hide file tree
Showing 10 changed files with 682 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.confluent.ksql.schema.ksql.SqlTypeParser;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.security.ExtensionSecurityManager;
import io.confluent.ksql.util.KsqlException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -262,13 +263,16 @@ private static SqlType invokeSchemaProviderMethod(
final String functionName
) {
try {
ExtensionSecurityManager.INSTANCE.pushInUdf();
return (SqlType) m.invoke(instance, args);
} catch (IllegalAccessException
| InvocationTargetException e) {
throw new KsqlException(String.format("Cannot invoke the schema provider "
+ "method %s for UDF %s. ",
m.getName(), functionName
), e);
} finally {
ExtensionSecurityManager.INSTANCE.popOutUdf();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.SqlTypeParser;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.security.ExtensionSecurityManager;
import io.confluent.ksql.util.KsqlException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
Expand Down Expand Up @@ -87,12 +88,14 @@ KsqlAggregateFunction createFunction(final AggregateFunctionInitArguments initAr
final List<SqlArgument> argTypeList) {
final Object[] factoryArgs = initArgs.args().toArray();
try {
ExtensionSecurityManager.INSTANCE.pushInUdf();
final Udaf udaf = (Udaf)method.invoke(null, factoryArgs);

udaf.initializeTypeArguments(argTypeList);

if (udaf instanceof Configurable) {
((Configurable) udaf).configure(initArgs.config());
}
ExtensionSecurityManager.INSTANCE.popOutUdf();

final SqlType aggregateSqlType = (SqlType) udaf.getAggregateSqlType()
.orElseGet(() -> SchemaConverters.functionToSqlConverter().toSqlType(aggregateArgType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.confluent.ksql.function.udf.UdfMetadata;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.SqlTypeParser;
import io.confluent.ksql.security.ExtensionSecurityManager;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -186,8 +187,10 @@ private Function<KsqlConfig, Kudf> getUdfFactory(
final Object actualUdf = FunctionLoaderUtils.instantiateFunctionInstance(
method.getDeclaringClass(), udfDescriptionAnnotation.name());
if (actualUdf instanceof Configurable) {
ExtensionSecurityManager.INSTANCE.pushInUdf();
((Configurable) actualUdf)
.configure(ksqlConfig.getKsqlFunctionsConfigProps(functionName));
ExtensionSecurityManager.INSTANCE.popOutUdf();
}
final PluggableUdf theUdf = new PluggableUdf(invoker, actualUdf);
return metrics.<Kudf>map(m -> new UdfMetricProducer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

package io.confluent.ksql.security;

import io.confluent.ksql.function.FunctionLoaderUtils;
import io.confluent.ksql.function.UdfLoader;
import io.confluent.ksql.function.udf.PluggableUdf;
import java.lang.reflect.ReflectPermission;
import java.security.AllPermission;
import java.security.CodeSource;
import java.security.Permission;
Expand Down Expand Up @@ -93,6 +96,19 @@ public void checkExec(final String cmd) {
super.checkExec(cmd);
}

@Override
public void checkPermission(final Permission perm) {
System.out.println("Checking permission " + perm);
if (inUdfExecution()) {
if (perm instanceof ReflectPermission) {
throw new SecurityException("A UDF attempted to use reflection.");
}
if (perm instanceof RuntimePermission) {
throw new SecurityException("A UDF attempted to make a system call.");
}
}
super.checkPermission(perm);
}

private boolean inUdfExecution() {
final Stack<Boolean> executing = UDF_IS_EXECUTING.get();
Expand All @@ -105,6 +121,11 @@ private boolean inUdfExecution() {
* @return true if caller is allowed
*/
private boolean validateCaller() {
return getClassContext()[2].equals(PluggableUdf.class);
final Class caller = getClassContext()[2];
System.out.println("Caller is " + caller);
return caller.equals(PluggableUdf.class)
|| caller.equals(FunctionLoaderUtils.class)
|| caller.equals(UdfLoader.class)
|| caller.getName().equals("io.confluent.ksql.function.UdafFactoryInvoker");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import io.confluent.ksql.function.udf.UdfParameter;
import io.confluent.ksql.function.udf.UdfSchemaProvider;
import io.confluent.ksql.metastore.TypeRegistry;
import io.confluent.ksql.metrics.MetricCollectors;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.SqlArgument;
import io.confluent.ksql.schema.ksql.SqlTypeParser;
Expand All @@ -62,6 +61,7 @@
import io.confluent.ksql.schema.ksql.types.SqlStruct;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.security.ExtensionSecurityManager;
import io.confluent.ksql.test.util.KsqlTestFolder;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
Expand Down Expand Up @@ -138,6 +138,44 @@ public void shouldLoadFunctionsInKsqlEngine() {
assertThat(substring2.evaluate("foo", 2, 1), equalTo("o"));
}

@Test
public void shouldLoadBadFunctionButNotLetItExit() {
final List<SqlArgument> argList = Arrays.asList(SqlArgument.of(SqlTypes.STRING));
// We do need to set up the ExtensionSecurityManager for our test.
// This is controlled by a feature flag and in this test, we just directly enable it.
SecurityManager manager = System.getSecurityManager();
System.setSecurityManager(ExtensionSecurityManager.INSTANCE);

final UdfFactory function = FUNC_REG.getUdfFactory(FunctionName.of("test_udf"));
assertThat(function, not(nullValue()));

KsqlScalarFunction ksqlScalarFunction = function.getFunction(argList);
final Kudf badFunction = ksqlScalarFunction.newInstance(ksqlConfig);

final Exception e0 = assertThrows(
java.lang.SecurityException.class,
() -> FUNC_REG.getUdfFactory(FunctionName.of("bad_test_udf"))
.getFunction(argList).newInstance(ksqlConfig)
);
assertThat(e0.getMessage(), containsString("A UDF attempted to call System.exit"));

final Exception e1 = assertThrows(
KsqlException.class,
() -> ksqlScalarFunction.getReturnType(argList)
);
assertThat(e1.getMessage(), containsString(
"Cannot invoke the schema provider method exit for UDF test_udf."));

final Exception e2 = assertThrows(
KsqlFunctionException.class,
() -> badFunction.evaluate("foo")
);
assertThat(e2.getMessage(), containsString(
"Failed to invoke function public org.apache.kafka.connect.data.Struct "
+ "io.confluent.ksql.function.udf.TestUdf.returnList(java.lang.String)"));
System.setSecurityManager(manager);
}

@SuppressWarnings("unchecked")
@Test
public void shouldLoadUdafs() {
Expand Down Expand Up @@ -180,6 +218,100 @@ public void shouldLoadStructUdafs() {
equalTo(new Struct(schema).put("A", 1).put("B", 2)));
}

@Test
public void shouldNotLetBadUdafsExit() {
// We do need to set up the ExtensionSecurityManager for our test.
// This is controlled by a feature flag and in this test, we just directly enable it.
SecurityManager manager = System.getSecurityManager();
System.setSecurityManager(ExtensionSecurityManager.INSTANCE);

// This will exit via create
final Exception e1 = assertThrows(
KsqlException.class,
() ->
((KsqlAggregateFunction) FUNC_REG
.getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.array(SqlTypes.INTEGER),
AggregateFunctionInitArguments.EMPTY_ARGS)).aggregate("foo", 2L)
);
assertThat(e1.getMessage(), containsString("Failed to invoke UDAF factory method"));

// This will exit via configure
final Exception e2 = assertThrows(
KsqlException.class,
() ->
((Configurable)FUNC_REG
.getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.INTEGER,
AggregateFunctionInitArguments.EMPTY_ARGS)).configure(Collections.EMPTY_MAP)
);
assertThat(e2.getMessage(), containsString("Failed to invoke UDAF factory method"));


// This will exit via initialize
final Exception e3 = assertThrows(
SecurityException.class,
() ->
FUNC_REG
.getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.DOUBLE,
AggregateFunctionInitArguments.EMPTY_ARGS).getInitialValueSupplier().get()
);
assertThat(e3.getMessage(), containsString("A UDF attempted to call System.exit"));

// This will exit via map
final Exception e4 = assertThrows(
SecurityException.class,
() ->
((KsqlAggregateFunction) FUNC_REG
.getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.BOOLEAN,
AggregateFunctionInitArguments.EMPTY_ARGS)).getResultMapper().apply(true)
);
assertThat(e4.getMessage(), containsString("A UDF attempted to call System.exit"));


// This will exit via merge
final Schema schema = SchemaBuilder.struct()
.field("A", Schema.OPTIONAL_INT32_SCHEMA)
.field("B", Schema.OPTIONAL_INT32_SCHEMA)
.optional()
.build();
final SqlStruct sqlSchema = SqlTypes.struct()
.field("A", SqlTypes.INTEGER)
.field("B", SqlTypes.INTEGER)
.build();
final Struct input = new Struct(schema).put("A", 0).put("B", 0);
final Exception e5 = assertThrows(
SecurityException.class,
() ->
((KsqlAggregateFunction) FUNC_REG.getAggregateFunction(FunctionName.of("bad_test_udaf"),
sqlSchema,
AggregateFunctionInitArguments.EMPTY_ARGS)).getMerger().apply(null, input, input)
);
assertThat(e5.getMessage(), containsString("A UDF attempted to call System.exit"));


// This will exit via aggregate
final Exception e6 = assertThrows(
SecurityException.class,
() ->
((KsqlAggregateFunction) FUNC_REG
.getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.STRING,
AggregateFunctionInitArguments.EMPTY_ARGS)).aggregate("foo", 2L)
);
assertThat(e6.getMessage(), containsString("A UDF attempted to call System.exit"));

// This will exit via undo.
final Exception e7 = assertThrows(
SecurityException.class,
() ->
((TableAggregationFunction) FUNC_REG
.getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.BIGINT,
AggregateFunctionInitArguments.EMPTY_ARGS)).undo(1L, 1L)
);
assertThat(e7.getMessage(), containsString("A UDF attempted to call System.exit"));

System.setSecurityManager(manager);
}


@Test
public void shouldLoadDecimalUdfs() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@
import io.confluent.ksql.schema.ksql.SqlTypeParser;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.security.ExtensionSecurityManager;
import io.confluent.ksql.util.KsqlException;
import java.io.File;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -227,6 +229,61 @@ public void shouldLoadVarArgsMethod() {
assertThat(function.getReturnType(args), equalTo(STRUCT_SCHEMA));
}

@Test
public void shouldNotBadUdtfsExit() {
// We do need to set up the ExtensionSecurityManager for our test.
// This is controlled by a feature flag and in this test, we just directly enable it.
SecurityManager manager = System.getSecurityManager();
System.setSecurityManager(ExtensionSecurityManager.INSTANCE);

final Exception e1 = assertThrows(
KsqlException.class,
() ->
FUNC_REG.getTableFunction(
FunctionName.of("bad_test_udtf"),
Collections.singletonList(SqlArgument.of(SqlTypes.decimal(2,0))))
.getReturnType(ImmutableList.of(SqlArgument.of(SqlTypes.DOUBLE)))
);
assertThat(e1.getMessage(), containsString(
"Cannot invoke the schema provider method provideSchema for UDF bad_test_udtf."));

final Exception e2 = assertThrows(
KsqlFunctionException.class,
() ->
FUNC_REG.getTableFunction(
FunctionName.of("bad_test_udtf"),
Collections.singletonList(SqlArgument.of(SqlTypes.STRING))).apply("foo")
);
assertThat(e2.getMessage(), containsString(
"Failed to invoke function public java.util.List "
+ "io.confluent.ksql.function.udf.BadTestUdtf.listStringReturn(java.lang.String)"));

// Stop reflection
final Exception e3 = assertThrows(
KsqlFunctionException.class,
() ->
FUNC_REG.getTableFunction(
FunctionName.of("bad_test_udtf"),
Collections.singletonList(SqlArgument.of(SqlTypes.BOOLEAN))).apply(true)
);
assertThat(e3.getMessage(), containsString(
"Failed to invoke function public java.util.List "
+ "io.confluent.ksql.function.udf.BadTestUdtf.listBooleanReturn(boolean)"));

final Exception e4 = assertThrows(
KsqlFunctionException.class,
() ->
FUNC_REG.getTableFunction(
FunctionName.of("bad_test_udtf"),
Collections.singletonList(SqlArgument.of(SqlTypes.DOUBLE))).apply(1.234)
);
assertThat(e4.getMessage(), containsString(
"Failed to invoke function public java.util.List "
+ "io.confluent.ksql.function.udf.BadTestUdtf.listDoubleReturn(double)"));

System.setSecurityManager(manager);
}

@Test
public void shouldNotLoadUdtfWithWrongReturnValue() {
// Given:
Expand Down
Loading

0 comments on commit 4437a44

Please sign in to comment.