-
Notifications
You must be signed in to change notification settings - Fork 320
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: merge AIOS DAG SQL * feat: mergeDAGSQL * add AIOSUtil * feat: add AIOS merge SQL test case * feat: split margeDAGSQL and validateSQLInRequest
- Loading branch information
Showing
7 changed files
with
814 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
230 changes: 230 additions & 0 deletions
230
java/openmldb-jdbc/src/main/java/com/_4paradigm/openmldb/sdk/utils/AIOSUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
|
||
|
||
/* | ||
* Copyright 2021 4Paradigm | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com._4paradigm.openmldb.sdk.utils; | ||
|
||
import com._4paradigm.openmldb.sdk.Column; | ||
import com._4paradigm.openmldb.sdk.DAGNode; | ||
import com._4paradigm.openmldb.sdk.Schema; | ||
|
||
import com.google.gson.Gson; | ||
|
||
import java.sql.SQLException; | ||
import java.sql.Types; | ||
|
||
import java.util.ArrayList; | ||
import java.util.LinkedList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Queue; | ||
import java.util.Map; | ||
|
||
public class AIOSUtil { | ||
private static class AIOSDAGNode { | ||
public String uuid; | ||
public String script; | ||
public ArrayList<String> parents = new ArrayList<>(); | ||
public ArrayList<String> inputTables = new ArrayList<>(); | ||
public Map<String, String> tableNameMap = new HashMap<>(); | ||
} | ||
|
||
private static class AIOSDAGColumn { | ||
public String name; | ||
public String type; | ||
} | ||
|
||
private static class AIOSDAGSchema { | ||
public String prn; | ||
public List<AIOSDAGColumn> cols = new ArrayList<>(); | ||
} | ||
|
||
private static class AIOSDAG { | ||
public List<AIOSDAGNode> nodes = new ArrayList<>(); | ||
public List<AIOSDAGSchema> schemas = new ArrayList<>(); | ||
} | ||
|
||
private static int parseType(String type) { | ||
switch (type.toLowerCase()) { | ||
case "smallint": | ||
case "int16": | ||
return Types.SMALLINT; | ||
case "int32": | ||
case "i32": | ||
case "int": | ||
return Types.INTEGER; | ||
case "int64": | ||
case "bigint": | ||
return Types.BIGINT; | ||
case "float": | ||
return Types.FLOAT; | ||
case "double": | ||
return Types.DOUBLE; | ||
case "bool": | ||
case "boolean": | ||
return Types.BOOLEAN; | ||
case "string": | ||
return Types.VARCHAR; | ||
case "timestamp": | ||
return Types.TIMESTAMP; | ||
case "date": | ||
return Types.DATE; | ||
default: | ||
throw new RuntimeException("Unknown type: " + type); | ||
} | ||
} | ||
|
||
private static DAGNode buildAIOSDAG(Map<String, String> sqls, Map<String, Map<String, String>> dag) { | ||
Queue<String> queue = new LinkedList<>(); | ||
Map<String, List<String>> childrenMap = new HashMap<>(); | ||
Map<String, Integer> degreeMap = new HashMap<>(); | ||
Map<String, DAGNode> nodeMap = new HashMap<>(); | ||
for (String uuid: sqls.keySet()) { | ||
Map<String, String> parents = dag.get(uuid); | ||
int degree = 0; | ||
if (parents != null) { | ||
for (String parent : parents.values()) { | ||
if (dag.get(parent) != null) { | ||
degree += 1; | ||
if (childrenMap.get(parent) == null) { | ||
childrenMap.put(parent, new ArrayList<>()); | ||
} | ||
childrenMap.get(parent).add(uuid); | ||
} | ||
} | ||
} | ||
degreeMap.put(uuid, degree); | ||
if (degree == 0) { | ||
queue.offer(uuid); | ||
} | ||
} | ||
|
||
ArrayList<DAGNode> targets = new ArrayList<>(); | ||
while (!queue.isEmpty()) { | ||
String uuid = queue.poll(); | ||
String sql = sqls.get(uuid); | ||
if (sql == null) { | ||
continue; | ||
} | ||
|
||
DAGNode node = new DAGNode(uuid, sql, new ArrayList<DAGNode>()); | ||
Map<String, String> parents = dag.get(uuid); | ||
for (Map.Entry<String, String> parent : parents.entrySet()) { | ||
DAGNode producer = nodeMap.get(parent.getValue()); | ||
if (producer != null) { | ||
node.producers.add(new DAGNode(parent.getKey(), producer.sql, producer.producers)); | ||
} | ||
} | ||
nodeMap.put(uuid, node); | ||
List<String> children = childrenMap.get(uuid); | ||
if (children == null || children.size() == 0) { | ||
targets.add(node); | ||
} else { | ||
for (String child : children) { | ||
degreeMap.put(child, degreeMap.get(child) - 1); | ||
if (degreeMap.get(child) == 0) { | ||
queue.offer(child); | ||
} | ||
} | ||
} | ||
} | ||
|
||
if (targets.size() == 0) { | ||
throw new RuntimeException("Invalid DAG: target node not found"); | ||
} else if (targets.size() > 1) { | ||
throw new RuntimeException("Invalid DAG: target node is not unique"); | ||
} | ||
return targets.get(0); | ||
} | ||
|
||
public static DAGNode parseAIOSDAG(String json) throws SQLException { | ||
Gson gson = new Gson(); | ||
AIOSDAG graph = gson.fromJson(json, AIOSDAG.class); | ||
Map<String, String> sqls = new HashMap<>(); | ||
Map<String, Map<String, String>> dag = new HashMap<>(); | ||
|
||
for (AIOSDAGNode node : graph.nodes) { | ||
if (sqls.get(node.uuid) != null) { | ||
throw new RuntimeException("Duplicate 'uuid': " + node.uuid); | ||
} | ||
if (node.parents.size() != node.inputTables.size()) { | ||
throw new RuntimeException("Size of 'parents' and 'inputTables' mismatch: " + node.uuid); | ||
} | ||
Map<String, String> parents = new HashMap<String, String>(); | ||
for (int i = 0; i < node.parents.size(); i++) { | ||
String table = node.inputTables.get(i); | ||
if (parents.get(table) != null) { | ||
throw new RuntimeException("Ambiguous name '" + table + "': " + node.uuid); | ||
} | ||
parents.put(table, node.parents.get(i)); | ||
} | ||
sqls.put(node.uuid, node.script); | ||
dag.put(node.uuid, parents); | ||
} | ||
return buildAIOSDAG(sqls, dag); | ||
} | ||
|
||
public static Map<String, Map<String, Schema>> parseAIOSTableSchema(String json, String usedDB) { | ||
Gson gson = new Gson(); | ||
AIOSDAG graph = gson.fromJson(json, AIOSDAG.class); | ||
Map<String, String> sqls = new HashMap<>(); | ||
for (AIOSDAGNode node : graph.nodes) { | ||
sqls.put(node.uuid, node.script); | ||
} | ||
|
||
Map<String, Schema> schemaMap = new HashMap<>(); | ||
for (AIOSDAGSchema schema : graph.schemas) { | ||
List<Column> columns = new ArrayList<>(); | ||
for (AIOSDAGColumn column : schema.cols) { | ||
try { | ||
columns.add(new Column(column.name, parseType(column.type))); | ||
} catch (Exception e) { | ||
throw new RuntimeException("Unknown SQL type: " + column.type); | ||
} | ||
} | ||
schemaMap.put(schema.prn, new Schema(columns)); | ||
} | ||
|
||
Map<String, Schema> tableSchema0 = new HashMap<>(); | ||
for (AIOSDAGNode node : graph.nodes) { | ||
for (int i = 0; i < node.parents.size(); i++) { | ||
String table = node.inputTables.get(i); | ||
if (sqls.get(node.parents.get(i)) == null) { | ||
String prn = node.tableNameMap.get(table); | ||
if (prn == null) { | ||
throw new RuntimeException("Table not found in 'tableNameMap': " + | ||
node.uuid + " " + table); | ||
} | ||
Schema schema = schemaMap.get(prn); | ||
if (schema == null) { | ||
throw new RuntimeException("Schema not found: " + prn); | ||
} | ||
if (tableSchema0.get(table) != null) { | ||
if (tableSchema0.get(table) != schema) { | ||
throw new RuntimeException("Table name conflict: " + table); | ||
} | ||
} | ||
tableSchema0.put(table, schema); | ||
} | ||
} | ||
} | ||
|
||
Map<String, Map<String, Schema>> tableSchema = new HashMap<>(); | ||
tableSchema.put(usedDB, tableSchema0); | ||
return tableSchema; | ||
} | ||
} |
108 changes: 108 additions & 0 deletions
108
java/openmldb-jdbc/src/test/data/aiosdagsql/error1.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
{ | ||
"nodes": [ | ||
{ | ||
"id": -1, | ||
"uuid": "8a41c2a7-5259-4dbd-9423-66f9d24f0194", | ||
"type": "FeatureCompute", | ||
"script": "select t3.*, csv(regression_label(t3.job)) from t1 last join t3 on t1.id \u003d t3.id", | ||
"isDebug": false, | ||
"isCurrent": false, | ||
"parents": [ | ||
"15810afc-b62f-4165-a027-a198f7e5a375", | ||
"f84bb5fe-b247-4b43-8ae0-9c865c80052e" | ||
], | ||
"inputTables": [ | ||
"t1", | ||
"t3" | ||
], | ||
"tableNameMap": { | ||
"t1": "modelIDE/train-QueryExec-1715152021-021413.table", | ||
"t3": "modelIDE/train-QueryExec-1715152182-85b06d.table" | ||
}, | ||
"outputTables": [], | ||
"instanceType": null, | ||
"tables": {}, | ||
"loader": null, | ||
"originConfig": null, | ||
"enablePrn": true | ||
} | ||
], | ||
"schemas": [ | ||
{ | ||
"uuid": null, | ||
"prn": "modelIDE/train-QueryExec-1715152021-021413.table", | ||
"cols": [{ | ||
"name": "id", | ||
"type": "Int" | ||
}, | ||
{ | ||
"name": "y", | ||
"type": "Int" | ||
}, | ||
{ | ||
"name": "f1_bool", | ||
"type": "Boolean" | ||
}, | ||
{ | ||
"name": "f2_sint", | ||
"type": "SmallInt" | ||
}, | ||
{ | ||
"name": "f3_int", | ||
"type": "Int" | ||
}, | ||
{ | ||
"name": "f4_bint", | ||
"type": "BigInt" | ||
}, | ||
{ | ||
"name": "f5_float", | ||
"type": "Float" | ||
}, | ||
{ | ||
"name": "f6_double", | ||
"type": "Double" | ||
}, | ||
{ | ||
"name": "f7_date", | ||
"type": "Date" | ||
}, | ||
{ | ||
"name": "f8_ts", | ||
"type": "Timestamp" | ||
}, | ||
{ | ||
"name": "f9_str", | ||
"type": "String" | ||
} | ||
], | ||
"isOutput": null | ||
}, | ||
{ | ||
"uuid": null, | ||
"prn": "modelIDE/train-QueryExec-1715152182-85b06d.table", | ||
"cols": [{ | ||
"name": "id", | ||
"type": "Int" | ||
}, | ||
{ | ||
"name": "age", | ||
"type": "Int" | ||
}, | ||
{ | ||
"name": "job", | ||
"type": "String" | ||
}, | ||
{ | ||
"name": "marital", | ||
"type": "String" | ||
}, | ||
{ | ||
"name": "education", | ||
"type": "String" | ||
} | ||
], | ||
"isOutput": null | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
select t3.*, csv(regression_label(t3.job)) from t1 last join t3 on t1.id = t3.id |
Oops, something went wrong.