Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: merge dag sql #3911

Merged
merged 6 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -691,4 +695,40 @@

return new DAGNode(dag.getName(), dag.getSql(), convertedProducers);
}

private static String mergeDAGSQLMemo(DAGNode dag, Map<DAGNode, String> memo, Set<DAGNode> visiting) {
if (visiting.contains(dag)) {
throw new RuntimeException("Invalid DAG: found circle");

Check warning on line 701 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/SqlClusterExecutor.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/SqlClusterExecutor.java#L701

Added line #L701 was not covered by tests
}

String merged = memo.get(dag);
if (merged != null) {
return merged;

Check warning on line 706 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/SqlClusterExecutor.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/impl/SqlClusterExecutor.java#L706

Added line #L706 was not covered by tests
}

visiting.add(dag);
StringBuilder with = new StringBuilder();
for (DAGNode node : dag.producers) {
String sql = mergeDAGSQLMemo(node, memo, visiting);
if (with.length() == 0) {
with.append("WITH ");
} else {
with.append(",\n");
}
with.append(node.name).append(" as (\n");
with.append(sql).append("\n").append(")");
}
if (with.length() == 0) {
merged = dag.sql;
} else {
merged = with.append("\n").append(dag.sql).toString();
}
visiting.remove(dag);
memo.put(dag, merged);
return merged;
}

public static String mergeDAGSQL(DAGNode dag) {
return mergeDAGSQLMemo(dag, new HashMap<DAGNode, String>(), new HashSet<DAGNode>());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@


/*
* Copyright 2021 4Paradigm
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.sdk.utils;

import com._4paradigm.openmldb.sdk.Column;
import com._4paradigm.openmldb.sdk.DAGNode;
import com._4paradigm.openmldb.sdk.Schema;

import com.google.gson.Gson;

import java.sql.SQLException;
import java.sql.Types;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.HashMap;
import java.util.List;
import java.util.Queue;
import java.util.Map;

public class AIOSUtil {

Check warning on line 37 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L37

Added line #L37 was not covered by tests
private static class AIOSDAGNode {
public String uuid;
public String script;
public ArrayList<String> parents = new ArrayList<>();
public ArrayList<String> inputTables = new ArrayList<>();
public Map<String, String> tableNameMap = new HashMap<>();
}

private static class AIOSDAGColumn {
public String name;
public String type;
}

private static class AIOSDAGSchema {
public String prn;
public List<AIOSDAGColumn> cols = new ArrayList<>();
}

private static class AIOSDAG {
public List<AIOSDAGNode> nodes = new ArrayList<>();
public List<AIOSDAGSchema> schemas = new ArrayList<>();
}

private static int parseType(String type) {
switch (type.toLowerCase()) {
case "smallint":
case "int16":
return Types.SMALLINT;
case "int32":
case "i32":
case "int":
return Types.INTEGER;
case "int64":
case "bigint":
return Types.BIGINT;
case "float":
return Types.FLOAT;
case "double":
return Types.DOUBLE;
case "bool":
case "boolean":
return Types.BOOLEAN;
case "string":
return Types.VARCHAR;
case "timestamp":
return Types.TIMESTAMP;
case "date":
return Types.DATE;
default:
throw new RuntimeException("Unknown type: " + type);

Check warning on line 87 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L87

Added line #L87 was not covered by tests
}
}

private static DAGNode buildAIOSDAG(Map<String, String> sqls, Map<String, Map<String, String>> dag) {
Queue<String> queue = new LinkedList<>();
Map<String, List<String>> childrenMap = new HashMap<>();
Map<String, Integer> degreeMap = new HashMap<>();
Map<String, DAGNode> nodeMap = new HashMap<>();
for (String uuid: sqls.keySet()) {
Map<String, String> parents = dag.get(uuid);
int degree = 0;
if (parents != null) {
for (String parent : parents.values()) {
if (dag.get(parent) != null) {
degree += 1;
if (childrenMap.get(parent) == null) {
childrenMap.put(parent, new ArrayList<>());
}
childrenMap.get(parent).add(uuid);
}
}
}
degreeMap.put(uuid, degree);
if (degree == 0) {
queue.offer(uuid);
}
}

ArrayList<DAGNode> targets = new ArrayList<>();
while (!queue.isEmpty()) {
String uuid = queue.poll();
String sql = sqls.get(uuid);
if (sql == null) {
continue;

Check warning on line 121 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L121

Added line #L121 was not covered by tests
}

DAGNode node = new DAGNode(uuid, sql, new ArrayList<DAGNode>());
Map<String, String> parents = dag.get(uuid);
for (Map.Entry<String, String> parent : parents.entrySet()) {
DAGNode producer = nodeMap.get(parent.getValue());
if (producer != null) {
node.producers.add(new DAGNode(parent.getKey(), producer.sql, producer.producers));
}
}
nodeMap.put(uuid, node);
List<String> children = childrenMap.get(uuid);
if (children == null || children.size() == 0) {
targets.add(node);
} else {
for (String child : children) {
degreeMap.put(child, degreeMap.get(child) - 1);
if (degreeMap.get(child) == 0) {
queue.offer(child);
}
}
}
}

if (targets.size() == 0) {
throw new RuntimeException("Invalid DAG: target node not found");

Check warning on line 147 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L147

Added line #L147 was not covered by tests
} else if (targets.size() > 1) {
throw new RuntimeException("Invalid DAG: target node is not unique");

Check warning on line 149 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L149

Added line #L149 was not covered by tests
}
return targets.get(0);
}

public static DAGNode parseAIOSDAG(String json) throws SQLException {
Gson gson = new Gson();
AIOSDAG graph = gson.fromJson(json, AIOSDAG.class);
Map<String, String> sqls = new HashMap<>();
Map<String, Map<String, String>> dag = new HashMap<>();

for (AIOSDAGNode node : graph.nodes) {
if (sqls.get(node.uuid) != null) {
throw new RuntimeException("Duplicate 'uuid': " + node.uuid);

Check warning on line 162 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L162

Added line #L162 was not covered by tests
}
if (node.parents.size() != node.inputTables.size()) {
throw new RuntimeException("Size of 'parents' and 'inputTables' mismatch: " + node.uuid);

Check warning on line 165 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L165

Added line #L165 was not covered by tests
}
Map<String, String> parents = new HashMap<String, String>();
for (int i = 0; i < node.parents.size(); i++) {
String table = node.inputTables.get(i);
if (parents.get(table) != null) {
throw new RuntimeException("Ambiguous name '" + table + "': " + node.uuid);

Check warning on line 171 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L171

Added line #L171 was not covered by tests
}
parents.put(table, node.parents.get(i));
}
sqls.put(node.uuid, node.script);
dag.put(node.uuid, parents);
}
return buildAIOSDAG(sqls, dag);
}

public static Map<String, Map<String, Schema>> parseAIOSTableSchema(String json, String usedDB) {
Gson gson = new Gson();
AIOSDAG graph = gson.fromJson(json, AIOSDAG.class);
Map<String, String> sqls = new HashMap<>();
for (AIOSDAGNode node : graph.nodes) {
sqls.put(node.uuid, node.script);
}

Map<String, Schema> schemaMap = new HashMap<>();
for (AIOSDAGSchema schema : graph.schemas) {
List<Column> columns = new ArrayList<>();
for (AIOSDAGColumn column : schema.cols) {
try {
columns.add(new Column(column.name, parseType(column.type)));
} catch (Exception e) {
throw new RuntimeException("Unknown SQL type: " + column.type);

Check warning on line 196 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L195-L196

Added lines #L195 - L196 were not covered by tests
}
}
schemaMap.put(schema.prn, new Schema(columns));
}

Map<String, Schema> tableSchema0 = new HashMap<>();
for (AIOSDAGNode node : graph.nodes) {
for (int i = 0; i < node.parents.size(); i++) {
String table = node.inputTables.get(i);
if (sqls.get(node.parents.get(i)) == null) {
String prn = node.tableNameMap.get(table);
if (prn == null) {
throw new RuntimeException("Table not found in 'tableNameMap': " +

Check warning on line 209 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L209

Added line #L209 was not covered by tests
node.uuid + " " + table);
}
Schema schema = schemaMap.get(prn);
if (schema == null) {
throw new RuntimeException("Schema not found: " + prn);

Check warning on line 214 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L214

Added line #L214 was not covered by tests
}
if (tableSchema0.get(table) != null) {
if (tableSchema0.get(table) != schema) {
throw new RuntimeException("Table name conflict: " + table);

Check warning on line 218 in java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java

View check run for this annotation

Codecov / codecov/patch

java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java#L218

Added line #L218 was not covered by tests
}
}
tableSchema0.put(table, schema);
}
}
}

Map<String, Map<String, Schema>> tableSchema = new HashMap<>();
tableSchema.put(usedDB, tableSchema0);
return tableSchema;
}
}
108 changes: 108 additions & 0 deletions java/openmldb-jdbc/src/test/data/aiosdagsql/error1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
{
"nodes": [
{
"id": -1,
"uuid": "8a41c2a7-5259-4dbd-9423-66f9d24f0194",
"type": "FeatureCompute",
"script": "select t3.*, csv(regression_label(t3.job)) from t1 last join t3 on t1.id \u003d t3.id",
"isDebug": false,
"isCurrent": false,
"parents": [
"15810afc-b62f-4165-a027-a198f7e5a375",
"f84bb5fe-b247-4b43-8ae0-9c865c80052e"
],
"inputTables": [
"t1",
"t3"
],
"tableNameMap": {
"t1": "modelIDE/train-QueryExec-1715152021-021413.table",
"t3": "modelIDE/train-QueryExec-1715152182-85b06d.table"
},
"outputTables": [],
"instanceType": null,
"tables": {},
"loader": null,
"originConfig": null,
"enablePrn": true
}
],
"schemas": [
{
"uuid": null,
"prn": "modelIDE/train-QueryExec-1715152021-021413.table",
"cols": [{
"name": "id",
"type": "Int"
},
{
"name": "y",
"type": "Int"
},
{
"name": "f1_bool",
"type": "Boolean"
},
{
"name": "f2_sint",
"type": "SmallInt"
},
{
"name": "f3_int",
"type": "Int"
},
{
"name": "f4_bint",
"type": "BigInt"
},
{
"name": "f5_float",
"type": "Float"
},
{
"name": "f6_double",
"type": "Double"
},
{
"name": "f7_date",
"type": "Date"
},
{
"name": "f8_ts",
"type": "Timestamp"
},
{
"name": "f9_str",
"type": "String"
}
],
"isOutput": null
},
{
"uuid": null,
"prn": "modelIDE/train-QueryExec-1715152182-85b06d.table",
"cols": [{
"name": "id",
"type": "Int"
},
{
"name": "age",
"type": "Int"
},
{
"name": "job",
"type": "String"
},
{
"name": "marital",
"type": "String"
},
{
"name": "education",
"type": "String"
}
],
"isOutput": null
}
]
}
1 change: 1 addition & 0 deletions java/openmldb-jdbc/src/test/data/aiosdagsql/error1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select t3.*, csv(regression_label(t3.job)) from t1 last join t3 on t1.id = t3.id
Loading
Loading