Skip to content

Commit

Permalink
Merge pull request #38 from traindb-project/issue-12
Browse files Browse the repository at this point in the history
Feat: Train models on remote model servers
  • Loading branch information
taewhi authored Jul 20, 2023
2 parents 5a0f797 + ce4027b commit b125b42
Show file tree
Hide file tree
Showing 15 changed files with 542 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import traindb.catalog.pm.MSynopsis;
import traindb.catalog.pm.MTable;
import traindb.catalog.pm.MTask;
import traindb.catalog.pm.MTrainingStatus;

public interface CatalogContext {

Expand Down Expand Up @@ -62,6 +63,11 @@ Collection<MModel> getInferenceModels(String baseSchema, String baseTable)

MModel getModel(String name);

Collection<MTrainingStatus> getTrainingStatus(Map<String, Object> filterPatterns)
throws CatalogException;

void updateTrainingStatus(String modelName, String status) throws CatalogException;

/* Synopsis */
MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows, Double ratio)
throws CatalogException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
package traindb.catalog;

import com.google.common.collect.ImmutableMap;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import javax.jdo.PersistenceManager;
Expand All @@ -32,6 +35,7 @@
import traindb.catalog.pm.MSynopsis;
import traindb.catalog.pm.MTable;
import traindb.catalog.pm.MTask;
import traindb.catalog.pm.MTrainingStatus;
import traindb.common.TrainDBLogger;

public final class JDOCatalogContext implements CatalogContext {
Expand Down Expand Up @@ -139,10 +143,18 @@ public MModel trainModel(
}
}

MModeltype mModeltype = getModeltype(modeltypeName);
MModel mModel = new MModel(
getModeltype(modeltypeName), modelName, schemaName, tableName, columnNames,
mModeltype, modelName, schemaName, tableName, columnNames,
baseTableRows, trainedRows, options == null ? "" : options);
pm.makePersistent(mModel);

if (mModeltype.getLocation().equals("REMOTE")) {
MTrainingStatus mTrainingStatus = new MTrainingStatus(modelName, "TRAINING",
new Timestamp(System.currentTimeMillis()), mModel);
pm.makePersistent(mTrainingStatus);
}

return mModel;
} catch (RuntimeException e) {
throw new CatalogException("failed to train model '" + modelName + "'", e);
Expand Down Expand Up @@ -178,6 +190,13 @@ public void dropModel(String name) throws CatalogException {
tx.commit();
}

Collection<MTrainingStatus> trainingStatus =
getTrainingStatus(ImmutableMap.of("model_name", name));
if (trainingStatus != null && trainingStatus.size() > 0) {
tx.begin();
pm.deletePersistentAll(trainingStatus);
tx.commit();
}
} catch (RuntimeException e) {
throw new CatalogException("failed to drop model '" + name + "'", e);
} finally {
Expand Down Expand Up @@ -228,6 +247,33 @@ public boolean modelExists(String name) {
return null;
}

@Override
public Collection<MTrainingStatus> getTrainingStatus(Map<String, Object> filterPatterns)
throws CatalogException {
try {
Query query = pm.newQuery(MTrainingStatus.class);
setFilterPatterns(query, filterPatterns);
return (List<MTrainingStatus>) query.execute();
} catch (RuntimeException e) {
throw new CatalogException("failed to get training status", e);
}
}

@Override
public void updateTrainingStatus(String modelName, String status) throws CatalogException {
try {
Query query = pm.newQuery(MTrainingStatus.class);
setFilterPatterns(query, ImmutableMap.of("model_name", modelName));
List<MTrainingStatus> trainingStatus = (List<MTrainingStatus>) query.execute();
Comparator<MTrainingStatus> comparator = Comparator.comparing(MTrainingStatus::getStartTime);
MTrainingStatus latestStatus = trainingStatus.stream().max(comparator).get();
latestStatus.setTrainingStatus(status);
pm.makePersistent(latestStatus);
} catch (RuntimeException e) {
throw new CatalogException("failed to get training status", e);
}
}

@Override
public MSynopsis createSynopsis(String synopsisName, String modelName, Integer rows,
@Nullable Double ratio) throws CatalogException {
Expand Down
18 changes: 18 additions & 0 deletions traindb-catalog/src/main/java/traindb/catalog/pm/MModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package traindb.catalog.pm;

import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import javax.jdo.annotations.Column;
import javax.jdo.annotations.IdGeneratorStrategy;
Expand Down Expand Up @@ -58,6 +60,9 @@ public final class MModel {
@Persistent
private byte[] model_options;

@Persistent(mappedBy = "model", dependentElement = "true")
private Collection<MTrainingStatus> training_status;

public MModel(
MModeltype modeltype, String modelName, String schemaName, String tableName,
List<String> columns, @Nullable Long baseTableRows, @Nullable Long trainedRows,
Expand Down Expand Up @@ -103,4 +108,17 @@ public long getTrainedRows() {
public String getModelOptions() {
return new String(model_options);
}

public Collection<MTrainingStatus> trainingStatus() {
return training_status;
}

public boolean isEnabled() {
if (training_status.isEmpty() || training_status.size() == 0) {
return true;
}
Comparator<MTrainingStatus> comparator = Comparator.comparing(MTrainingStatus::getStartTime);
MTrainingStatus latestStatus = training_status.stream().max(comparator).get();
return latestStatus.getTrainingStatus().equals("FINISHED");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package traindb.catalog.pm;

import java.sql.Timestamp;
import javax.jdo.annotations.Column;
import javax.jdo.annotations.IdGeneratorStrategy;
import javax.jdo.annotations.Index;
import javax.jdo.annotations.PersistenceCapable;
import javax.jdo.annotations.Persistent;
import javax.jdo.annotations.PrimaryKey;
import traindb.catalog.CatalogConstants;

@PersistenceCapable
@Index(name="TRAINING_STATUS_IDX", members={"model_name", "start_time"})
public final class MTrainingStatus {
@PrimaryKey
@Persistent(valueStrategy = IdGeneratorStrategy.INCREMENT)
private long id;

@Persistent
@Column(length = CatalogConstants.IDENTIFIER_MAX_LENGTH)
private String model_name;

@Persistent
private Timestamp start_time;

@Persistent
@Column(length = 9)
// Status: TRAINING, FINISHED
private String training_status;

@Persistent(dependent = "false")
private MModel model;

public MTrainingStatus(String modelName, String status, Timestamp startTime, MModel model) {
this.model_name = modelName;
this.training_status = status;
this.start_time = startTime;
this.model = model;
}

public String getModelName() {
return model_name;
}

public Timestamp getStartTime() {
return start_time;
}

public String getTrainingStatus() {
return training_status;
}

public MModel getModel() {
return model;
}

public void setTrainingStatus(String status) {
this.training_status = status;
}
}
2 changes: 2 additions & 0 deletions traindb-core/src/main/antlr4/traindb/sql/TrainDBSql.g4
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ showTargets
| K_SCHEMAS
| K_TABLES
| K_HYPERPARAMETERS
| K_TRAININGS
| K_QUERYLOGS
| K_TASKS
;
Expand Down Expand Up @@ -241,6 +242,7 @@ K_SYNOPSIS : S Y N O P S I S ;
K_TABLES : T A B L E S ;
K_TASKS : T A S K S ;
K_TRAIN : T R A I N ;
K_TRAININGS : T R A I N I N G S;
K_USE : U S E ;
K_WHERE : W H E R E ;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,30 @@ public abstract String infer(String aggregateExpression, String groupByColumn,

public abstract String listHyperparameters(String className, String uri) throws Exception;

public boolean checkAvailable(String modelName) throws Exception {
return true;
}

public Path getModelPath() {
return Paths.get(TrainDBConfiguration.getTrainDBPrefixPath(), "models",
modeltypeName, modelName);
}

public static AbstractTrainDBModelRunner createModelRunner(
TrainDBConnectionImpl conn, CatalogContext catalogContext, TrainDBConfiguration config,
String modeltypeName, String modelName, String location) {
if (location.equals("REMOTE")) {
return new TrainDBFastApiModelRunner(conn, catalogContext, modeltypeName, modelName);
}
// location.equals("LOCAL")
if (config.getModelRunner().equals("py4j")) {
return new TrainDBPy4JModelRunner(conn, catalogContext, modeltypeName, modelName);
}

return new TrainDBFileModelRunner(conn, catalogContext, modeltypeName, modelName);
}


protected String buildSelectTrainingDataQuery(String schemaName, String tableName,
List<String> columnNames) {
StringBuilder sb = new StringBuilder();
Expand Down
Loading

0 comments on commit b125b42

Please sign in to comment.