Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tabnet predictor #2643

Merged
merged 1 commit into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ public TabNetClassificationLoss(String name) {
public NDArray evaluate(NDList labels, NDList predictions) {
return Loss.softmaxCrossEntropyLoss()
.evaluate(labels, new NDList(predictions.get(0)))
.add(predictions.get(1));
.add(predictions.get(1).mean());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ public NDArray evaluate(NDList labels, NDList predictions) {
.sub(predictions.get(0))
.square()
.mean()
.add(predictions.get(1));
.add(predictions.get(1).mean());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ public Record get(NDManager manager, long index) {
return new Record(data, label);
}

/**
* Returns the direct designated features (either data or label features) from a row.
*
* @param index the index of the requested data item
* @param selected the features to pull from the row
* @return the direct features
*/
public List<String> getRowDirect(long index, List<Feature> selected) {
List<String> results = new ArrayList<>(selected.size());
for (Feature feature : selected) {
results.add(getCell(index, feature.getName()));
}
return results;
}

/**
* Returns the designated features (either data or label features) from a row.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public TabularTranslator(Model model, Map<String, ?> arguments) {
@Override
public TabularResults processOutput(TranslatorContext ctx, NDList list) throws Exception {
List<TabularResult> results = new ArrayList<>(labels.size());
float[] data = list.singletonOrThrow().toFloatArray();
float[] data = list.head().toFloatArray();
int dataIndex = 0;
for (Feature label : labels) {
Featurizer featurizer = label.getFeaturizer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@

import ai.djl.Model;
import ai.djl.basicdataset.tabular.AirfoilRandomAccess;
import ai.djl.basicdataset.tabular.ListFeatures;
import ai.djl.basicdataset.tabular.TabularDataset;
import ai.djl.basicdataset.tabular.TabularResults;
import ai.djl.basicmodelzoo.tabular.TabNet;
import ai.djl.engine.Engine;
import ai.djl.examples.training.util.Arguments;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
Expand All @@ -31,6 +35,7 @@
import ai.djl.training.loss.TabNetRegressionLoss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;

import java.io.IOException;

Expand All @@ -54,9 +59,10 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans
model.setBlock(tabNet);

// get the training and validation dataset
RandomAccessDataset[] randomAccessDatasets = getDataset(arguments);
RandomAccessDataset trainingSet = randomAccessDatasets[0];
RandomAccessDataset validateSet = randomAccessDatasets[1];
TabularDataset dataset = getDataset(arguments);
RandomAccessDataset[] split = dataset.randomSplit(8, 2);
RandomAccessDataset trainingSet = split[0];
RandomAccessDataset validateSet = split[1];

// setup training configuration
DefaultTrainingConfig config = setupTrainingConfig(arguments);
Expand All @@ -71,6 +77,16 @@ public static TrainingResult runExample(String[] args) throws IOException, Trans

EasyTrain.fit(trainer, arguments.getEpoch(), trainingSet, validateSet);

Translator<ListFeatures, TabularResults> translator =
dataset.matchingTranslatorOptions()
.option(ListFeatures.class, TabularResults.class);
try (Predictor<ListFeatures, TabularResults> predictor =
model.newPredictor(translator)) {
ListFeatures input =
new ListFeatures(dataset.getRowDirect(3, dataset.getFeatures()));
predictor.predict(input);
}

return trainer.getTrainingResult();
}
}
Expand All @@ -92,7 +108,7 @@ private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
.addTrainingListeners(listener);
}

private static RandomAccessDataset[] getDataset(Arguments arguments)
private static TabularDataset getDataset(Arguments arguments)
throws IOException, TranslateException {
AirfoilRandomAccess.Builder airfoilBuilder = AirfoilRandomAccess.builder();

Expand All @@ -106,6 +122,6 @@ private static RandomAccessDataset[] getDataset(Arguments arguments)
AirfoilRandomAccess airfoilRandomAccess = airfoilBuilder.build();
airfoilRandomAccess.prepare(new ProgressBar());
// split the dataset into
return airfoilRandomAccess.randomSplit(8, 2);
return airfoilRandomAccess;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,7 @@ protected NDList forwardInternal(
NDArray sparseLoss =
mask.singletonOrThrow()
.mul(-1)
.mul(NDArrays.add(mask.singletonOrThrow(), 1e-10).log())
.mean();
.mul(NDArrays.add(mask.singletonOrThrow(), 1e-10).log());
NDList x1 = featureTransformer.forward(parameterStore, new NDList(x), training);
return new NDList(x1.singletonOrThrow(), sparseLoss);
}
Expand Down