Skip to content

Commit

Permalink
feat: merge dag sql (#3911)
Browse files Browse the repository at this point in the history
* feat: merge AIOS DAG SQL

* feat: mergeDAGSQL

* add AIOSUtil

* feat: add AIOS merge SQL test case

* feat: split margeDAGSQL and validateSQLInRequest
  • Loading branch information
wyl4pd authored May 14, 2024
1 parent 6569b42 commit 673ab1d
Show file tree
Hide file tree
Showing 7 changed files with 814 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -691,4 +695,40 @@ private static DAGNode convertDAG(com._4paradigm.openmldb.DAGNode dag) {

return new DAGNode(dag.getName(), dag.getSql(), convertedProducers);
}

private static String mergeDAGSQLMemo(DAGNode dag, Map<DAGNode, String> memo, Set<DAGNode> visiting) {
if (visiting.contains(dag)) {
throw new RuntimeException("Invalid DAG: found circle");
}

String merged = memo.get(dag);
if (merged != null) {
return merged;
}

visiting.add(dag);
StringBuilder with = new StringBuilder();
for (DAGNode node : dag.producers) {
String sql = mergeDAGSQLMemo(node, memo, visiting);
if (with.length() == 0) {
with.append("WITH ");
} else {
with.append(",\n");
}
with.append(node.name).append(" as (\n");
with.append(sql).append("\n").append(")");
}
if (with.length() == 0) {
merged = dag.sql;
} else {
merged = with.append("\n").append(dag.sql).toString();
}
visiting.remove(dag);
memo.put(dag, merged);
return merged;
}

public static String mergeDAGSQL(DAGNode dag) {
return mergeDAGSQLMemo(dag, new HashMap<DAGNode, String>(), new HashSet<DAGNode>());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@


/*
* Copyright 2021 4Paradigm
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.sdk.utils;

import com._4paradigm.openmldb.sdk.Column;
import com._4paradigm.openmldb.sdk.DAGNode;
import com._4paradigm.openmldb.sdk.Schema;

import com.google.gson.Gson;

import java.sql.SQLException;
import java.sql.Types;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.HashMap;
import java.util.List;
import java.util.Queue;
import java.util.Map;

public class AIOSUtil {
private static class AIOSDAGNode {
public String uuid;
public String script;
public ArrayList<String> parents = new ArrayList<>();
public ArrayList<String> inputTables = new ArrayList<>();
public Map<String, String> tableNameMap = new HashMap<>();
}

private static class AIOSDAGColumn {
public String name;
public String type;
}

private static class AIOSDAGSchema {
public String prn;
public List<AIOSDAGColumn> cols = new ArrayList<>();
}

private static class AIOSDAG {
public List<AIOSDAGNode> nodes = new ArrayList<>();
public List<AIOSDAGSchema> schemas = new ArrayList<>();
}

private static int parseType(String type) {
switch (type.toLowerCase()) {
case "smallint":
case "int16":
return Types.SMALLINT;
case "int32":
case "i32":
case "int":
return Types.INTEGER;
case "int64":
case "bigint":
return Types.BIGINT;
case "float":
return Types.FLOAT;
case "double":
return Types.DOUBLE;
case "bool":
case "boolean":
return Types.BOOLEAN;
case "string":
return Types.VARCHAR;
case "timestamp":
return Types.TIMESTAMP;
case "date":
return Types.DATE;
default:
throw new RuntimeException("Unknown type: " + type);
}
}

private static DAGNode buildAIOSDAG(Map<String, String> sqls, Map<String, Map<String, String>> dag) {
Queue<String> queue = new LinkedList<>();
Map<String, List<String>> childrenMap = new HashMap<>();
Map<String, Integer> degreeMap = new HashMap<>();
Map<String, DAGNode> nodeMap = new HashMap<>();
for (String uuid: sqls.keySet()) {
Map<String, String> parents = dag.get(uuid);
int degree = 0;
if (parents != null) {
for (String parent : parents.values()) {
if (dag.get(parent) != null) {
degree += 1;
if (childrenMap.get(parent) == null) {
childrenMap.put(parent, new ArrayList<>());
}
childrenMap.get(parent).add(uuid);
}
}
}
degreeMap.put(uuid, degree);
if (degree == 0) {
queue.offer(uuid);
}
}

ArrayList<DAGNode> targets = new ArrayList<>();
while (!queue.isEmpty()) {
String uuid = queue.poll();
String sql = sqls.get(uuid);
if (sql == null) {
continue;
}

DAGNode node = new DAGNode(uuid, sql, new ArrayList<DAGNode>());
Map<String, String> parents = dag.get(uuid);
for (Map.Entry<String, String> parent : parents.entrySet()) {
DAGNode producer = nodeMap.get(parent.getValue());
if (producer != null) {
node.producers.add(new DAGNode(parent.getKey(), producer.sql, producer.producers));
}
}
nodeMap.put(uuid, node);
List<String> children = childrenMap.get(uuid);
if (children == null || children.size() == 0) {
targets.add(node);
} else {
for (String child : children) {
degreeMap.put(child, degreeMap.get(child) - 1);
if (degreeMap.get(child) == 0) {
queue.offer(child);
}
}
}
}

if (targets.size() == 0) {
throw new RuntimeException("Invalid DAG: target node not found");
} else if (targets.size() > 1) {
throw new RuntimeException("Invalid DAG: target node is not unique");
}
return targets.get(0);
}

public static DAGNode parseAIOSDAG(String json) throws SQLException {
Gson gson = new Gson();
AIOSDAG graph = gson.fromJson(json, AIOSDAG.class);
Map<String, String> sqls = new HashMap<>();
Map<String, Map<String, String>> dag = new HashMap<>();

for (AIOSDAGNode node : graph.nodes) {
if (sqls.get(node.uuid) != null) {
throw new RuntimeException("Duplicate 'uuid': " + node.uuid);
}
if (node.parents.size() != node.inputTables.size()) {
throw new RuntimeException("Size of 'parents' and 'inputTables' mismatch: " + node.uuid);
}
Map<String, String> parents = new HashMap<String, String>();
for (int i = 0; i < node.parents.size(); i++) {
String table = node.inputTables.get(i);
if (parents.get(table) != null) {
throw new RuntimeException("Ambiguous name '" + table + "': " + node.uuid);
}
parents.put(table, node.parents.get(i));
}
sqls.put(node.uuid, node.script);
dag.put(node.uuid, parents);
}
return buildAIOSDAG(sqls, dag);
}

public static Map<String, Map<String, Schema>> parseAIOSTableSchema(String json, String usedDB) {
Gson gson = new Gson();
AIOSDAG graph = gson.fromJson(json, AIOSDAG.class);
Map<String, String> sqls = new HashMap<>();
for (AIOSDAGNode node : graph.nodes) {
sqls.put(node.uuid, node.script);
}

Map<String, Schema> schemaMap = new HashMap<>();
for (AIOSDAGSchema schema : graph.schemas) {
List<Column> columns = new ArrayList<>();
for (AIOSDAGColumn column : schema.cols) {
try {
columns.add(new Column(column.name, parseType(column.type)));
} catch (Exception e) {
throw new RuntimeException("Unknown SQL type: " + column.type);
}
}
schemaMap.put(schema.prn, new Schema(columns));
}

Map<String, Schema> tableSchema0 = new HashMap<>();
for (AIOSDAGNode node : graph.nodes) {
for (int i = 0; i < node.parents.size(); i++) {
String table = node.inputTables.get(i);
if (sqls.get(node.parents.get(i)) == null) {
String prn = node.tableNameMap.get(table);
if (prn == null) {
throw new RuntimeException("Table not found in 'tableNameMap': " +
node.uuid + " " + table);
}
Schema schema = schemaMap.get(prn);
if (schema == null) {
throw new RuntimeException("Schema not found: " + prn);
}
if (tableSchema0.get(table) != null) {
if (tableSchema0.get(table) != schema) {
throw new RuntimeException("Table name conflict: " + table);
}
}
tableSchema0.put(table, schema);
}
}
}

Map<String, Map<String, Schema>> tableSchema = new HashMap<>();
tableSchema.put(usedDB, tableSchema0);
return tableSchema;
}
}
108 changes: 108 additions & 0 deletions java/openmldb-jdbc/src/test/data/aiosdagsql/error1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
{
"nodes": [
{
"id": -1,
"uuid": "8a41c2a7-5259-4dbd-9423-66f9d24f0194",
"type": "FeatureCompute",
"script": "select t3.*, csv(regression_label(t3.job)) from t1 last join t3 on t1.id \u003d t3.id",
"isDebug": false,
"isCurrent": false,
"parents": [
"15810afc-b62f-4165-a027-a198f7e5a375",
"f84bb5fe-b247-4b43-8ae0-9c865c80052e"
],
"inputTables": [
"t1",
"t3"
],
"tableNameMap": {
"t1": "modelIDE/train-QueryExec-1715152021-021413.table",
"t3": "modelIDE/train-QueryExec-1715152182-85b06d.table"
},
"outputTables": [],
"instanceType": null,
"tables": {},
"loader": null,
"originConfig": null,
"enablePrn": true
}
],
"schemas": [
{
"uuid": null,
"prn": "modelIDE/train-QueryExec-1715152021-021413.table",
"cols": [{
"name": "id",
"type": "Int"
},
{
"name": "y",
"type": "Int"
},
{
"name": "f1_bool",
"type": "Boolean"
},
{
"name": "f2_sint",
"type": "SmallInt"
},
{
"name": "f3_int",
"type": "Int"
},
{
"name": "f4_bint",
"type": "BigInt"
},
{
"name": "f5_float",
"type": "Float"
},
{
"name": "f6_double",
"type": "Double"
},
{
"name": "f7_date",
"type": "Date"
},
{
"name": "f8_ts",
"type": "Timestamp"
},
{
"name": "f9_str",
"type": "String"
}
],
"isOutput": null
},
{
"uuid": null,
"prn": "modelIDE/train-QueryExec-1715152182-85b06d.table",
"cols": [{
"name": "id",
"type": "Int"
},
{
"name": "age",
"type": "Int"
},
{
"name": "job",
"type": "String"
},
{
"name": "marital",
"type": "String"
},
{
"name": "education",
"type": "String"
}
],
"isOutput": null
}
]
}
1 change: 1 addition & 0 deletions java/openmldb-jdbc/src/test/data/aiosdagsql/error1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select t3.*, csv(regression_label(t3.job)) from t1 last join t3 on t1.id = t3.id
Loading

0 comments on commit 673ab1d

Please sign in to comment.