diff --git a/wren-base/src/main/java/io/wren/base/dto/JoinType.java b/wren-base/src/main/java/io/wren/base/dto/JoinType.java index cbc20d0a7..0a2b813dd 100644 --- a/wren-base/src/main/java/io/wren/base/dto/JoinType.java +++ b/wren-base/src/main/java/io/wren/base/dto/JoinType.java @@ -16,7 +16,6 @@ import static io.wren.base.dto.JoinType.GenericJoinType.TO_MANY; import static io.wren.base.dto.JoinType.GenericJoinType.TO_ONE; -import static java.lang.String.format; import static java.util.Objects.requireNonNull; public enum JoinType @@ -41,15 +40,12 @@ public enum GenericJoinType public static JoinType reverse(JoinType joinType) { - switch (joinType) { - case ONE_TO_ONE: - return ONE_TO_ONE; - case ONE_TO_MANY: - return MANY_TO_ONE; - case MANY_TO_ONE: - return ONE_TO_MANY; - } - throw new IllegalArgumentException(format("Invalid join type %s", joinType)); + return switch (joinType) { + case ONE_TO_ONE -> ONE_TO_ONE; + case ONE_TO_MANY -> MANY_TO_ONE; + case MANY_TO_ONE -> ONE_TO_MANY; + case MANY_TO_MANY -> MANY_TO_MANY; + }; } public GenericJoinType getType() diff --git a/wren-tests/src/test/java/io/wren/testing/TestMDLResourceV2.java b/wren-tests/src/test/java/io/wren/testing/TestMDLResourceV2.java index 1ccdc21cc..d10cb9a6e 100644 --- a/wren-tests/src/test/java/io/wren/testing/TestMDLResourceV2.java +++ b/wren-tests/src/test/java/io/wren/testing/TestMDLResourceV2.java @@ -192,6 +192,61 @@ LEFT JOIN ( """); } + @Test + public void testSetManyToMany() + { + Manifest manifest = Manifest.builder() + .setCatalog("wrenai") + .setSchema("tpch") + .setModels(List.of( + model("Customer", "SELECT * FROM tpch.customer", + List.of(column("custkey", "integer", null, false, "c_custkey"), + column("name", "varchar", null, false, "c_name"))), + model("Orders", "SELECT * FROM tpch.orders", + List.of(column("orderkey", "integer", null, false, "o_orderkey"), + column("custkey", "integer", null, false, "o_custkey"), + column("customer", "Customer", "CustomerOrders", false), + caluclatedColumn("customer_name", "varchar", "customer.name")), + "orderkey"))) + .setRelationships(List.of(relationship("CustomerOrders", List.of("Customer", "Orders"), JoinType.MANY_TO_MANY, "Customer.custkey = Orders.custkey"))) + .build(); + + String manifestStr = base64Encode(toJson(manifest)); + DryPlanDtoV2 dryPlanDto = new DryPlanDtoV2(manifestStr, "select orderkey from Orders limit 200"); + String dryPlan = dryPlanV2(dryPlanDto); + assertThat(dryPlan).isEqualTo(""" + WITH + "Orders" AS ( + SELECT + "Orders"."orderkey" "orderkey" + , "Orders"."custkey" "custkey" + FROM + ( + SELECT + "Orders"."orderkey" "orderkey" + , "Orders"."custkey" "custkey" + FROM + ( + SELECT + o_orderkey "orderkey" + , o_custkey "custkey" + FROM + ( + SELECT * + FROM + tpch.orders + ) "Orders" + ) "Orders" + ) "Orders" + )\s + SELECT orderkey + FROM + Orders + LIMIT 200 + """); + + } + private String toJson(Manifest manifest) { return MANIFEST_JSON_CODEC.toJson(manifest);