From e9ea3f8e3f6f550b1150b86217912a1d98fc5198 Mon Sep 17 00:00:00 2001 From: Taewhi Lee Date: Tue, 18 Jul 2023 15:04:38 +0900 Subject: [PATCH 1/7] Update the traindb-model submodule for ModelServer --- traindb-model | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/traindb-model b/traindb-model index 038e00f..8b15f94 160000 --- a/traindb-model +++ b/traindb-model @@ -1 +1 @@ -Subproject commit 038e00faac80a5ef05567970dc88c884609a2773 +Subproject commit 8b15f9444296ec694723df4f1fdf894378569360 From 20976b4fc6d3c5a7ea1c69a95e0681e39363cf29 Mon Sep 17 00:00:00 2001 From: Taewhi Lee Date: Tue, 18 Jul 2023 15:27:14 +0900 Subject: [PATCH 2/7] Feat: train models on remote TrainDBModelServer (in traindb-model) --- .../engine/AbstractTrainDBModelRunner.java | 15 ++ .../engine/TrainDBFastApiModelRunner.java | 229 ++++++++++++++++++ .../traindb/engine/TrainDBQueryEngine.java | 19 +- .../java/traindb/planner/TrainDBPlanner.java | 4 + .../rules/ApproxAggregateInferenceRule.java | 10 +- 5 files changed, 261 insertions(+), 16 deletions(-) create mode 100644 traindb-core/src/main/java/traindb/engine/TrainDBFastApiModelRunner.java diff --git a/traindb-core/src/main/java/traindb/engine/AbstractTrainDBModelRunner.java b/traindb-core/src/main/java/traindb/engine/AbstractTrainDBModelRunner.java index 7df5f28..1fb9f31 100644 --- a/traindb-core/src/main/java/traindb/engine/AbstractTrainDBModelRunner.java +++ b/traindb-core/src/main/java/traindb/engine/AbstractTrainDBModelRunner.java @@ -58,6 +58,21 @@ public Path getModelPath() { modeltypeName, modelName); } + public static AbstractTrainDBModelRunner createModelRunner( + TrainDBConnectionImpl conn, CatalogContext catalogContext, TrainDBConfiguration config, + String modeltypeName, String modelName, String location) { + if (location.equals("REMOTE")) { + return new TrainDBFastApiModelRunner(conn, catalogContext, modeltypeName, modelName); + } + // location.equals("LOCAL") + if (config.getModelRunner().equals("py4j")) { + return new TrainDBPy4JModelRunner(conn, catalogContext, modeltypeName, modelName); + } + + return new TrainDBFileModelRunner(conn, catalogContext, modeltypeName, modelName); + } + + protected String buildSelectTrainingDataQuery(String schemaName, String tableName, List columnNames) { StringBuilder sb = new StringBuilder(); diff --git a/traindb-core/src/main/java/traindb/engine/TrainDBFastApiModelRunner.java b/traindb-core/src/main/java/traindb/engine/TrainDBFastApiModelRunner.java new file mode 100644 index 0000000..4e8e6de --- /dev/null +++ b/traindb-core/src/main/java/traindb/engine/TrainDBFastApiModelRunner.java @@ -0,0 +1,229 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.engine; + +import java.io.BufferedReader; +import java.io.DataOutputStream; +import java.io.FileOutputStream; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.commons.dbcp2.BasicDataSource; +import org.apache.commons.text.StringEscapeUtils; +import org.json.simple.JSONObject; +import traindb.catalog.CatalogContext; +import traindb.catalog.pm.MModeltype; +import traindb.common.TrainDBException; +import traindb.jdbc.TrainDBConnectionImpl; +import traindb.schema.TrainDBTable; + +public class TrainDBFastApiModelRunner extends AbstractTrainDBModelRunner { + + private static final String BOUNDARY = "*****"; + private static final String DOUBLE_HYPHEN = "--"; + private static final String CRLF = "\r\n"; + + public TrainDBFastApiModelRunner( + TrainDBConnectionImpl conn, CatalogContext catalogContext, String modeltypeName, + String modelName) { + super(conn, catalogContext, modeltypeName, modelName); + } + + private static String checkTrailingSlash(String uri) { + return uri.endsWith("/") ? uri : uri + "/"; + } + + private void addString(DataOutputStream request, String key, String value) throws Exception { + StringBuilder sb = new StringBuilder(); + sb.append(DOUBLE_HYPHEN).append(BOUNDARY).append(CRLF); + sb.append("Content-Disposition: form-data; name=\""+ key +"\"").append(CRLF); + sb.append("Content-Type: plain/text").append(CRLF); + sb.append(CRLF).append(value).append(CRLF); + request.writeBytes(sb.toString()); + } + + private void addMetadataFile(DataOutputStream request, JSONObject metadata) throws Exception { + StringBuilder sb = new StringBuilder(); + sb.append(DOUBLE_HYPHEN).append(BOUNDARY).append(CRLF); + sb.append("Content-Disposition: form-data; "); + sb.append("name=\"metadata_file\"; filename=\"metadata.json\"").append(CRLF); + sb.append("Content-Type: application/json").append(CRLF); + sb.append(CRLF).append(metadata.toJSONString()).append(CRLF); + request.writeBytes(sb.toString()); + } + + private void finishMultipartRequest(DataOutputStream request) throws Exception { + request.writeBytes(DOUBLE_HYPHEN + BOUNDARY + DOUBLE_HYPHEN + CRLF); + request.flush(); + request.close(); + } + + @Override + public void trainModel(TrainDBTable table, List columnNames, + Map trainOptions, JavaTypeFactory typeFactory) + throws Exception { + MModeltype mModeltype = catalogContext.getModeltype(modeltypeName); + URL url = new URL(checkTrailingSlash(mModeltype.getUri()) + + "modeltype/" + mModeltype.getClassName() + "/train"); + HttpURLConnection httpConn = (HttpURLConnection) url.openConnection(); + httpConn.setRequestMethod("POST"); + httpConn.setRequestProperty("Content-Type", "multipart/form-data; boundary=" + BOUNDARY); + httpConn.setDoOutput(true); + + BasicDataSource ds = conn.getDataSource(); + String schemaName = table.getSchema().getName(); + String tableName = table.getName(); + String sql = buildSelectTrainingDataQuery(schemaName, tableName, columnNames); + + JSONObject tableMetadata = buildTableMetadata(schemaName, tableName, columnNames, trainOptions, + table.getRowType(typeFactory)); + + OutputStream outputStream = httpConn.getOutputStream(); + DataOutputStream request = new DataOutputStream(outputStream); + + addString(request, "modeltype_class", mModeltype.getClassName()); + addString(request, "model_name", modelName); + addString(request, "jdbc_driver_class", ds.getDriverClassName()); + addString(request, "db_url", ds.getUrl()); + addString(request, "db_user", ds.getUsername()); + addString(request, "db_pwd", ds.getPassword()); + addString(request, "select_training_data_sql", sql); + addMetadataFile(request, tableMetadata); + finishMultipartRequest(request); + + if (httpConn.getResponseCode() != HttpURLConnection.HTTP_OK) { + throw new TrainDBException("failed to train model"); + } + + StringBuilder response = new StringBuilder(); + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(httpConn.getInputStream(), StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + response.append(line); + } + } + System.out.println(response); + } + + @Override + public void generateSynopsis(String outputPath, int rows) throws Exception { + MModeltype mModeltype = catalogContext.getModel(modelName).getModeltype(); + URL url = new URL(checkTrailingSlash(mModeltype.getUri()) + "model/" + modelName + "/synopsis"); + + HttpURLConnection httpConn = (HttpURLConnection) url.openConnection(); + httpConn.setRequestMethod("POST"); + httpConn.setRequestProperty("Content-Type", "multipart/form-data; boundary=" + BOUNDARY); + httpConn.setDoOutput(true); + + OutputStream outputStream = httpConn.getOutputStream(); + DataOutputStream request = new DataOutputStream(outputStream); + + addString(request, "model_name", modelName); + addString(request, "modeltype_class", mModeltype.getClassName()); + addString(request, "rows", String.valueOf(rows)); + finishMultipartRequest(request); + + if (httpConn.getResponseCode() != HttpURLConnection.HTTP_OK) { + throw new TrainDBException("failed to create synopsis"); + } + + Files.createDirectories(Paths.get(outputPath).getParent()); + FileOutputStream fos = new FileOutputStream(outputPath); + InputStream is = httpConn.getInputStream(); + int read; + byte[] buf = new byte[32768]; + while ((read = is.read(buf)) > 0) { + fos.write(buf, 0, read); + } + fos.close(); + is.close(); + } + + @Override + public String infer(String aggregateExpression, String groupByColumn, String whereCondition) + throws Exception { + MModeltype mModeltype = catalogContext.getModel(modelName).getModeltype(); + URL url = new URL(checkTrailingSlash(mModeltype.getUri()) + "model/" + modelName + "/infer"); + + HttpURLConnection httpConn = (HttpURLConnection) url.openConnection(); + httpConn.setRequestMethod("POST"); + httpConn.setRequestProperty("Content-Type", "multipart/form-data; boundary=" + BOUNDARY); + httpConn.setDoOutput(true); + + OutputStream outputStream = httpConn.getOutputStream(); + DataOutputStream request = new DataOutputStream(outputStream); + + addString(request, "model_name", modelName); + addString(request, "modeltype_class", mModeltype.getClassName()); + addString(request, "agg_expr", aggregateExpression); + addString(request, "group_by_column", groupByColumn); + addString(request, "where_condition", whereCondition); + finishMultipartRequest(request); + + if (httpConn.getResponseCode() != HttpURLConnection.HTTP_OK) { + throw new TrainDBException("failed to infer '" + aggregateExpression + "'"); + } + + String modelPath = getModelPath().toString(); + UUID queryId = UUID.randomUUID(); + String outputPath = modelPath + "/infer" + queryId + ".csv"; + + Files.createDirectories(Paths.get(outputPath).getParent()); + FileOutputStream fos = new FileOutputStream(outputPath); + InputStream is = httpConn.getInputStream(); + int read; + byte[] buf = new byte[32768]; + while ((read = is.read(buf)) > 0) { + fos.write(buf, 0, read); + } + fos.close(); + is.close(); + + return outputPath; + } + + @Override + public String listHyperparameters(String className, String uri) throws Exception { + URL url = new URL(checkTrailingSlash(uri) + "modeltype/" + className + "/hyperparams"); + HttpURLConnection httpConn = (HttpURLConnection) url.openConnection(); + httpConn.setRequestMethod("GET"); + + if (httpConn.getResponseCode() != HttpURLConnection.HTTP_OK) { + throw new TrainDBException("failed to list hyperparameters"); + } + + StringBuilder response = new StringBuilder(); + BufferedReader reader = new BufferedReader( + new InputStreamReader(httpConn.getInputStream(), StandardCharsets.UTF_8)); + String line; + while ((line = reader.readLine()) != null) { + response.append(line); + } + + // remove beginning/ending double quotes and unescape + return StringEscapeUtils.unescapeJava(response.toString().replaceAll("^\"|\"$", "")); + } + +} \ No newline at end of file diff --git a/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java b/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java index 190fb5e..8778f3a 100644 --- a/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java +++ b/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java @@ -80,7 +80,7 @@ public void createModeltype(String name, String category, String location, T_tracer.closeTaskTime("SUCCESS"); T_tracer.openTaskTime("create modeltype"); - AbstractTrainDBModelRunner runner = createModelRunner(name, ""); + AbstractTrainDBModelRunner runner = createModelRunner(name, "", location); String hyperparamsInfo = runner.listHyperparameters(className, uri); catalogContext.createModeltype(name, category, location, className, uri, hyperparamsInfo); T_tracer.closeTaskTime("SUCCESS"); @@ -150,7 +150,8 @@ public void trainModel( Long trainedRows = baseTableRows; // TODO T_tracer.openTaskTime("train model"); - AbstractTrainDBModelRunner runner = createModelRunner(modeltypeName, modelName); + AbstractTrainDBModelRunner runner = createModelRunner( + modeltypeName, modelName, catalogContext.getModeltype(modeltypeName).getLocation()); runner.trainModel(table, columnNames, trainOptions, conn.getTypeFactory()); T_tracer.closeTaskTime("SUCCESS"); @@ -271,13 +272,10 @@ private void loadSynopsisIntoTable(String synopsisName, MModel mModel, } } - private AbstractTrainDBModelRunner createModelRunner(String modeltypeName, String modelName) { - String modelrunner = conn.cfg.getModelRunner(); - if (modelrunner.equals("py4j")) { - return new TrainDBPy4JModelRunner(conn, catalogContext, modeltypeName, modelName); - } - - return new TrainDBFileModelRunner(conn, catalogContext, modeltypeName, modelName); + private AbstractTrainDBModelRunner createModelRunner(String modeltypeName, String modelName, + String location) { + return AbstractTrainDBModelRunner.createModelRunner( + conn, catalogContext, conn.cfg, modeltypeName, modelName, location); } @Override @@ -312,7 +310,8 @@ public void createSynopsis(String synopsisName, String modelName, int limitNumbe MModel mModel = catalogContext.getModel(modelName); MModeltype mModeltype = mModel.getModeltype(); - AbstractTrainDBModelRunner runner = createModelRunner(mModeltype.getModeltypeName(), modelName); + AbstractTrainDBModelRunner runner = createModelRunner( + mModeltype.getModeltypeName(), modelName, mModeltype.getLocation()); String outputPath = runner.getModelPath().toString() + '/' + synopsisName + ".csv"; runner.generateSynopsis(outputPath, limitNumber); T_tracer.closeTaskTime("SUCCESS"); diff --git a/traindb-core/src/main/java/traindb/planner/TrainDBPlanner.java b/traindb-core/src/main/java/traindb/planner/TrainDBPlanner.java index ffead04..4711708 100644 --- a/traindb-core/src/main/java/traindb/planner/TrainDBPlanner.java +++ b/traindb-core/src/main/java/traindb/planner/TrainDBPlanner.java @@ -93,6 +93,10 @@ public void initPlanner() { Hook.PLANNER.run(this); // allow test to add or remove rules } + public TrainDBConfiguration getConfig() { + return (TrainDBConfiguration) catalogReader.getConfig(); + } + private CaqpExecutionTimePolicy createCaqpExecutionTimePolicy(CalciteConnectionConfig config) throws Exception { CaqpExecutionTimePolicyType policy; diff --git a/traindb-core/src/main/java/traindb/planner/rules/ApproxAggregateInferenceRule.java b/traindb-core/src/main/java/traindb/planner/rules/ApproxAggregateInferenceRule.java index d3102c5..065a1c6 100644 --- a/traindb-core/src/main/java/traindb/planner/rules/ApproxAggregateInferenceRule.java +++ b/traindb-core/src/main/java/traindb/planner/rules/ApproxAggregateInferenceRule.java @@ -42,7 +42,6 @@ import traindb.adapter.python.PythonRel; import traindb.catalog.pm.MModel; import traindb.engine.AbstractTrainDBModelRunner; -import traindb.engine.TrainDBFileModelRunner; import traindb.planner.TrainDBPlanner; @Value.Enclosing @@ -116,11 +115,10 @@ public void onMatch(RelOptRuleCall call) { // TODO choose the best inference model final MModel bestInferenceModel = candidateModels.iterator().next(); - - AbstractTrainDBModelRunner runner = - new TrainDBFileModelRunner(null, planner.getCatalogContext(), - bestInferenceModel.getModeltype().getModeltypeName(), - bestInferenceModel.getModelName()); + AbstractTrainDBModelRunner runner = AbstractTrainDBModelRunner.createModelRunner( + null, planner.getCatalogContext(), planner.getConfig(), + bestInferenceModel.getModeltype().getModeltypeName(), + bestInferenceModel.getModelName(), bestInferenceModel.getModeltype().getLocation()); PythonMLAggregateModel modelTable = new PythonMLAggregateModel(runner, aggregate.getAggCallList(), aggregate.getGroupSet(), aggregate.getInput().getRowType(), From 8dacb2174f09004d66a3852f62de45b2e6e67342 Mon Sep 17 00:00:00 2001 From: Taewhi Lee Date: Wed, 19 Jul 2023 19:13:24 +0900 Subject: [PATCH 3/7] Feat: add catalog functions to manage model training status --- .../java/traindb/catalog/CatalogContext.java | 6 ++ .../traindb/catalog/JDOCatalogContext.java | 48 ++++++++++++- .../main/java/traindb/catalog/pm/MModel.java | 18 +++++ .../traindb/catalog/pm/MTrainingStatus.java | 70 +++++++++++++++++++ 4 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 traindb-catalog/src/main/java/traindb/catalog/pm/MTrainingStatus.java diff --git a/traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java b/traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java index 1f18698..02ed38c 100644 --- a/traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java +++ b/traindb-catalog/src/main/java/traindb/catalog/CatalogContext.java @@ -26,6 +26,7 @@ import traindb.catalog.pm.MSynopsis; import traindb.catalog.pm.MTable; import traindb.catalog.pm.MTask; +import traindb.catalog.pm.MTrainingStatus; public interface CatalogContext { @@ -62,6 +63,11 @@ Collection getInferenceModels(String baseSchema, String baseTable) MModel getModel(String name); + Collection getTrainingStatus(Map filterPatterns) + throws CatalogException; + + void updateTrainingStatus(String modelName, String status) throws CatalogException; + /* Synopsis */ MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows, Double ratio) throws CatalogException; diff --git a/traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java b/traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java index 23595c5..5c4db8e 100644 --- a/traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java +++ b/traindb-catalog/src/main/java/traindb/catalog/JDOCatalogContext.java @@ -15,7 +15,10 @@ package traindb.catalog; import com.google.common.collect.ImmutableMap; +import java.sql.Timestamp; +import java.util.ArrayList; import java.util.Collection; +import java.util.Comparator; import java.util.List; import java.util.Map; import javax.jdo.PersistenceManager; @@ -32,6 +35,7 @@ import traindb.catalog.pm.MSynopsis; import traindb.catalog.pm.MTable; import traindb.catalog.pm.MTask; +import traindb.catalog.pm.MTrainingStatus; import traindb.common.TrainDBLogger; public final class JDOCatalogContext implements CatalogContext { @@ -139,10 +143,18 @@ public MModel trainModel( } } + MModeltype mModeltype = getModeltype(modeltypeName); MModel mModel = new MModel( - getModeltype(modeltypeName), modelName, schemaName, tableName, columnNames, + mModeltype, modelName, schemaName, tableName, columnNames, baseTableRows, trainedRows, options == null ? "" : options); pm.makePersistent(mModel); + + if (mModeltype.getLocation().equals("REMOTE")) { + MTrainingStatus mTrainingStatus = new MTrainingStatus(modelName, "TRAINING", + new Timestamp(System.currentTimeMillis()), mModel); + pm.makePersistent(mTrainingStatus); + } + return mModel; } catch (RuntimeException e) { throw new CatalogException("failed to train model '" + modelName + "'", e); @@ -178,6 +190,13 @@ public void dropModel(String name) throws CatalogException { tx.commit(); } + Collection trainingStatus = + getTrainingStatus(ImmutableMap.of("model_name", name)); + if (trainingStatus != null && trainingStatus.size() > 0) { + tx.begin(); + pm.deletePersistentAll(trainingStatus); + tx.commit(); + } } catch (RuntimeException e) { throw new CatalogException("failed to drop model '" + name + "'", e); } finally { @@ -228,6 +247,33 @@ public boolean modelExists(String name) { return null; } + @Override + public Collection getTrainingStatus(Map filterPatterns) + throws CatalogException { + try { + Query query = pm.newQuery(MTrainingStatus.class); + setFilterPatterns(query, filterPatterns); + return (List) query.execute(); + } catch (RuntimeException e) { + throw new CatalogException("failed to get training status", e); + } + } + + @Override + public void updateTrainingStatus(String modelName, String status) throws CatalogException { + try { + Query query = pm.newQuery(MTrainingStatus.class); + setFilterPatterns(query, ImmutableMap.of("model_name", modelName)); + List trainingStatus = (List) query.execute(); + Comparator comparator = Comparator.comparing(MTrainingStatus::getStartTime); + MTrainingStatus latestStatus = trainingStatus.stream().max(comparator).get(); + latestStatus.setTrainingStatus(status); + pm.makePersistent(latestStatus); + } catch (RuntimeException e) { + throw new CatalogException("failed to get training status", e); + } + } + @Override public MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows, @Nullable Double ratio) throws CatalogException { diff --git a/traindb-catalog/src/main/java/traindb/catalog/pm/MModel.java b/traindb-catalog/src/main/java/traindb/catalog/pm/MModel.java index a19a20c..2ef623c 100644 --- a/traindb-catalog/src/main/java/traindb/catalog/pm/MModel.java +++ b/traindb-catalog/src/main/java/traindb/catalog/pm/MModel.java @@ -14,6 +14,8 @@ package traindb.catalog.pm; +import java.util.Collection; +import java.util.Comparator; import java.util.List; import javax.jdo.annotations.Column; import javax.jdo.annotations.IdGeneratorStrategy; @@ -58,6 +60,9 @@ public final class MModel { @Persistent private byte[] model_options; + @Persistent(mappedBy = "model", dependentElement = "true") + private Collection training_status; + public MModel( MModeltype modeltype, String modelName, String schemaName, String tableName, List columns, @Nullable Long baseTableRows, @Nullable Long trainedRows, @@ -103,4 +108,17 @@ public long getTrainedRows() { public String getModelOptions() { return new String(model_options); } + + public Collection trainingStatus() { + return training_status; + } + + public boolean isEnabled() { + if (training_status.isEmpty() || training_status.size() == 0) { + return true; + } + Comparator comparator = Comparator.comparing(MTrainingStatus::getStartTime); + MTrainingStatus latestStatus = training_status.stream().max(comparator).get(); + return latestStatus.getTrainingStatus().equals("FINISHED"); + } } diff --git a/traindb-catalog/src/main/java/traindb/catalog/pm/MTrainingStatus.java b/traindb-catalog/src/main/java/traindb/catalog/pm/MTrainingStatus.java new file mode 100644 index 0000000..b75f906 --- /dev/null +++ b/traindb-catalog/src/main/java/traindb/catalog/pm/MTrainingStatus.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package traindb.catalog.pm; + +import java.sql.Timestamp; +import javax.jdo.annotations.Column; +import javax.jdo.annotations.IdGeneratorStrategy; +import javax.jdo.annotations.Index; +import javax.jdo.annotations.PersistenceCapable; +import javax.jdo.annotations.Persistent; +import javax.jdo.annotations.PrimaryKey; +import traindb.catalog.CatalogConstants; + +@PersistenceCapable +@Index(name="TRAINING_STATUS_IDX", members={"model_name", "start_time"}) +public final class MTrainingStatus { + @PrimaryKey + @Persistent(valueStrategy = IdGeneratorStrategy.INCREMENT) + private long id; + + @Persistent + @Column(length = CatalogConstants.IDENTIFIER_MAX_LENGTH) + private String model_name; + + @Persistent + private Timestamp start_time; + + @Persistent + @Column(length = 9) + // Status: TRAINING, FINISHED + private String training_status; + + @Persistent(dependent = "false") + private MModel model; + + public MTrainingStatus(String modelName, String status, Timestamp startTime, MModel model) { + this.model_name = modelName; + this.training_status = status; + this.start_time = startTime; + this.model = model; + } + + public Timestamp getStartTime() { + return start_time; + } + + public String getTrainingStatus() { + return training_status; + } + + public MModel getModel() { + return model; + } + + public void setTrainingStatus(String status) { + this.training_status = status; + } +} From 4bac7ea69ec1c72af08f7f7a4c0dc693066945eb Mon Sep 17 00:00:00 2001 From: Taewhi Lee Date: Wed, 19 Jul 2023 19:19:29 +0900 Subject: [PATCH 4/7] Feat: add an api to return status for the specified model --- .../engine/AbstractTrainDBModelRunner.java | 4 +++ .../engine/TrainDBFastApiModelRunner.java | 33 +++++++++++++++++-- traindb-model | 2 +- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/traindb-core/src/main/java/traindb/engine/AbstractTrainDBModelRunner.java b/traindb-core/src/main/java/traindb/engine/AbstractTrainDBModelRunner.java index 1fb9f31..c6b1b0f 100644 --- a/traindb-core/src/main/java/traindb/engine/AbstractTrainDBModelRunner.java +++ b/traindb-core/src/main/java/traindb/engine/AbstractTrainDBModelRunner.java @@ -53,6 +53,10 @@ public abstract String infer(String aggregateExpression, String groupByColumn, public abstract String listHyperparameters(String className, String uri) throws Exception; + public boolean checkAvailable(String modelName) throws Exception { + return true; + } + public Path getModelPath() { return Paths.get(TrainDBConfiguration.getTrainDBPrefixPath(), "models", modeltypeName, modelName); diff --git a/traindb-core/src/main/java/traindb/engine/TrainDBFastApiModelRunner.java b/traindb-core/src/main/java/traindb/engine/TrainDBFastApiModelRunner.java index 4e8e6de..079adbb 100644 --- a/traindb-core/src/main/java/traindb/engine/TrainDBFastApiModelRunner.java +++ b/traindb-core/src/main/java/traindb/engine/TrainDBFastApiModelRunner.java @@ -204,6 +204,11 @@ public String infer(String aggregateExpression, String groupByColumn, String whe return outputPath; } + private String unescapeString(String s) { + // remove beginning/ending double quotes and unescape + return StringEscapeUtils.unescapeJava(s.replaceAll("^\"|\"$", "")); + } + @Override public String listHyperparameters(String className, String uri) throws Exception { URL url = new URL(checkTrailingSlash(uri) + "modeltype/" + className + "/hyperparams"); @@ -222,8 +227,32 @@ public String listHyperparameters(String className, String uri) throws Exception response.append(line); } - // remove beginning/ending double quotes and unescape - return StringEscapeUtils.unescapeJava(response.toString().replaceAll("^\"|\"$", "")); + return unescapeString(response.toString()); + } + + @Override + public boolean checkAvailable(String modelName) throws Exception { + MModeltype mModeltype = catalogContext.getModel(modelName).getModeltype(); + URL url = new URL(checkTrailingSlash(mModeltype.getUri()) + "model/" + modelName + "/status"); + HttpURLConnection httpConn = (HttpURLConnection) url.openConnection(); + httpConn.setRequestMethod("GET"); + + if (httpConn.getResponseCode() != HttpURLConnection.HTTP_OK) { + throw new TrainDBException("failed to get model status"); + } + + StringBuilder response = new StringBuilder(); + BufferedReader reader = new BufferedReader( + new InputStreamReader(httpConn.getInputStream(), StandardCharsets.UTF_8)); + String line; + while ((line = reader.readLine()) != null) { + response.append(line); + } + String res = unescapeString(response.toString()); + if (res.equalsIgnoreCase("FINISHED")) { + return true; + } + return false; } } \ No newline at end of file diff --git a/traindb-model b/traindb-model index 8b15f94..28b826b 160000 --- a/traindb-model +++ b/traindb-model @@ -1 +1 @@ -Subproject commit 8b15f9444296ec694723df4f1fdf894378569360 +Subproject commit 28b826b36d36080a68ee17d4d346d322e99c55b8 From f182da55678141bd6a5dc87e6a2fef0bbd45935e Mon Sep 17 00:00:00 2001 From: Taewhi Lee Date: Wed, 19 Jul 2023 19:21:37 +0900 Subject: [PATCH 5/7] Feat: check model status when creating synopsis --- .../src/main/java/traindb/engine/TrainDBQueryEngine.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java b/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java index 8778f3a..b2b8e08 100644 --- a/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java +++ b/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java @@ -312,6 +312,14 @@ public void createSynopsis(String synopsisName, String modelName, int limitNumbe AbstractTrainDBModelRunner runner = createModelRunner( mModeltype.getModeltypeName(), modelName, mModeltype.getLocation()); + + if (!mModel.isEnabled()) { // remote model + if (!runner.checkAvailable(modelName)) { + throw new TrainDBException( + "model '" + modelName + "' is not available (training is not finished)"); + } + catalogContext.updateTrainingStatus(modelName, "FINISHED"); + } String outputPath = runner.getModelPath().toString() + '/' + synopsisName + ".csv"; runner.generateSynopsis(outputPath, limitNumber); T_tracer.closeTaskTime("SUCCESS"); From 3503247d98cbc1acfbce436a14c32a9c8e0220cb Mon Sep 17 00:00:00 2001 From: Taewhi Lee Date: Wed, 19 Jul 2023 20:25:24 +0900 Subject: [PATCH 6/7] Feat: check model status when choosing inference model --- .../rules/ApproxAggregateInferenceRule.java | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/traindb-core/src/main/java/traindb/planner/rules/ApproxAggregateInferenceRule.java b/traindb-core/src/main/java/traindb/planner/rules/ApproxAggregateInferenceRule.java index 065a1c6..acb5d2b 100644 --- a/traindb-core/src/main/java/traindb/planner/rules/ApproxAggregateInferenceRule.java +++ b/traindb-core/src/main/java/traindb/planner/rules/ApproxAggregateInferenceRule.java @@ -17,6 +17,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import org.apache.calcite.plan.RelOptCluster; @@ -41,6 +42,7 @@ import traindb.adapter.python.PythonMLAggregateModelScan; import traindb.adapter.python.PythonRel; import traindb.catalog.pm.MModel; +import traindb.common.TrainDBException; import traindb.engine.AbstractTrainDBModelRunner; import traindb.planner.TrainDBPlanner; @@ -114,8 +116,32 @@ public void onMatch(RelOptRuleCall call) { } // TODO choose the best inference model - final MModel bestInferenceModel = candidateModels.iterator().next(); - AbstractTrainDBModelRunner runner = AbstractTrainDBModelRunner.createModelRunner( + // It is currently assumed that the candidate models are already sorted by cost + AbstractTrainDBModelRunner runner; + MModel bestInferenceModel = null; + for (Iterator iter = candidateModels.iterator(); iter.hasNext(); ) { + bestInferenceModel = iter.next(); + if (bestInferenceModel.isEnabled()) { + break; + } + try { + runner = AbstractTrainDBModelRunner.createModelRunner( + null, planner.getCatalogContext(), planner.getConfig(), + bestInferenceModel.getModeltype().getModeltypeName(), + bestInferenceModel.getModelName(), bestInferenceModel.getModeltype().getLocation()); + if (runner.checkAvailable(bestInferenceModel.getModelName())) { + planner.getCatalogContext().updateTrainingStatus(bestInferenceModel.getModelName(), "FINISHED"); + break; + } + } catch (Exception e) { + // ignore + } + } + if (bestInferenceModel == null || !bestInferenceModel.isEnabled()) { + return; + } + + runner = AbstractTrainDBModelRunner.createModelRunner( null, planner.getCatalogContext(), planner.getConfig(), bestInferenceModel.getModeltype().getModeltypeName(), bestInferenceModel.getModelName(), bestInferenceModel.getModeltype().getLocation()); From ce4027b845b16deacce37709c61a0ffd92671db4 Mon Sep 17 00:00:00 2001 From: Taewhi Lee Date: Wed, 19 Jul 2023 21:09:05 +0900 Subject: [PATCH 7/7] Feat: implement the SHOW TRAININGS command --- .../traindb/catalog/pm/MTrainingStatus.java | 4 ++ .../src/main/antlr4/traindb/sql/TrainDBSql.g4 | 2 + .../traindb/engine/TrainDBQueryEngine.java | 48 ++++++++++++++++++- .../src/main/java/traindb/sql/TrainDBSql.java | 5 ++ .../java/traindb/sql/TrainDBSqlCommand.java | 1 + .../java/traindb/sql/TrainDBSqlRunner.java | 2 + .../traindb/sql/TrainDBSqlShowCommand.java | 11 +++++ 7 files changed, 72 insertions(+), 1 deletion(-) diff --git a/traindb-catalog/src/main/java/traindb/catalog/pm/MTrainingStatus.java b/traindb-catalog/src/main/java/traindb/catalog/pm/MTrainingStatus.java index b75f906..016dd6e 100644 --- a/traindb-catalog/src/main/java/traindb/catalog/pm/MTrainingStatus.java +++ b/traindb-catalog/src/main/java/traindb/catalog/pm/MTrainingStatus.java @@ -52,6 +52,10 @@ public MTrainingStatus(String modelName, String status, Timestamp startTime, MMo this.model = model; } + public String getModelName() { + return model_name; + } + public Timestamp getStartTime() { return start_time; } diff --git a/traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4 b/traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4 index dd282a2..7796269 100644 --- a/traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4 +++ b/traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4 @@ -125,6 +125,7 @@ showTargets | K_SCHEMAS | K_TABLES | K_HYPERPARAMETERS + | K_TRAININGS | K_QUERYLOGS | K_TASKS ; @@ -241,6 +242,7 @@ K_SYNOPSIS : S Y N O P S I S ; K_TABLES : T A B L E S ; K_TASKS : T A S K S ; K_TRAIN : T R A I N ; +K_TRAININGS : T R A I N I N G S; K_USE : U S E ; K_WHERE : W H E R E ; diff --git a/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java b/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java index b2b8e08..2a0f8ad 100644 --- a/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java +++ b/traindb-core/src/main/java/traindb/engine/TrainDBQueryEngine.java @@ -29,7 +29,6 @@ import org.codehaus.jackson.annotate.JsonIgnoreProperties; import org.codehaus.jackson.map.ObjectMapper; import org.json.simple.JSONObject; -import org.json.simple.parser.JSONParser; import traindb.adapter.TrainDBSqlDialect; import traindb.catalog.CatalogContext; import traindb.catalog.CatalogException; @@ -38,6 +37,8 @@ import traindb.catalog.pm.MQueryLog; import traindb.catalog.pm.MSynopsis; import traindb.catalog.pm.MTask; +import traindb.catalog.pm.MTrainingStatus; +import traindb.common.TrainDBConfiguration; import traindb.common.TrainDBException; import traindb.common.TrainDBLogger; import traindb.jdbc.TrainDBConnectionImpl; @@ -530,6 +531,41 @@ public TrainDBListResultSet showHyperparameters(Map filterPatter return new TrainDBListResultSet(header, hyperparamInfo); } + @Override + public TrainDBListResultSet showTrainings(Map filterPatterns) throws Exception { + List header = Arrays.asList("model_name", "model_server", "start_time", "status"); + checkShowWhereColumns(filterPatterns, header); + replacePatternFilterColumn(filterPatterns, "model_server", "model.modeltype.uri"); + + T_tracer.startTaskTracer("show trainings"); + T_tracer.openTaskTime("scan : training status"); + + List> trainingInfo = new ArrayList<>(); + for (MTrainingStatus mTraining : catalogContext.getTrainingStatus(filterPatterns)) { + if (mTraining.getTrainingStatus().equals("TRAINING")) { + AbstractTrainDBModelRunner runner = + AbstractTrainDBModelRunner.createModelRunner( + conn, catalogContext, conn.cfg, mTraining.getModel().getModeltype().getModeltypeName(), + mTraining.getModelName(), mTraining.getModel().getModeltype().getLocation()); + try { + if (runner.checkAvailable(mTraining.getModelName())) { + catalogContext.updateTrainingStatus(mTraining.getModelName(), "FINISHED"); + } + } catch (Exception e) { + // ignore + } + } + trainingInfo.add(Arrays.asList(mTraining.getModelName(), + mTraining.getModel().getModeltype().getUri(), mTraining.getStartTime(), + mTraining.getTrainingStatus())); + } + + T_tracer.closeTaskTime("SUCCESS"); + T_tracer.endTaskTracer(); + + return new TrainDBListResultSet(header, trainingInfo); + } + @Override public void useSchema(String schemaName) throws Exception { T_tracer.startTaskTracer("use " + schemaName); @@ -661,6 +697,16 @@ private void addPrefixToPatternFilter(Map patterns, List } } + private void replacePatternFilterColumn(Map patterns, + String before, String after) { + for (String key : patterns.keySet()) { + if (key.equals(before)) { + patterns.put(after, patterns.remove(before)); + return; + } + } + } + @JsonIgnoreProperties(ignoreUnknown = true) static class Hyperparameter { private String name; diff --git a/traindb-core/src/main/java/traindb/sql/TrainDBSql.java b/traindb-core/src/main/java/traindb/sql/TrainDBSql.java index b33aa81..417f229 100644 --- a/traindb-core/src/main/java/traindb/sql/TrainDBSql.java +++ b/traindb-core/src/main/java/traindb/sql/TrainDBSql.java @@ -107,6 +107,9 @@ public static TrainDBListResultSet run(TrainDBSqlCommand command, TrainDBSqlRunn case SHOW_HYPERPARAMETERS: TrainDBSqlShowCommand showHyperparams = (TrainDBSqlShowCommand) command; return runner.showHyperparameters(showHyperparams.getWhereExpressionMap()); + case SHOW_TRAININGS: + TrainDBSqlShowCommand showTrainings = (TrainDBSqlShowCommand) command; + return runner.showTrainings(showTrainings.getWhereExpressionMap()); case USE_SCHEMA: TrainDBSqlUseSchema useSchema = (TrainDBSqlUseSchema) command; runner.useSchema(useSchema.getSchemaName()); @@ -202,6 +205,8 @@ public void exitShowStmt(TrainDBSqlParser.ShowStmtContext ctx) { commands.add(new TrainDBSqlShowCommand.Tables(whereExprMap)); } else if (showTarget.equals("HYPERPARAMETERS")) { commands.add(new TrainDBSqlShowCommand.Hyperparameters(whereExprMap)); + } else if (showTarget.equals("TRAININGS")) { + commands.add(new TrainDBSqlShowCommand.Trainings(whereExprMap)); } else if (showTarget.equals("QUERYLOGS")) { commands.add(new TrainDBSqlShowCommand.QueryLogs(whereExprMap)); } else if (showTarget.equals("TASKS")) { diff --git a/traindb-core/src/main/java/traindb/sql/TrainDBSqlCommand.java b/traindb-core/src/main/java/traindb/sql/TrainDBSqlCommand.java index 2d20616..f6db543 100644 --- a/traindb-core/src/main/java/traindb/sql/TrainDBSqlCommand.java +++ b/traindb-core/src/main/java/traindb/sql/TrainDBSqlCommand.java @@ -22,6 +22,7 @@ public enum Type { DROP_MODELTYPE, SHOW_MODELTYPES, SHOW_MODELS, + SHOW_TRAININGS, TRAIN_MODEL, DROP_MODEL, CREATE_SYNOPSIS, diff --git a/traindb-core/src/main/java/traindb/sql/TrainDBSqlRunner.java b/traindb-core/src/main/java/traindb/sql/TrainDBSqlRunner.java index d0c091d..f47a869 100644 --- a/traindb-core/src/main/java/traindb/sql/TrainDBSqlRunner.java +++ b/traindb-core/src/main/java/traindb/sql/TrainDBSqlRunner.java @@ -46,6 +46,8 @@ void trainModel(String modeltypeName, String modelName, String schemaName, Strin TrainDBListResultSet showHyperparameters(Map filterPatterns) throws Exception; + TrainDBListResultSet showTrainings(Map filterPatterns) throws Exception; + void useSchema(String schemaName) throws Exception; TrainDBListResultSet describeTable(String schemaName, String tableName) throws Exception; diff --git a/traindb-core/src/main/java/traindb/sql/TrainDBSqlShowCommand.java b/traindb-core/src/main/java/traindb/sql/TrainDBSqlShowCommand.java index 4248fd4..9c7c993 100644 --- a/traindb-core/src/main/java/traindb/sql/TrainDBSqlShowCommand.java +++ b/traindb-core/src/main/java/traindb/sql/TrainDBSqlShowCommand.java @@ -94,6 +94,17 @@ public Type getType() { } } + static class Trainings extends TrainDBSqlShowCommand { + Trainings(Map whereExprMap) { + super(whereExprMap); + } + + @Override + public Type getType() { + return Type.SHOW_TRAININGS; + } + } + static class QueryLogs extends TrainDBSqlShowCommand { QueryLogs(Map whereExprMap) { super(whereExprMap);