Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] Add getNumFeature method #6075

Merged
merged 4 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,17 @@ void saveRabitCheckpoint() throws XGBoostError {
version += 1;
}

/**
* Get number of model features.
* @return the number of features.
* @throws XGBoostError
*/
public long getNumFeature() throws XGBoostError {
long[] numFeature = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
return numFeature[0];
}

/**
* Internal initialization function.
* @param cacheMats The cached DMatrix.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ public final static native int XGBoosterDumpModelExWithFeatures(
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
public final static native int XGBoosterGetNumFeature(long handle, long[] feature);

// rabit functions
public final static native int RabitInit(String[] args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,14 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
.asScala.mapValues(_.doubleValue).toSeq: _*)
}

/**
* Get the number of model features.
*
* @return number of features
*/
@throws(classOf[XGBoostError])
def getNumFeature: Long = booster.getNumFeature

def getVersion: Int = booster.getVersion

def toByteArray: Array[Byte] = {
Expand Down
16 changes: 16 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveRabit
return XGBoosterSaveRabitCheckpoint(handle);
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterGetNumFeature
* Signature: (J[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
bst_ulong num_feature;
int ret = XGBoosterGetNumFeature(handle, &num_feature);
JVM_CHECK_CALL(ret);
jlong jnum_feature = num_feature;
jenv->SetLongArrayRegion(jout, 0, 1, &jnum_feature);
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: RabitInit
Expand Down
8 changes: 8 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -646,4 +646,18 @@ public void testSetAndGetAttrs() throws XGBoostError {
TestCase.assertEquals(attr.get("bb"), "BB");
TestCase.assertEquals(attr.get("cc"), "CC");
}

/**
* test get number of features from a booster
*
* @throws XGBoostError
*/
@Test
public void testGetNumFeature() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

Booster booster = trainBooster(trainMat, testMat);
TestCase.assertEquals(booster.getNumFeature(), 127);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,12 @@ class ScalaBoosterImplSuite extends FunSuite {
val nextBooster = XGBoost.train(trainMat, paramMap, round = 4, booster = prevBooster)
assert(prevBooster == nextBooster)
}

test("test getting number of features from a booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val booster = trainBooster(trainMat, testMat)

TestCase.assertEquals(booster.getNumFeature, 127)
}
}