Skip to content

Commit

Permalink
Adjust manifest string to base64 (#581)
Browse files Browse the repository at this point in the history
* Adjust manifest string to base64

* Pull up initDuckDB

* Add test for MDLResourceV2
  • Loading branch information
grieve54706 authored May 29, 2024
1 parent 1452089 commit dfd6d7a
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 39 deletions.
2 changes: 1 addition & 1 deletion ibis-server/app/model/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class IbisDTO(BaseModel):
sql: str
manifest_str: str = Field(alias="manifestStr", description="JSON string of manifest")
manifest_str: str = Field(alias="manifestStr", description="Base64 manifest")


class PostgresDTO(IbisDTO):
Expand Down
4 changes: 3 additions & 1 deletion wren-main/src/main/java/io/wren/main/web/MDLResourceV2.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
import jakarta.ws.rs.container.Suspended;

import java.io.IOException;
import java.util.Base64;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

import static io.wren.main.web.WrenExceptionMapper.bindAsyncResponse;
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;

@Path("/v2/mdl")
Expand Down Expand Up @@ -58,7 +60,7 @@ public void dryPlan(
.orElseThrow(() -> new IllegalArgumentException("Manifest is required")))
.thenApply(manifestStr -> {
try {
return WrenMDL.fromJson(manifestStr);
return WrenMDL.fromJson(new String(Base64.getDecoder().decode(manifestStr), UTF_8));
}
catch (IOException e) {
throw new RuntimeException(e);
Expand Down
36 changes: 36 additions & 0 deletions wren-tests/src/test/java/io/wren/testing/RequireWrenServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package io.wren.testing;

import com.google.common.io.Closer;
import com.google.common.io.Resources;
import com.google.inject.Key;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.HttpClientConfig;
Expand All @@ -30,11 +31,13 @@
import io.wren.base.dto.Manifest;
import io.wren.base.sqlrewrite.analyzer.decisionpoint.QueryAnalysis;
import io.wren.cache.TaskInfo;
import io.wren.main.connector.duckdb.DuckDBMetadata;
import io.wren.main.validation.ValidationResult;
import io.wren.main.web.dto.CheckOutputDto;
import io.wren.main.web.dto.ColumnLineageInputDto;
import io.wren.main.web.dto.DeployInputDto;
import io.wren.main.web.dto.DryPlanDto;
import io.wren.main.web.dto.DryPlanDtoV2;
import io.wren.main.web.dto.ErrorMessageDto;
import io.wren.main.web.dto.LineageResult;
import io.wren.main.web.dto.PreviewDto;
Expand Down Expand Up @@ -68,6 +71,7 @@
import static io.airlift.json.JsonCodec.listJsonCodec;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.SECONDS;

public abstract class RequireWrenServer
Expand All @@ -92,6 +96,7 @@ public abstract class RequireWrenServer
private static final JsonCodec<QueryResultDto> QUERY_RESULT_DTO_CODEC = jsonCodec(QueryResultDto.class);
private static final JsonCodec<List<Column>> COLUMN_LIST_CODEC = listJsonCodec(Column.class);
private static final JsonCodec<DryPlanDto> DRY_PLAN_DTO_CODEC = jsonCodec(DryPlanDto.class);
private static final JsonCodec<DryPlanDtoV2> DRY_PLAN_DTO_V2_CODEC = jsonCodec(DryPlanDtoV2.class);
private static final JsonCodec<List<ValidationResult>> VALIDATION_RESULT_LIST_CODEC = listJsonCodec(ValidationResult.class);
private static final JsonCodec<ValidateDto> VALIDATE_DTO_CODEC = jsonCodec(ValidateDto.class);
private static final JsonCodec<List<QueryAnalysisDto>> QUERY_ANALYSIS_DTO_LIST_CODEC = listJsonCodec(QueryAnalysisDto.class);
Expand All @@ -116,6 +121,22 @@ protected static JettyHttpClient createHttpClient()
protected abstract TestingWrenServer createWrenServer()
throws Exception;

protected void initDuckDB()
{
ClassLoader classLoader = getClass().getClassLoader();
String initSQL;
try {
initSQL = Resources.toString(requireNonNull(classLoader.getResource("duckdb/init.sql")).toURI().toURL(), UTF_8);
}
catch (Exception e) {
throw new RuntimeException(e);
}
initSQL = initSQL.replaceAll("basePath", requireNonNull(classLoader.getResource("tpch/data")).getPath());
DuckDBMetadata metadata = wrenServer.getInstance(Key.get(DuckDBMetadata.class));
metadata.setInitSQL(initSQL);
metadata.reload();
}

protected TestingWrenServer server()
{
return wrenServer;
Expand Down Expand Up @@ -215,6 +236,21 @@ protected String dryPlan(DryPlanDto dryPlanDto)
return response.getBody();
}

protected String dryPlanV2(DryPlanDtoV2 dryPlanDto)
{
Request request = prepareGet()
.setUri(server().getHttpServerBasedUrl().resolve("/v2/mdl/dry-plan"))
.setHeader(CONTENT_TYPE, "application/json")
.setBodyGenerator(jsonBodyGenerator(DRY_PLAN_DTO_V2_CODEC, dryPlanDto))
.build();

StringResponseHandler.StringResponse response = executeHttpRequest(request, createStringResponseHandler());
if (response.getStatusCode() != 200) {
getWebApplicationException(response);
}
return response.getBody();
}

protected void deployMDL(DeployInputDto dto)
{
Request request = preparePost()
Expand Down
19 changes: 4 additions & 15 deletions wren-tests/src/test/java/io/wren/testing/TestMDLResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
package io.wren.testing;

import com.google.common.collect.ImmutableMap;
import com.google.common.io.Resources;
import com.google.inject.Key;
import io.wren.base.dto.Column;
import io.wren.base.dto.JoinType;
import io.wren.base.dto.Manifest;
import io.wren.base.type.IntegerType;
import io.wren.main.connector.duckdb.DuckDBMetadata;
import io.wren.main.validation.ColumnIsValid;
import io.wren.main.validation.ValidationResult;
import io.wren.main.web.dto.CheckOutputDto;
Expand Down Expand Up @@ -49,7 +46,6 @@
import static io.wren.base.dto.Relationship.relationship;
import static io.wren.main.validation.ColumnIsValid.COLUMN_IS_VALID;
import static io.wren.testing.WebApplicationExceptionAssert.assertWebApplicationException;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatNoException;
Expand Down Expand Up @@ -93,22 +89,15 @@ protected TestingWrenServer createWrenServer()
.put(WREN_DATASOURCE_TYPE, DUCKDB.name())
.put(WREN_ENABLE_DYNAMIC_FIELDS, "true");

TestingWrenServer testing = TestingWrenServer.builder()
return TestingWrenServer.builder()
.setRequiredConfigs(properties.build())
.build();
initDuckDB(testing);
return testing;
}

protected void initDuckDB(TestingWrenServer wrenServer)
throws Exception
@Override
protected void prepare()
{
ClassLoader classLoader = getClass().getClassLoader();
String initSQL = Resources.toString(requireNonNull(classLoader.getResource("duckdb/init.sql")).toURI().toURL(), UTF_8);
initSQL = initSQL.replaceAll("basePath", requireNonNull(classLoader.getResource("tpch/data")).getPath());
DuckDBMetadata metadata = wrenServer.getInstance(Key.get(DuckDBMetadata.class));
metadata.setInitSQL(initSQL);
metadata.reload();
initDuckDB();
}

@Test
Expand Down
204 changes: 204 additions & 0 deletions wren-tests/src/test/java/io/wren/testing/TestMDLResourceV2.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.wren.testing;

import com.google.common.collect.ImmutableMap;
import io.airlift.json.JsonCodec;
import io.wren.base.dto.JoinType;
import io.wren.base.dto.Manifest;
import io.wren.main.web.dto.DryPlanDtoV2;
import org.testng.annotations.Test;

import java.nio.file.Files;
import java.util.Base64;
import java.util.List;

import static io.airlift.json.JsonCodec.jsonCodec;
import static io.wren.base.config.WrenConfig.DataSourceType.DUCKDB;
import static io.wren.base.config.WrenConfig.WREN_DATASOURCE_TYPE;
import static io.wren.base.config.WrenConfig.WREN_DIRECTORY;
import static io.wren.base.config.WrenConfig.WREN_ENABLE_DYNAMIC_FIELDS;
import static io.wren.base.dto.Column.caluclatedColumn;
import static io.wren.base.dto.Column.column;
import static io.wren.base.dto.Model.model;
import static io.wren.base.dto.Relationship.relationship;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;

public class TestMDLResourceV2
extends RequireWrenServer
{
private static final JsonCodec<Manifest> MANIFEST_JSON_CODEC = jsonCodec(Manifest.class);

@Override
protected TestingWrenServer createWrenServer()
throws Exception
{
ImmutableMap.Builder<String, String> properties = ImmutableMap.<String, String>builder()
.put(WREN_DIRECTORY, Files.createTempDirectory("mdl").toAbsolutePath().toString())
.put(WREN_DATASOURCE_TYPE, DUCKDB.name())
.put(WREN_ENABLE_DYNAMIC_FIELDS, "true");
TestingWrenServer testing = TestingWrenServer.builder()
.setRequiredConfigs(properties.build())
.build();
return testing;
}

@Override
protected void prepare()
{
initDuckDB();
}

@Test
public void testDryPlan()
{
Manifest manifest = Manifest.builder()
.setCatalog("wrenai")
.setSchema("tpch")
.setModels(List.of(
model("Customer", "SELECT * FROM tpch.customer",
List.of(column("custkey", "integer", null, false, "c_custkey"),
column("name", "varchar", null, false, "c_name"))),
model("Orders", "SELECT * FROM tpch.orders",
List.of(column("orderkey", "integer", null, false, "o_orderkey"),
column("custkey", "integer", null, false, "o_custkey"),
column("customer", "Customer", "CustomerOrders", false),
caluclatedColumn("customer_name", "varchar", "customer.name")),
"orderkey")))
.setRelationships(List.of(relationship("CustomerOrders", List.of("Customer", "Orders"), JoinType.ONE_TO_MANY, "Customer.custkey = Orders.custkey")))
.build();

String manifestStr = base64Encode(toJson(manifest));

DryPlanDtoV2 dryPlanDto = new DryPlanDtoV2(manifestStr, "select orderkey from Orders limit 200");
String dryPlan = dryPlanV2(dryPlanDto);
assertThat(dryPlan).isEqualTo("""
WITH
"Orders" AS (
SELECT
"Orders"."orderkey" "orderkey"
, "Orders"."custkey" "custkey"
FROM
(
SELECT
"Orders"."orderkey" "orderkey"
, "Orders"."custkey" "custkey"
FROM
(
SELECT
o_orderkey "orderkey"
, o_custkey "custkey"
FROM
(
SELECT *
FROM
tpch.orders
) "Orders"
) "Orders"
) "Orders"
)\s
SELECT orderkey
FROM
Orders
LIMIT 200
""");

dryPlanDto = new DryPlanDtoV2(manifestStr, "select customer_name from Orders limit 200");
dryPlan = dryPlanV2(dryPlanDto);
assertThat(dryPlan).isEqualTo("""
WITH
"Customer" AS (
SELECT
"Customer"."custkey" "custkey"
, "Customer"."name" "name"
FROM
(
SELECT
"Customer"."custkey" "custkey"
, "Customer"."name" "name"
FROM
(
SELECT
c_custkey "custkey"
, c_name "name"
FROM
(
SELECT *
FROM
tpch.customer
) "Customer"
) "Customer"
) "Customer"
)\s
, "Orders" AS (
SELECT
"Orders"."orderkey" "orderkey"
, "Orders"."custkey" "custkey"
, "Orders_relationsub"."customer_name" "customer_name"
FROM
((
SELECT
"Orders"."orderkey" "orderkey"
, "Orders"."custkey" "custkey"
FROM
(
SELECT
o_orderkey "orderkey"
, o_custkey "custkey"
FROM
(
SELECT *
FROM
tpch.orders
) "Orders"
) "Orders"
) "Orders"
LEFT JOIN (
SELECT
"Orders"."orderkey"
, "Customer"."name" "customer_name"
FROM
((
SELECT
o_orderkey "orderkey"
, o_custkey "custkey"
FROM
(
SELECT *
FROM
tpch.orders
) "Orders"
) "Orders"
LEFT JOIN "Customer" ON ("Customer"."custkey" = "Orders"."custkey"))
) "Orders_relationsub" ON ("Orders"."orderkey" = "Orders_relationsub"."orderkey"))
)\s
SELECT customer_name
FROM
Orders
LIMIT 200
""");
}

private String toJson(Manifest manifest)
{
return MANIFEST_JSON_CODEC.toJson(manifest);
}

private String base64Encode(String str)
{
return Base64.getEncoder().encodeToString(str.getBytes(UTF_8));
}
}
Loading

0 comments on commit dfd6d7a

Please sign in to comment.