diff --git a/api/src/main/java/ai/djl/training/loss/TabNetClassificationLoss.java b/api/src/main/java/ai/djl/training/loss/TabNetClassificationLoss.java index 58756054dd1..c556358e75f 100644 --- a/api/src/main/java/ai/djl/training/loss/TabNetClassificationLoss.java +++ b/api/src/main/java/ai/djl/training/loss/TabNetClassificationLoss.java @@ -42,6 +42,6 @@ public TabNetClassificationLoss(String name) { public NDArray evaluate(NDList labels, NDList predictions) { return Loss.softmaxCrossEntropyLoss() .evaluate(labels, new NDList(predictions.get(0))) - .add(predictions.get(1)); + .add(predictions.get(1).mean()); } } diff --git a/api/src/main/java/ai/djl/training/loss/TabNetRegressionLoss.java b/api/src/main/java/ai/djl/training/loss/TabNetRegressionLoss.java index 21f2901366f..b3776bfb7cb 100644 --- a/api/src/main/java/ai/djl/training/loss/TabNetRegressionLoss.java +++ b/api/src/main/java/ai/djl/training/loss/TabNetRegressionLoss.java @@ -46,6 +46,6 @@ public NDArray evaluate(NDList labels, NDList predictions) { .sub(predictions.get(0)) .square() .mean() - .add(predictions.get(1)); + .add(predictions.get(1).mean()); } } diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/TabularDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/TabularDataset.java index 15351c844cc..bdd4dbecd2c 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/TabularDataset.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/TabularDataset.java @@ -105,6 +105,21 @@ public Record get(NDManager manager, long index) { return new Record(data, label); } + /** + * Returns the direct designated features (either data or label features) from a row. + * + * @param index the index of the requested data item + * @param selected the features to pull from the row + * @return the direct features + */ + public List getRowDirect(long index, List selected) { + List results = new ArrayList<>(selected.size()); + for (Feature feature : selected) { + results.add(getCell(index, feature.getName())); + } + return results; + } + /** * Returns the designated features (either data or label features) from a row. * diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/TabularTranslator.java b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/TabularTranslator.java index 49d740afdcf..d1796377948 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/tabular/TabularTranslator.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/tabular/TabularTranslator.java @@ -62,7 +62,7 @@ public TabularTranslator(Model model, Map arguments) { @Override public TabularResults processOutput(TranslatorContext ctx, NDList list) throws Exception { List results = new ArrayList<>(labels.size()); - float[] data = list.singletonOrThrow().toFloatArray(); + float[] data = list.head().toFloatArray(); int dataIndex = 0; for (Feature label : labels) { Featurizer featurizer = label.getFeaturizer(); diff --git a/examples/src/main/java/ai/djl/examples/training/TrainAirfoilWithTabNet.java b/examples/src/main/java/ai/djl/examples/training/TrainAirfoilWithTabNet.java index 77e3fae0ca3..06cf66b2a6f 100644 --- a/examples/src/main/java/ai/djl/examples/training/TrainAirfoilWithTabNet.java +++ b/examples/src/main/java/ai/djl/examples/training/TrainAirfoilWithTabNet.java @@ -14,9 +14,13 @@ import ai.djl.Model; import ai.djl.basicdataset.tabular.AirfoilRandomAccess; +import ai.djl.basicdataset.tabular.ListFeatures; +import ai.djl.basicdataset.tabular.TabularDataset; +import ai.djl.basicdataset.tabular.TabularResults; import ai.djl.basicmodelzoo.tabular.TabNet; import ai.djl.engine.Engine; import ai.djl.examples.training.util.Arguments; +import ai.djl.inference.Predictor; import ai.djl.metric.Metrics; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; @@ -31,6 +35,7 @@ import ai.djl.training.loss.TabNetRegressionLoss; import ai.djl.training.util.ProgressBar; import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; import java.io.IOException; @@ -54,9 +59,10 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans model.setBlock(tabNet); // get the training and validation dataset - RandomAccessDataset[] randomAccessDatasets = getDataset(arguments); - RandomAccessDataset trainingSet = randomAccessDatasets[0]; - RandomAccessDataset validateSet = randomAccessDatasets[1]; + TabularDataset dataset = getDataset(arguments); + RandomAccessDataset[] split = dataset.randomSplit(8, 2); + RandomAccessDataset trainingSet = split[0]; + RandomAccessDataset validateSet = split[1]; // setup training configuration DefaultTrainingConfig config = setupTrainingConfig(arguments); @@ -71,6 +77,16 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans EasyTrain.fit(trainer, arguments.getEpoch(), trainingSet, validateSet); + Translator translator = + dataset.matchingTranslatorOptions() + .option(ListFeatures.class, TabularResults.class); + try (Predictor predictor = + model.newPredictor(translator)) { + ListFeatures input = + new ListFeatures(dataset.getRowDirect(3, dataset.getFeatures())); + predictor.predict(input); + } + return trainer.getTrainingResult(); } } @@ -92,7 +108,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) { .addTrainingListeners(listener); } - private static RandomAccessDataset[] getDataset(Arguments arguments) + private static TabularDataset getDataset(Arguments arguments) throws IOException, TranslateException { AirfoilRandomAccess.Builder airfoilBuilder = AirfoilRandomAccess.builder(); @@ -106,6 +122,6 @@ private static RandomAccessDataset[] getDataset(Arguments arguments) AirfoilRandomAccess airfoilRandomAccess = airfoilBuilder.build(); airfoilRandomAccess.prepare(new ProgressBar()); // split the dataset into - return airfoilRandomAccess.randomSplit(8, 2); + return airfoilRandomAccess; } } diff --git a/model-zoo/src/main/java/ai/djl/basicmodelzoo/tabular/TabNet.java b/model-zoo/src/main/java/ai/djl/basicmodelzoo/tabular/TabNet.java index be2f8d00706..3ebf404a4c4 100644 --- a/model-zoo/src/main/java/ai/djl/basicmodelzoo/tabular/TabNet.java +++ b/model-zoo/src/main/java/ai/djl/basicmodelzoo/tabular/TabNet.java @@ -444,8 +444,7 @@ protected NDList forwardInternal( NDArray sparseLoss = mask.singletonOrThrow() .mul(-1) - .mul(NDArrays.add(mask.singletonOrThrow(), 1e-10).log()) - .mean(); + .mul(NDArrays.add(mask.singletonOrThrow(), 1e-10).log()); NDList x1 = featureTransformer.forward(parameterStore, new NDList(x), training); return new NDList(x1.singletonOrThrow(), sparseLoss); }