From a37688cdd33428c1f5c4e67bbc776835f72b0403 Mon Sep 17 00:00:00 2001 From: James Hughes Date: Thu, 24 Mar 2022 14:22:59 -0400 Subject: [PATCH] fix: Apply the ExtensionSecurityManager to UDAFs (#8776) * fix: Apply the ExtensionSecurityManager to UDAFs Addresses https://github.com/confluentinc/ksql/issues/8662 --- .../ksql/function/BaseAggregateFunction.java | 16 +- .../ksql/function/FunctionLoaderUtils.java | 4 + .../ksql/function/UdafAggregateFunction.java | 4 + .../ksql/function/UdafFactoryInvoker.java | 5 +- .../function/UdafTableAggregateFunction.java | 8 +- .../io/confluent/ksql/function/UdfLoader.java | 10 +- .../security/ExtensionSecurityManager.java | 27 +- .../ksql/function/UdfLoaderTest.java | 296 ++++++++++++++++- .../ksql/function/UdtfLoaderTest.java | 59 +++- .../ksql/function/udaf/BadTestUdaf.java | 310 ++++++++++++++++++ .../ksql/function/udf/BadTestUdf.java | 44 +++ .../ksql/function/udf/BadTestUdtf.java | 147 +++++++++ 12 files changed, 896 insertions(+), 34 deletions(-) create mode 100644 ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/BadTestUdaf.java create mode 100644 ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/BadTestUdf.java create mode 100644 ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/BadTestUdtf.java diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/BaseAggregateFunction.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/BaseAggregateFunction.java index 1e0fc010b521..0fdcec704d84 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/BaseAggregateFunction.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/BaseAggregateFunction.java @@ -21,6 +21,7 @@ import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.ksql.SchemaConverters; import io.confluent.ksql.schema.ksql.types.SqlType; +import io.confluent.ksql.security.ExtensionSecurityManager; import io.confluent.ksql.util.KsqlException; import java.util.List; import java.util.Objects; @@ -55,12 +56,17 @@ public BaseAggregateFunction( ) { this.argIndexInValue = argIndexInValue; this.initialValueSupplier = () -> { - final A val = initialValueSupplier.get(); - if (val instanceof Struct && !((Struct) val).schema().isOptional()) { - throw new KsqlException("Initialize function for " + functionName - + " must return struct with optional schema"); + ExtensionSecurityManager.INSTANCE.pushInUdf(); + try { + final A val = initialValueSupplier.get(); + if (val instanceof Struct && !((Struct) val).schema().isOptional()) { + throw new KsqlException("Initialize function for " + functionName + + " must return struct with optional schema"); + } + return val; + } finally { + ExtensionSecurityManager.INSTANCE.popOutUdf(); } - return val; }; this.aggregateSchema = Objects.requireNonNull(aggregateType, "aggregateType"); this.outputSchema = Objects.requireNonNull(outputType, "outputType"); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java index 8cac1a9eb61d..a5bce7ff7108 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java @@ -29,6 +29,7 @@ import io.confluent.ksql.schema.ksql.SqlTypeParser; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.security.ExtensionSecurityManager; import io.confluent.ksql.util.KsqlException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; @@ -262,6 +263,7 @@ private static SqlType invokeSchemaProviderMethod( final String functionName ) { try { + ExtensionSecurityManager.INSTANCE.pushInUdf(); return (SqlType) m.invoke(instance, args); } catch (IllegalAccessException | InvocationTargetException e) { @@ -269,6 +271,8 @@ private static SqlType invokeSchemaProviderMethod( + "method %s for UDF %s. ", m.getName(), functionName ), e); + } finally { + ExtensionSecurityManager.INSTANCE.popOutUdf(); } } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafAggregateFunction.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafAggregateFunction.java index 9135e6b4bc69..8415c5963475 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafAggregateFunction.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafAggregateFunction.java @@ -18,6 +18,7 @@ import io.confluent.ksql.GenericKey; import io.confluent.ksql.function.udaf.Udaf; import io.confluent.ksql.schema.ksql.types.SqlType; +import io.confluent.ksql.security.ExtensionSecurityManager; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -106,9 +107,12 @@ private static Optional getSensor( private static T timed(final Optional maybeSensor, final Supplier task) { final long start = Time.SYSTEM.nanoseconds(); try { + // Since the timed() function wraps the calls to Udafs, we use it to protect the calls. + ExtensionSecurityManager.INSTANCE.pushInUdf(); return task.get(); } finally { maybeSensor.ifPresent(sensor -> sensor.record(Time.SYSTEM.nanoseconds() - start)); + ExtensionSecurityManager.INSTANCE.popOutUdf(); } } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafFactoryInvoker.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafFactoryInvoker.java index 8eaf7bc290c0..aff82d44c3f8 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafFactoryInvoker.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafFactoryInvoker.java @@ -24,6 +24,7 @@ import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.SqlTypeParser; import io.confluent.ksql.schema.ksql.types.SqlType; +import io.confluent.ksql.security.ExtensionSecurityManager; import io.confluent.ksql.util.KsqlException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; @@ -89,8 +90,8 @@ KsqlAggregateFunction createFunction(final AggregateFunctionInitArguments initAr final List argTypeList) { final Object[] factoryArgs = initArgs.args().toArray(); try { + ExtensionSecurityManager.INSTANCE.pushInUdf(); final Udaf udaf = (Udaf)method.invoke(null, factoryArgs); - udaf.initializeTypeArguments(argTypeList); if (udaf instanceof Configurable) { ((Configurable) udaf).configure(initArgs.config()); @@ -134,6 +135,8 @@ KsqlAggregateFunction createFunction(final AggregateFunctionInitArguments initAr } catch (final Exception e) { LOG.error("Failed to invoke UDAF factory method", e); throw new KsqlException("Failed to invoke UDAF factory method", e); + } finally { + ExtensionSecurityManager.INSTANCE.popOutUdf(); } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafTableAggregateFunction.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafTableAggregateFunction.java index 2e1d49dc6456..cfc81dabed26 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafTableAggregateFunction.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdafTableAggregateFunction.java @@ -19,6 +19,7 @@ import io.confluent.ksql.function.udaf.TableUdaf; import io.confluent.ksql.function.udaf.Udaf; import io.confluent.ksql.schema.ksql.types.SqlType; +import io.confluent.ksql.security.ExtensionSecurityManager; import java.util.List; import java.util.Optional; import org.apache.kafka.common.metrics.Metrics; @@ -42,6 +43,11 @@ public UdafTableAggregateFunction( @Override public A undo(final I valueToUndo, final A aggregateValue) { - return ((TableUdaf)udaf).undo(valueToUndo, aggregateValue); + ExtensionSecurityManager.INSTANCE.pushInUdf(); + try { + return ((TableUdaf)udaf).undo(valueToUndo, aggregateValue); + } finally { + ExtensionSecurityManager.INSTANCE.popOutUdf(); + } } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java index 440cbcd1d131..8584623e8263 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java @@ -24,6 +24,7 @@ import io.confluent.ksql.function.udf.UdfMetadata; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.ksql.SqlTypeParser; +import io.confluent.ksql.security.ExtensionSecurityManager; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; import java.lang.reflect.Method; @@ -186,8 +187,13 @@ private Function getUdfFactory( final Object actualUdf = FunctionLoaderUtils.instantiateFunctionInstance( method.getDeclaringClass(), udfDescriptionAnnotation.name()); if (actualUdf instanceof Configurable) { - ((Configurable) actualUdf) - .configure(ksqlConfig.getKsqlFunctionsConfigProps(functionName)); + ExtensionSecurityManager.INSTANCE.pushInUdf(); + try { + ((Configurable) actualUdf) + .configure(ksqlConfig.getKsqlFunctionsConfigProps(functionName)); + } finally { + ExtensionSecurityManager.INSTANCE.popOutUdf(); + } } final PluggableUdf theUdf = new PluggableUdf(invoker, actualUdf); return metrics.map(m -> new UdfMetricProducer( diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/security/ExtensionSecurityManager.java b/ksqldb-engine/src/main/java/io/confluent/ksql/security/ExtensionSecurityManager.java index 49e87706d869..316fb7fae62a 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/security/ExtensionSecurityManager.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/security/ExtensionSecurityManager.java @@ -15,7 +15,6 @@ package io.confluent.ksql.security; -import io.confluent.ksql.function.udf.PluggableUdf; import java.security.AllPermission; import java.security.CodeSource; import java.security.Permission; @@ -60,20 +59,16 @@ public boolean implies(final ProtectionDomain domain, final Permission permissio } public synchronized void pushInUdf() { - if (validateCaller()) { - if (UDF_IS_EXECUTING.get() == null) { - UDF_IS_EXECUTING.set(new Stack<>()); - } - UDF_IS_EXECUTING.get().push(true); + if (UDF_IS_EXECUTING.get() == null) { + UDF_IS_EXECUTING.set(new Stack<>()); } + UDF_IS_EXECUTING.get().push(true); } public void popOutUdf() { - if (validateCaller()) { - final Stack stack = UDF_IS_EXECUTING.get(); - if (stack != null && !stack.isEmpty()) { - stack.pop(); - } + final Stack stack = UDF_IS_EXECUTING.get(); + if (stack != null && !stack.isEmpty()) { + stack.pop(); } } @@ -93,18 +88,8 @@ public void checkExec(final String cmd) { super.checkExec(cmd); } - private boolean inUdfExecution() { final Stack executing = UDF_IS_EXECUTING.get(); return executing != null && !executing.isEmpty(); } - - /** - * Check if the caller is a PluggableUdf. It will be the third - * item in the class array. - * @return true if caller is allowed - */ - private boolean validateCaller() { - return getClassContext()[2].equals(PluggableUdf.class); - } } diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java index 2b3ffacaa3ab..8289cb704a4b 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java @@ -51,7 +51,6 @@ import io.confluent.ksql.function.udf.UdfParameter; import io.confluent.ksql.function.udf.UdfSchemaProvider; import io.confluent.ksql.metastore.TypeRegistry; -import io.confluent.ksql.metrics.MetricCollectors; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.schema.ksql.SqlArgument; import io.confluent.ksql.schema.ksql.SqlTypeParser; @@ -62,6 +61,7 @@ import io.confluent.ksql.schema.ksql.types.SqlStruct; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.security.ExtensionSecurityManager; import io.confluent.ksql.test.util.KsqlTestFolder; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; @@ -90,6 +90,7 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import org.junit.function.ThrowingRunnable; import org.junit.rules.TemporaryFolder; /** @@ -138,6 +139,62 @@ public void shouldLoadFunctionsInKsqlEngine() { assertThat(substring2.evaluate("foo", 2, 1), equalTo("o")); } + @Test + public void shouldLoadBadFunctionButNotLetItExit() { + // Given: + final List argList = Arrays.asList(SqlArgument.of(SqlTypes.STRING)); + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + final UdfFactory function = FUNC_REG.getUdfFactory(FunctionName.of("bad_test_udf")); + assertThat(function, not(nullValue())); + + KsqlScalarFunction ksqlScalarFunction = function.getFunction(argList); + + // When: + final Exception e1 = assertThrows( + KsqlException.class, + () -> ksqlScalarFunction.getReturnType(argList) + ); + + // Then: + assertThat(e1.getMessage(), containsString( + "Cannot invoke the schema provider method exit for UDF bad_test_udf.")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + + @Test + public void shouldLoadBadFunctionButNotLetItExit2() { + // Given: + final List argList = Arrays.asList(SqlArgument.of(SqlTypes.STRING)); + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + final UdfFactory function = FUNC_REG.getUdfFactory(FunctionName.of("bad_test_udf")); + assertThat(function, not(nullValue())); + + KsqlScalarFunction ksqlScalarFunction = function.getFunction(argList); + final Kudf badFunction = ksqlScalarFunction.newInstance(ksqlConfig); + + // Given: + final Exception e2 = assertThrows( + KsqlFunctionException.class, + () -> badFunction.evaluate("foo") + ); + + // Then: + assertThat(e2.getMessage(), containsString( + "Failed to invoke function public org.apache.kafka.connect.data.Struct " + + "io.confluent.ksql.function.udf.BadTestUdf.returnList(java.lang.String)")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + @SuppressWarnings("unchecked") @Test public void shouldLoadUdafs() { @@ -180,6 +237,243 @@ public void shouldLoadStructUdafs() { equalTo(new Struct(schema).put("A", 1).put("B", 2))); } + @Test + public void shouldNotLetBadUdafsExitWithBadCreate() { + // Given: + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + // When: + // This will exit via create + final Exception e1 = assertThrows( + KsqlException.class, + () -> { + + KsqlAggregateFunction function = ((KsqlAggregateFunction) FUNC_REG + .getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.array(SqlTypes.INTEGER), + AggregateFunctionInitArguments.EMPTY_ARGS)); + function.aggregate("foo", 2L); + } + ); + + // Then: + assertThat(e1.getMessage(), containsString("Failed to invoke UDAF factory method")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + + @Test + public void shouldNotLetBadUdafsExitWithBadConfigure() { + // Given: + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + // When: + // This will exit via configure + final Exception e2 = assertThrows( + KsqlException.class, + () -> + ((Configurable)FUNC_REG + .getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.INTEGER, + AggregateFunctionInitArguments.EMPTY_ARGS)).configure(Collections.EMPTY_MAP) + ); + + // Then: + assertThat(e2.getMessage(), containsString("Failed to invoke UDAF factory method")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + + @Test + public void shouldNotLetBadUdafsExitWithBadInitialize() { + // Given: + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + // When: + // This will exit via initialize + final Exception e3 = assertThrows( + SecurityException.class, + new ThrowingRunnable() { + @Override + public void run() throws Throwable { + FUNC_REG + .getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.DOUBLE, + AggregateFunctionInitArguments.EMPTY_ARGS).getInitialValueSupplier().get(); + } + } + ); + + // Then: + assertThat(e3.getMessage(), containsString("A UDF attempted to call System.exit")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + + @Test + public void shouldNotLetBadUdafsExitWithBadMap() { + // Given: + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + // When: + // This will exit via map + final Exception e4 = assertThrows( + SecurityException.class, + () -> + ((KsqlAggregateFunction) FUNC_REG + .getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.BOOLEAN, + AggregateFunctionInitArguments.EMPTY_ARGS)).getResultMapper().apply(true) + ); + + // Then: + assertThat(e4.getMessage(), containsString("A UDF attempted to call System.exit")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + + + @Test + public void shouldNotLetBadUdafsExitWithBadMerge() { + // Given: + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + // When: + // This will exit via merge + final Schema schema = SchemaBuilder.struct() + .field("A", Schema.OPTIONAL_INT32_SCHEMA) + .field("B", Schema.OPTIONAL_INT32_SCHEMA) + .optional() + .build(); + final SqlStruct sqlSchema = SqlTypes.struct() + .field("A", SqlTypes.INTEGER) + .field("B", SqlTypes.INTEGER) + .build(); + final Struct input = new Struct(schema).put("A", 0).put("B", 0); + final Exception e5 = assertThrows( + SecurityException.class, + () -> + ((KsqlAggregateFunction) FUNC_REG.getAggregateFunction(FunctionName.of("bad_test_udaf"), + sqlSchema, + AggregateFunctionInitArguments.EMPTY_ARGS)).getMerger().apply(null, input, input) + ); + + // Then: + assertThat(e5.getMessage(), containsString("A UDF attempted to call System.exit")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + + @Test + public void shouldNotLetBadUdafsExitWithBadAggregate() { + // Given: + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + // When: + // This will exit via aggregate + final Exception e6 = assertThrows( + SecurityException.class, + () -> + ((KsqlAggregateFunction) FUNC_REG + .getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.STRING, + AggregateFunctionInitArguments.EMPTY_ARGS)).aggregate("foo", 2L) + ); + + // Then: + assertThat(e6.getMessage(), containsString("A UDF attempted to call System.exit")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + + @Test + public void shouldNotLetBadUdatsExitWithBadUnfo() { + // Given: + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + // When: + // This will exit via undo. + final Exception error = assertThrows( + SecurityException.class, + () -> + ((TableAggregationFunction) FUNC_REG + .getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.BIGINT, + AggregateFunctionInitArguments.EMPTY_ARGS)).undo(1L, 1L) + ); + + // Then: + assertThat(error.getMessage(), containsString("A UDF attempted to call System.exit")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + + @Test + public void shouldNotLetBadUdafsExitWithBadGetAggregateSqlType() { + // Given: + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + // When: + // This will exit due to a bad getAggregateSqlType. + final Exception error = assertThrows( + KsqlException.class, + () -> { + KsqlAggregateFunction func = ((KsqlAggregateFunction) FUNC_REG + .getAggregateFunction(FunctionName.of("bad_test_udaf"), + SqlTypes.array(SqlTypes.BIGINT), + AggregateFunctionInitArguments.EMPTY_ARGS)); + } + ); + + // Then: + assertThat(error.getCause().getMessage(), containsString("A UDF attempted to call System.exit")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + + @Test + public void shouldNotLetBadUdafsExitWithBadGetReturnSqlType() { + // Given: + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + // When: + // This will exit due to a bad getReturnSqlType. + final Exception error = assertThrows( + KsqlException.class, + () -> { + KsqlAggregateFunction func = ((KsqlAggregateFunction) FUNC_REG + .getAggregateFunction(FunctionName.of("bad_test_udaf"), SqlTypes.array(SqlTypes.BOOLEAN), + AggregateFunctionInitArguments.EMPTY_ARGS)); + } + ); + + // Then: + assertThat(error.getCause().getMessage(), containsString("A UDF attempted to call System.exit")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + @Test public void shouldLoadDecimalUdfs() { // Given: diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java index 3316a3eea90b..f50fcd0631ad 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java @@ -22,6 +22,7 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableList; @@ -32,10 +33,12 @@ import io.confluent.ksql.schema.ksql.SqlTypeParser; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.security.ExtensionSecurityManager; import io.confluent.ksql.util.KsqlException; import java.io.File; import java.math.BigDecimal; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -57,10 +60,10 @@ public void shouldLoadSimpleParams() { SqlArgument.of(SqlTypes.INTEGER), SqlArgument.of(SqlTypes.BIGINT), SqlArgument.of(SqlTypes.DOUBLE), - SqlArgument.of( SqlTypes.BOOLEAN), + SqlArgument.of(SqlTypes.BOOLEAN), SqlArgument.of(SqlTypes.STRING), - SqlArgument.of( DECIMAL_SCHEMA), - SqlArgument.of( STRUCT_SCHEMA) + SqlArgument.of(DECIMAL_SCHEMA), + SqlArgument.of(STRUCT_SCHEMA) ); // When: @@ -227,6 +230,56 @@ public void shouldLoadVarArgsMethod() { assertThat(function.getReturnType(args), equalTo(STRUCT_SCHEMA)); } + @Test + public void shouldNotLetBadUdtfsExitViaBadSchemaProvider() { + // Given: + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + // When: + final Exception error = assertThrows( + KsqlException.class, + () -> + FUNC_REG.getTableFunction( + FunctionName.of("bad_test_udtf"), + Collections.singletonList(SqlArgument.of(SqlTypes.decimal(2,0)))) + .getReturnType(ImmutableList.of(SqlArgument.of(SqlTypes.DOUBLE))) + ); + + // Then: + assertThat(error.getMessage(), containsString( + "Cannot invoke the schema provider method provideSchema for UDF bad_test_udtf.")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + + @Test + public void shouldNotLetBadUdtfsExit() { + // Given: + // We do need to set up the ExtensionSecurityManager for our test. + // This is controlled by a feature flag and in this test, we just directly enable it. + SecurityManager manager = System.getSecurityManager(); + System.setSecurityManager(ExtensionSecurityManager.INSTANCE); + + // When: + final Exception error = assertThrows( + KsqlFunctionException.class, + () -> + FUNC_REG.getTableFunction( + FunctionName.of("bad_test_udtf"), + Collections.singletonList(SqlArgument.of(SqlTypes.STRING))).apply("foo") + ); + + // Then: + assertThat(error.getMessage(), containsString( + "Failed to invoke function public java.util.List " + + "io.confluent.ksql.function.udf.BadTestUdtf.listStringReturn(java.lang.String)")); + System.setSecurityManager(manager); + assertEquals(System.getSecurityManager(), manager); + } + @Test public void shouldNotLoadUdtfWithWrongReturnValue() { // Given: diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/BadTestUdaf.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/BadTestUdaf.java new file mode 100644 index 000000000000..ffcb99ce353c --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udaf/BadTestUdaf.java @@ -0,0 +1,310 @@ +/* + * Copyright 2022 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.function.udaf; + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import io.confluent.ksql.schema.ksql.types.SqlType; +import io.confluent.ksql.util.KsqlConstants; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.apache.kafka.common.Configurable; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; + +@UdafDescription( + name = "bad_test_udaf", + description = "bad_test_udaf", + author = KsqlConstants.CONFLUENT_AUTHOR +) +public final class BadTestUdaf { + + private BadTestUdaf() { + } + + @SuppressFBWarnings("DM_EXIT") + private static void runBadCode() { + System.exit(-1); + } + + @UdafFactory(description = "sums longs with bad 'undo' method") + public static TableUdaf createSumLong() { + return new TableUdaf() { + @Override + public Long undo(final Long valueToUndo, final Long aggregateValue) { + runBadCode(); + return aggregateValue - valueToUndo; + } + + @Override + public Long initialize() { + return 0L; + } + + @Override + public Long aggregate(final Long value, final Long aggregate) { + return aggregate + value; + } + + @Override + public Long merge(final Long aggOne, final Long aggTwo) { + return aggOne + aggTwo; + } + + @Override + public Long map(final Long agg) { + return agg; + } + }; + } + + @UdafFactory(description = "sums int with a bad factory call") + public static TableUdaf, Long, Long> createFactoryExiting() { + runBadCode(); + return null; + } + + @UdafFactory(description = "sums int") + public static TableUdaf createSumInt() { + return new SumIntUdaf(); + } + + @UdafFactory(description = "sums double with a bad initialize") + public static Udaf createSumDouble() { + return new Udaf() { + @Override + public Double initialize() { + runBadCode(); + return 0.0; + } + + @Override + public Double aggregate(final Double val, final Double aggregate) { + return aggregate + val; + } + + @Override + public Double merge(final Double aggOne, final Double aggTwo) { + return aggOne + aggTwo; + } + + @Override + public Double map(final Double agg) { + return agg; + } + }; + } + + @UdafFactory(description = "sums the length of strings with a bad aggregate") + public static Udaf createSumLengthString() { + return new Udaf() { + @Override + public Long initialize() { + return (long) "initial".length(); + } + + @Override + public Long aggregate(final String s, final Long aggregate) { + runBadCode(); + return aggregate + s.length(); + } + + @Override + public Long merge(final Long aggOne, final Long aggTwo) { + return aggOne + aggTwo; + } + + @Override + public Long map(final Long agg) { + return agg; + } + }; + } + + @UdafFactory( + description = "returns a struct with {SUM(in->A), SUM(in->B)} with a bad merger", + paramSchema = "STRUCT", + aggregateSchema = "STRUCT", + returnSchema = "STRUCT") + public static Udaf createStructUdaf() { + return new Udaf() { + + @Override + public Struct initialize() { + return new Struct(SchemaBuilder.struct() + .field("A", Schema.OPTIONAL_INT32_SCHEMA) + .field("B", Schema.OPTIONAL_INT32_SCHEMA) + .optional() + .build()) + .put("A", 0) + .put("B", 0); + } + + @Override + public Struct aggregate(final Struct current, final Struct aggregate) { + aggregate.put("A", current.getInt32("A") + aggregate.getInt32("A")); + aggregate.put("B", current.getInt32("B") + aggregate.getInt32("B")); + return aggregate; + } + + @Override + public Struct merge(final Struct aggOne, final Struct aggTwo) { + runBadCode(); + return aggregate(aggOne, aggTwo); + } + + @Override + public Struct map(final Struct agg) { + return agg; + } + }; + } + + // With a bad map method + static class SumIntUdaf implements TableUdaf, Configurable { + + public static final String INIT_CONFIG = "ksql.functions.test_udaf.init"; + private long init = 0L; + + @Override + public Long undo(final Integer valueToUndo, final Long aggregateValue) { + return aggregateValue - valueToUndo; + } + + @Override + public Long initialize() { + return init; + } + + @Override + public Long aggregate(final Integer current, final Long aggregate) { + return current + aggregate; + } + + @Override + public Long merge(final Long aggOne, final Long aggTwo) { + return aggOne + aggTwo; + } + + @Override + public Long map(final Long agg) { + runBadCode(); + return agg; + } + + @Override + public void configure(final Map map) { + runBadCode(); + final Object init = map.get(INIT_CONFIG); + this.init = (init == null) ? this.init : (long) init; + } + } + + @UdafFactory( + description = "bad map", + paramSchema = "BOOLEAN", + aggregateSchema = "BOOLEAN", + returnSchema = "BOOLEAN") + public static Udaf createBadMapUdaf() { + return new Udaf() { + + @Override + public Boolean initialize() { + return Boolean.FALSE; + } + + @Override + public Boolean aggregate(Boolean current, Boolean aggregate) { + return Boolean.FALSE; + } + + @Override + public Boolean merge(Boolean aggOne, Boolean aggTwo) { + return Boolean.FALSE; + } + + @Override + public Boolean map(Boolean agg) { + runBadCode(); + return Boolean.FALSE; + } + }; + } + + @UdafFactory(description = "sums the length of strings with a bad aggregate") + public static Udaf, Long, Long> createBadAggregateTypeUdaf() { + return new Udaf, Long, Long>() { + @Override + public Long initialize() { + return null; + } + + @Override + public Long aggregate(List current, Long aggregate) { + return null; + } + + @Override + public Long merge(Long aggOne, Long aggTwo) { + return null; + } + + @Override + public Long map(Long agg) { + return null; + } + + @Override + public Optional getAggregateSqlType() { + runBadCode(); + return Optional.empty(); + } + }; + } + + @UdafFactory(description = "sums the length of strings with a bad aggregate") + public static Udaf, Long, Long> createBadReturnTypeUdaf() { + return new Udaf, Long, Long>() { + @Override + public Long initialize() { + return null; + } + + @Override + public Long aggregate(List current, Long aggregate) { + return null; + } + + @Override + public Long merge(Long aggOne, Long aggTwo) { + return null; + } + + @Override + public Long map(Long agg) { + return null; + } + + @Override + public Optional getReturnSqlType() { + runBadCode(); + return Optional.empty(); + } + }; + } + +} diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/BadTestUdf.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/BadTestUdf.java new file mode 100644 index 000000000000..34abec19aa74 --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/BadTestUdf.java @@ -0,0 +1,44 @@ +/* + * Copyright 2022 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.function.udf; + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import io.confluent.ksql.schema.ksql.types.SqlStruct; +import io.confluent.ksql.schema.ksql.types.SqlType; +import io.confluent.ksql.schema.ksql.types.SqlTypes; +import java.util.List; +import org.apache.kafka.connect.data.Struct; + +@UdfDescription(name="bad_test_udf", description = "test") +@SuppressWarnings("unused") +public class BadTestUdf { + private static final SqlStruct RETURN = + SqlStruct.builder().field("A", SqlTypes.STRING).build(); + + @SuppressFBWarnings("DM_EXIT") + @Udf(description = "Sample Bad", schemaProvider = "exit") + public Struct returnList(String string) { + System.exit(-1); + return null; + } + + @SuppressFBWarnings("DM_EXIT") + @UdfSchemaProvider + public SqlType exit(final List params) { + System.exit(-3); + return RETURN; + } +} diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/BadTestUdtf.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/BadTestUdtf.java new file mode 100644 index 000000000000..0316eb85152c --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/BadTestUdtf.java @@ -0,0 +1,147 @@ +/* + * Copyright 2022 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.function.udf; + +import com.google.common.collect.ImmutableList; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import io.confluent.ksql.function.udtf.Udtf; +import io.confluent.ksql.function.udtf.UdtfDescription; +import io.confluent.ksql.schema.ksql.types.SqlDecimal; +import io.confluent.ksql.schema.ksql.types.SqlType; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.util.List; +import java.util.Map; +import org.apache.kafka.connect.data.Struct; + +@UdtfDescription(name = "bad_test_udtf", description = "test") +@SuppressWarnings("unused") +public class BadTestUdtf { + + @SuppressFBWarnings("DM_EXIT") + private static void runBadCode() { + System.exit(-1); + } + + @Udtf + public List standardParams( + final int i, final long l, final double d, final boolean b, final String s, + final BigDecimal bd, @UdfParameter(schema = "STRUCT") final Struct struct + ) { + return ImmutableList.of(String.valueOf(i), String.valueOf(l), String.valueOf(d), + String.valueOf(b), s, bd.toString(), struct.toString() + ); + } + + @Udtf + public List parameterizedListParams( + final List i, final List l, final List d, final List b, final List s, + final List bd, @UdfParameter(schema = "ARRAY>") final List struct + ) { + return ImmutableList + .of(String.valueOf(i.get(0)), String.valueOf(l.get(0)), String.valueOf(d.get(0)), + String.valueOf(b.get(0)), s.get(0), bd.get(0).toString(), struct.get(0).toString() + ); + } + + @Udtf + public List parameterizedMapParams( + final Map i, + final Map l, + final Map d, + final Map b, + final Map s, + final Map bd, + @UdfParameter(schema = "MAP>") final Map struct + ) { + return ImmutableList + .of( + String.valueOf(i.values().iterator().next()), + String.valueOf(l.values().iterator().next()), + String.valueOf(d.values().iterator().next()), + String.valueOf(b.values().iterator().next()), + s.values().iterator().next(), + bd.values().iterator().next().toString(), + struct.values().iterator().next().toString() + ); + } + + @Udtf + public List parameterizedMapParams2( + final Map i, + final Map l, + final Map d, + final Map b, + final Map s, + final Map bd, + @UdfParameter(schema = "MAP>") final Map struct + ) { + return ImmutableList + .of( + String.valueOf(i.values().iterator().next()), + String.valueOf(l.values().iterator().next()), + String.valueOf(d.values().iterator().next()), + String.valueOf(b.values().iterator().next()), + s.values().iterator().next(), + bd.values().iterator().next().toString(), + struct.values().iterator().next().toString() + ); + } + + @Udtf + public List listIntegerReturn(final int i) { + return ImmutableList.of(i); + } + + @Udtf + public List listLongReturn(final long l) { + return ImmutableList.of(l); + } + + @Udtf + public List listBooleanReturn(final boolean b) + throws ClassNotFoundException, NoSuchMethodException, InvocationTargetException, IllegalAccessException { + Class shutdown = Class.forName("java.lang.Shutdown"); + Method method = shutdown.getDeclaredMethod("exit", int.class); + method.setAccessible(true); + method.invoke(shutdown, -10); + return ImmutableList.of(b); + } + + @Udtf + public List listStringReturn(final String s) { + runBadCode(); + return ImmutableList.of(s); + } + + @Udtf(schemaProvider = "provideSchema") + public List listBigDecimalReturnWithSchemaProvider(final BigDecimal bd) { + return ImmutableList.of(bd); + } + + @Udtf(schema = "STRUCT") + public List listStructReturn(@UdfParameter(schema = "STRUCT") final Struct struct) { + return ImmutableList.of(struct); + } + + @UdfSchemaProvider + public SqlType provideSchema(final List params) { + runBadCode(); + return SqlDecimal.of(30, 10); + } + +}