Skip to content

Commit

Permalink
fix: allow implicit cast of numbers literals to decimals on insert/se…
Browse files Browse the repository at this point in the history
…lect (#6005)
  • Loading branch information
spena authored Aug 18, 2020
1 parent 6c08e5d commit 2bc15dd
Show file tree
Hide file tree
Showing 16 changed files with 704 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

import com.google.errorprone.annotations.Immutable;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.schema.ksql.types.SqlDecimal;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.utils.FormatOptions;
import io.confluent.ksql.util.DecimalUtil;

import java.util.Objects;

/**
Expand Down Expand Up @@ -111,6 +114,20 @@ public int index() {
return index;
}

public boolean canImplicitlyCast(final SqlType toType) {
if (type instanceof SqlDecimal && toType instanceof SqlDecimal) {
return DecimalUtil.canImplicitlyCast((SqlDecimal)type, (SqlDecimal)toType);
}

return type.equals(toType);
}

public boolean equalsIgnoreType(final Column that) {
return Objects.equals(index, that.index)
&& Objects.equals(namespace, that.namespace)
&& Objects.equals(name, that.name);
}

@Override
public boolean equals(final Object o) {
if (this == o) {
Expand All @@ -120,10 +137,7 @@ public boolean equals(final Object o) {
return false;
}
final Column that = (Column) o;
return Objects.equals(index, that.index)
&& Objects.equals(namespace, that.namespace)
&& Objects.equals(type, that.type)
&& Objects.equals(name, that.name);
return equalsIgnoreType(that) && Objects.equals(type, that.type);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,27 @@ public boolean isKeyColumn(final ColumnName columnName) {
.isPresent();
}

/**
* Returns True if this schema is compatible with {@code other} schema.
*/
public boolean compatibleSchema(final LogicalSchema other) {
if (columns().size() != other.columns().size()) {
return false;
}

for (int i = 0; i < columns().size(); i++) {
final Column s1Column = columns().get(i);
final Column s2Column = other.columns().get(i);
final SqlType s2Type = s2Column.type();

if (!s1Column.equalsIgnoreType(s2Column) || !s1Column.canImplicitlyCast(s2Type)) {
return false;
}
}

return true;
}

@Override
public boolean equals(final Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ public static SqlDecimal toSqlDecimal(final SqlType schema) {
}
}

/**
* Returns True if {@code s1} can be implicitly cast to {@code s2}.
* </p>
* A decimal {@code s1} can be implicitly cast if precision/scale fits into the {@code s2}
* precision/scale.
* <ul>
* <li>{@code s1} scale <= {@code s2} scale</li>
* <li>{@code s1} left digits <= {@code s2} left digits</li>
* </ul>
*/
public static boolean canImplicitlyCast(final SqlDecimal s1, final SqlDecimal s2) {
return s1.getScale() <= s2.getScale()
&& (s1.getPrecision() - s1.getScale()) <= (s2.getPrecision() - s2.getScale());
}

public static BigDecimal cast(final long value, final int precision, final int scale) {
validateParameters(precision, scale);
final BigDecimal decimal = new BigDecimal(value, new MathContext(precision));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.schema.ksql.Column.Namespace;
import io.confluent.ksql.schema.ksql.LogicalSchema.Builder;
import io.confluent.ksql.schema.ksql.types.SqlDecimal;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.schema.utils.FormatOptions;
import io.confluent.ksql.util.KsqlException;
Expand Down Expand Up @@ -746,4 +747,69 @@ private static org.apache.kafka.connect.data.Field connectField(
) {
return new org.apache.kafka.connect.data.Field(fieldName, index, schema);
}

@Test
public void shouldSchemaNoCompatibleWithDifferentSizes() {
// Given:
final LogicalSchema schema = LogicalSchema.builder()
.valueColumn(F0, STRING)
.valueColumn(F1, BIGINT)
.build();
final LogicalSchema otherSchema = LogicalSchema.builder()
.valueColumn(F0, STRING)
.valueColumn(F1, BIGINT)
.valueColumn(V1, BIGINT)
.build();

// Then:
assertThat(schema.compatibleSchema(otherSchema), is(false));
}

@Test
public void shouldSchemaNoCompatibleOnDifferentColumnName() {
// Given:
final LogicalSchema schema = LogicalSchema.builder()
.valueColumn(F0, STRING)
.valueColumn(F1, BIGINT)
.build();
final LogicalSchema otherSchema = LogicalSchema.builder()
.valueColumn(F0, STRING)
.valueColumn(V1, BIGINT)
.build();

// Then:
assertThat(schema.compatibleSchema(otherSchema), is(false));
}

@Test
public void shouldSchemaNoCompatibleWhenCannotCastType() {
// Given:
final LogicalSchema schema = LogicalSchema.builder()
.valueColumn(F0, STRING)
.valueColumn(F1, BIGINT)
.build();
final LogicalSchema otherSchema = LogicalSchema.builder()
.valueColumn(F0, STRING)
.valueColumn(F1, INTEGER)
.build();

// Then:
assertThat(schema.compatibleSchema(otherSchema), is(false));
}

@Test
public void shouldSchemaCompatibleWithImplicitlyCastType() {
// Given:
final LogicalSchema schema = LogicalSchema.builder()
.valueColumn(F0, STRING)
.valueColumn(F1, SqlDecimal.of(5, 2))
.build();
final LogicalSchema otherSchema = LogicalSchema.builder()
.valueColumn(F0, STRING)
.valueColumn(F1, SqlDecimal.of(6, 3))
.build();

// Then:
assertThat(schema.compatibleSchema(otherSchema), is(true));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -510,4 +510,56 @@ public void shouldNotCastStringNonNumber() {
() -> cast("abc", 2, 1)
);
}

@Test
public void shouldAllowImplicitlyCastOnEqualSchema() {
// Given:
final SqlDecimal s1 = SqlTypes.decimal(5, 2);
final SqlDecimal s2 = SqlTypes.decimal(5, 2);

// When:
final boolean compatible = DecimalUtil.canImplicitlyCast(s1, s2);

// Then:
assertThat(compatible, is(true));
}

@Test
public void shouldAllowImplicitlyCastOnHigherPrecisionAndScale() {
// Given:
final SqlDecimal s1 = SqlTypes.decimal(5, 2);
final SqlDecimal s2 = SqlTypes.decimal(6, 3);

// When:
final boolean compatible = DecimalUtil.canImplicitlyCast(s1, s2);

// Then:
assertThat(compatible, is(true));
}

@Test
public void shouldAllowImplicitlyCastOnHigherScale() {
// Given:
final SqlDecimal s1 = SqlTypes.decimal(2, 1);
final SqlDecimal s2 = SqlTypes.decimal(2, 2);

// When:
final boolean compatible = DecimalUtil.canImplicitlyCast(s1, s2);

// Then:
assertThat(compatible, is(false));
}

@Test
public void shouldAllowImplicitlyCastOnLowerPrecision() {
// Given:
final SqlDecimal s1 = SqlTypes.decimal(2, 1);
final SqlDecimal s2 = SqlTypes.decimal(1, 1);

// When:
final boolean compatible = DecimalUtil.canImplicitlyCast(s1, s2);

// Then:
assertThat(compatible, is(false));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ private void validateExistingSink(
final LogicalSchema resultSchema = outputNode.getSchema();
final LogicalSchema existingSchema = existing.getSchema();

if (!resultSchema.equals(existingSchema)) {
if (!resultSchema.compatibleSchema(existingSchema)) {
throw new KsqlException("Incompatible schema between results and sink."
+ System.lineSeparator()
+ "Result schema is " + resultSchema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
import io.confluent.ksql.execution.streams.timestamp.TimestampExtractionPolicyFactory;
import io.confluent.ksql.execution.timestamp.TimestampColumn;
import io.confluent.ksql.execution.util.ExpressionTypeManager;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.udf.AsValue;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.parser.NodeLocation;
Expand Down Expand Up @@ -93,21 +93,21 @@ public class LogicalPlanner {

private final KsqlConfig ksqlConfig;
private final RewrittenAnalysis analysis;
private final FunctionRegistry functionRegistry;
private final MetaStore metaStore;
private final AggregateAnalyzer aggregateAnalyzer;
private final ColumnReferenceRewriter refRewriter;

public LogicalPlanner(
final KsqlConfig ksqlConfig,
final ImmutableAnalysis analysis,
final FunctionRegistry functionRegistry
final MetaStore metaStore
) {
this.ksqlConfig = Objects.requireNonNull(ksqlConfig, "ksqlConfig");
this.refRewriter =
new ColumnReferenceRewriter(analysis.getFromSourceSchemas(false).isJoin());
this.analysis = new RewrittenAnalysis(analysis, refRewriter::process);
this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry");
this.aggregateAnalyzer = new AggregateAnalyzer(functionRegistry);
this.metaStore = Objects.requireNonNull(metaStore, "metaStore");
this.aggregateAnalyzer = new AggregateAnalyzer(metaStore);
}

// CHECKSTYLE_RULES.OFF: CyclomaticComplexity
Expand Down Expand Up @@ -227,13 +227,20 @@ private Optional<TimestampColumn> getTimestampColumn(
return timestampColumn;
}

private Optional<LogicalSchema> getTargetSchema() {
return analysis.getInto().filter(i -> !i.isCreate())
.map(i -> metaStore.getSource(i.getName()))
.map(target -> target.getSchema());
}

private AggregateNode buildAggregateNode(final PlanNode sourcePlanNode) {
final GroupBy groupBy = analysis.getGroupBy()
.orElseThrow(IllegalStateException::new);

final List<SelectExpression> projectionExpressions = SelectionUtil.buildSelectExpressions(
sourcePlanNode,
analysis.getSelectItems()
analysis.getSelectItems(),
getTargetSchema()
);

final LogicalSchema schema =
Expand All @@ -247,7 +254,7 @@ private AggregateNode buildAggregateNode(final PlanNode sourcePlanNode) {
if (analysis.getHavingExpression().isPresent()) {
final FilterTypeValidator validator = new FilterTypeValidator(
sourcePlanNode.getSchema(),
functionRegistry,
metaStore,
FilterType.HAVING);

validator.validateFilterExpression(analysis.getHavingExpression().get());
Expand All @@ -258,7 +265,7 @@ private AggregateNode buildAggregateNode(final PlanNode sourcePlanNode) {
sourcePlanNode,
schema,
groupBy,
functionRegistry,
metaStore,
analysis,
aggregateAnalysis,
projectionExpressions,
Expand All @@ -271,8 +278,8 @@ private ProjectNode buildUserProjectNode(final PlanNode parentNode) {
new PlanNodeId("Project"),
parentNode,
analysis.getSelectItems(),
analysis.getInto().isPresent(),
functionRegistry
analysis.getInto(),
metaStore
);
}

Expand All @@ -294,7 +301,7 @@ private FilterNode buildFilterNode(
) {
final FilterTypeValidator validator = new FilterTypeValidator(
sourcePlanNode.getSchema(),
functionRegistry,
metaStore,
FilterType.WHERE);

validator.validateFilterExpression(filterExpression);
Expand Down Expand Up @@ -342,7 +349,7 @@ private RepartitionNode buildInternalRepartitionNode(
}

private FlatMapNode buildFlatMapNode(final PlanNode sourcePlanNode) {
return new FlatMapNode(new PlanNodeId("FlatMap"), sourcePlanNode, functionRegistry, analysis);
return new FlatMapNode(new PlanNodeId("FlatMap"), sourcePlanNode, metaStore, analysis);
}

private PlanNode buildSourceForJoin(
Expand Down Expand Up @@ -511,7 +518,7 @@ private LogicalSchema buildAggregateSchema(
sourceSchema
.withPseudoAndKeyColsInValue(analysis.getWindowExpression().isPresent()),
projectionExpressions,
functionRegistry
metaStore
);

final List<Expression> groupByExps = groupBy.getGroupingExpressions();
Expand Down Expand Up @@ -555,7 +562,7 @@ private LogicalSchema buildAggregateSchema(
);
} else {
final ExpressionTypeManager typeManager =
new ExpressionTypeManager(sourceSchema, functionRegistry);
new ExpressionTypeManager(sourceSchema, metaStore);

final Expression expression = groupByExps.get(0);

Expand Down Expand Up @@ -607,7 +614,7 @@ private LogicalSchema buildRepartitionedSchema(
return PartitionByParamsFactory.buildSchema(
sourceSchema,
partitionBy,
functionRegistry
metaStore
);
}

Expand Down
Loading

0 comments on commit 2bc15dd

Please sign in to comment.