diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 85897412f9a6..16817bf5ad1c 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -402,6 +402,7 @@ XGB_EXTERN_C typedef int XGBCallbackDataIterNext( // NOLINT(*) * \param data_handle The handle to the data. * \param callback The callback to get the data. * \param cache_info Additional information about cache file, can be null. + * \param missing Which value to represent missing value. * \param out The created DMatrix * \return 0 when success, -1 when failure happens. */ @@ -409,6 +410,7 @@ XGB_DLL int XGDMatrixCreateFromDataIter( DataIterHandle data_handle, XGBCallbackDataIterNext* callback, const char* cache_info, + float missing, DMatrixHandle *out); /** diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index 2e7540bd2b30..f3fca1b4d28b 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2023 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,21 +29,27 @@ public class DMatrix { protected long handle = 0; /** - * sparse matrix type (CSR or CSC) + * Create DMatrix from iterator. + * + * @param iter The data iterator of mini batch to provide the data. + * @param cacheInfo Cache path information, used for external memory setting, can be null. + * @throws XGBoostError */ - public static enum SparseType { - CSR, - CSC; + public DMatrix(Iterator iter, String cacheInfo) throws XGBoostError { + this(iter, cacheInfo, Float.NaN); } /** * Create DMatrix from iterator. * - * @param iter The data iterator of mini batch to provide the data. + * @param iter The data iterator of mini batch to provide the data. * @param cacheInfo Cache path information, used for external memory setting, can be null. + * @param missing the missing value * @throws XGBoostError */ - public DMatrix(Iterator iter, String cacheInfo) throws XGBoostError { + public DMatrix(Iterator iter, + String cacheInfo, + float missing) throws XGBoostError { if (iter == null) { throw new NullPointerException("iter: null"); } @@ -51,7 +57,8 @@ public DMatrix(Iterator iter, String cacheInfo) throws XGBoostErro int batchSize = 32 << 10; Iterator batchIter = new DataBatch.BatchIterator(iter, batchSize); long[] out = new long[1]; - XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter( + batchIter, cacheInfo, missing, out)); handle = out[0]; } @@ -72,10 +79,11 @@ public DMatrix(String dataPath) throws XGBoostError { /** * Create DMatrix from Sparse matrix in CSR/CSC format. + * * @param headers The row index of the matrix. * @param indices The indices of presenting entries. - * @param data The data content. - * @param st Type of sparsity. + * @param data The data content. + * @param st Type of sparsity. * @throws XGBoostError */ @Deprecated @@ -86,12 +94,13 @@ public DMatrix(long[] headers, int[] indices, float[] data, /** * Create DMatrix from Sparse matrix in CSR/CSC format. - * @param headers The row index of the matrix. - * @param indices The indices of presenting entries. - * @param data The data content. - * @param st Type of sparsity. - * @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as - * row number + * + * @param headers The row index of the matrix. + * @param indices The indices of presenting entries. + * @param data The data content. + * @param st Type of sparsity. + * @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as + * row number * @throws XGBoostError */ public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st, @@ -121,7 +130,6 @@ public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType s * @param nrow number of rows * @param ncol number of columns * @throws XGBoostError native error - * * @deprecated Please specify the missing value explicitly using * {@link DMatrix(float[], int, int, float)} */ @@ -144,9 +152,10 @@ public DMatrix(BigDenseMatrix matrix) throws XGBoostError { /** * create DMatrix from dense matrix - * @param data data values - * @param nrow number of rows - * @param ncol number of columns + * + * @param data data values + * @param nrow number of rows + * @param ncol number of columns * @param missing the specified value to represent the missing value */ public DMatrix(float[] data, int nrow, int ncol, float missing) throws XGBoostError { @@ -157,13 +166,14 @@ public DMatrix(float[] data, int nrow, int ncol, float missing) throws XGBoostEr /** * create DMatrix from dense matrix - * @param matrix instance of BigDenseMatrix + * + * @param matrix instance of BigDenseMatrix * @param missing the specified value to represent the missing value */ public DMatrix(BigDenseMatrix matrix, float missing) throws XGBoostError { long[] out = new long[1]; XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMatRef(matrix.address, matrix.nrow, - matrix.ncol, missing, out)); + matrix.ncol, missing, out)); handle = out[0]; } @@ -176,10 +186,11 @@ protected DMatrix(long handle) { /** * Create the normal DMatrix from column array interface - * @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface + * + * @param columnBatch the XGBoost ColumnBatch to provide the array interface * of feature columns - * @param missing missing value - * @param nthread threads number + * @param missing missing value + * @param nthread threads number * @throws XGBoostError */ public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError { @@ -194,36 +205,30 @@ public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoo } /** - * Set label of DMatrix from cuda array interface - * - * @param column the XGBoost Column to provide the cuda array interface - * of label column - * @throws XGBoostError native error + * flatten a mat to array */ - public void setLabel(Column column) throws XGBoostError { - setXGBDMatrixInfo("label", column.getArrayInterfaceJson()); - } + private static float[] flatten(float[][] mat) { + int size = 0; + for (float[] array : mat) size += array.length; + float[] result = new float[size]; + int pos = 0; + for (float[] ar : mat) { + System.arraycopy(ar, 0, result, pos, ar.length); + pos += ar.length; + } - /** - * Set weight of DMatrix from cuda array interface - * - * @param column the XGBoost Column to provide the cuda array interface - * of weight column - * @throws XGBoostError native error - */ - public void setWeight(Column column) throws XGBoostError { - setXGBDMatrixInfo("weight", column.getArrayInterfaceJson()); + return result; } /** - * Set base margin of DMatrix from cuda array interface + * Set query id of DMatrix from array interface * - * @param column the XGBoost Column to provide the cuda array interface - * of base margin column + * @param column the XGBoost Column to provide the array interface + * of query id column * @throws XGBoostError native error */ - public void setBaseMargin(Column column) throws XGBoostError { - setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson()); + public void setQueryId(Column column) throws XGBoostError { + setXGBDMatrixInfo("qid", column.getArrayInterfaceJson()); } private void setXGBDMatrixInfo(String type, String json) throws XGBoostError { @@ -257,17 +262,9 @@ private String[] getXGBDMatrixFeatureInfo(String type) throws XGBoostError { return outValue[0]; } - /** - * Set feature names - * @param values feature names to be set - * @throws XGBoostError - */ - public void setFeatureNames(String[] values) throws XGBoostError { - setXGBDMatrixFeatureInfo("feature_name", values); - } - /** * Get feature names + * * @return an array of feature names to be returned * @throws XGBoostError */ @@ -276,16 +273,18 @@ public String[] getFeatureNames() throws XGBoostError { } /** - * Set feature types - * @param values feature types to be set + * Set feature names + * + * @param values feature names to be set * @throws XGBoostError */ - public void setFeatureTypes(String[] values) throws XGBoostError { - setXGBDMatrixFeatureInfo("feature_type", values); + public void setFeatureNames(String[] values) throws XGBoostError { + setXGBDMatrixFeatureInfo("feature_name", values); } /** * Get feature types + * * @return an array of feature types to be returned * @throws XGBoostError */ @@ -294,46 +293,23 @@ public String[] getFeatureTypes() throws XGBoostError { } /** - * set label of dmatrix + * Set feature types * - * @param labels labels - * @throws XGBoostError native error + * @param values feature types to be set + * @throws XGBoostError */ - public void setLabel(float[] labels) throws XGBoostError { - XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels)); + public void setFeatureTypes(String[] values) throws XGBoostError { + setXGBDMatrixFeatureInfo("feature_type", values); } /** - * set weight of each instance + * Get group sizes of DMatrix * - * @param weights weights + * @return group size as array * @throws XGBoostError native error */ - public void setWeight(float[] weights) throws XGBoostError { - XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights)); - } - - /** - * Set base margin (initial prediction). - * - * The margin must have the same number of elements as the number of - * rows in this matrix. - */ - public void setBaseMargin(float[] baseMargin) throws XGBoostError { - if (baseMargin.length != rowNum()) { - throw new IllegalArgumentException(String.format( - "base margin must have exactly %s elements, got %s", - rowNum(), baseMargin.length)); - } - - XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin)); - } - - /** - * Set base margin (initial prediction). - */ - public void setBaseMargin(float[][] baseMargin) throws XGBoostError { - setBaseMargin(flatten(baseMargin)); + public int[] getGroup() throws XGBoostError { + return getIntInfo("group_ptr"); } /** @@ -347,13 +323,13 @@ public void setGroup(int[] group) throws XGBoostError { } /** - * Get group sizes of DMatrix + * Set query ids (used for ranking) * + * @param qid the query ids * @throws XGBoostError native error - * @return group size as array */ - public int[] getGroup() throws XGBoostError { - return getIntInfo("group_ptr"); + public void setQueryId(int[] qid) throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetUIntInfo(handle, "qid", qid)); } private float[] getFloatInfo(String field) throws XGBoostError { @@ -378,6 +354,27 @@ public float[] getLabel() throws XGBoostError { return getFloatInfo("label"); } + /** + * Set label of DMatrix from array interface + * + * @param column the XGBoost Column to provide the array interface + * of label column + * @throws XGBoostError native error + */ + public void setLabel(Column column) throws XGBoostError { + setXGBDMatrixInfo("label", column.getArrayInterfaceJson()); + } + + /** + * set label of dmatrix + * + * @param labels labels + * @throws XGBoostError native error + */ + public void setLabel(float[] labels) throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels)); + } + /** * get weight of the DMatrix * @@ -388,6 +385,27 @@ public float[] getWeight() throws XGBoostError { return getFloatInfo("weight"); } + /** + * Set weight of DMatrix from array interface + * + * @param column the XGBoost Column to provide the array interface + * of weight column + * @throws XGBoostError native error + */ + public void setWeight(Column column) throws XGBoostError { + setXGBDMatrixInfo("weight", column.getArrayInterfaceJson()); + } + + /** + * set weight of each instance + * + * @param weights weights + * @throws XGBoostError native error + */ + public void setWeight(float[] weights) throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights)); + } + /** * Get base margin of the DMatrix. */ @@ -395,6 +413,40 @@ public float[] getBaseMargin() throws XGBoostError { return getFloatInfo("base_margin"); } + /** + * Set base margin of DMatrix from array interface + * + * @param column the XGBoost Column to provide the array interface + * of base margin column + * @throws XGBoostError native error + */ + public void setBaseMargin(Column column) throws XGBoostError { + setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson()); + } + + /** + * Set base margin (initial prediction). + *

+ * The margin must have the same number of elements as the number of + * rows in this matrix. + */ + public void setBaseMargin(float[] baseMargin) throws XGBoostError { + if (baseMargin.length != rowNum()) { + throw new IllegalArgumentException(String.format( + "base margin must have exactly %s elements, got %s", + rowNum(), baseMargin.length)); + } + + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin)); + } + + /** + * Set base margin (initial prediction). + */ + public void setBaseMargin(float[][] baseMargin) throws XGBoostError { + setBaseMargin(flatten(baseMargin)); + } + /** * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`. * @@ -448,22 +500,6 @@ public long getHandle() { return handle; } - /** - * flatten a mat to array - */ - private static float[] flatten(float[][] mat) { - int size = 0; - for (float[] array : mat) size += array.length; - float[] result = new float[size]; - int pos = 0; - for (float[] ar : mat) { - System.arraycopy(ar, 0, result, pos, ar.length); - pos += ar.length; - } - - return result; - } - @Override protected void finalize() { dispose(); @@ -475,4 +511,12 @@ public synchronized void dispose() { handle = 0; } } + + /** + * sparse matrix type (CSR or CSC) + */ + public enum SparseType { + CSR, + CSC + } } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index b410d2be1d02..00413636e0f0 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -54,7 +54,7 @@ static void checkCall(int ret) throws XGBoostError { public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); final static native int XGDMatrixCreateFromDataIter(java.util.Iterator iter, - String cache_info, long[] out); + String cache_info, float missing, long[] out); public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, int shapeParam, diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala index 714adf726292..294107f082fa 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2023 by Contributors + Copyright (c) 2014-2024 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package ml.dmlc.xgboost4j.scala import _root_.scala.collection.JavaConverters._ import ml.dmlc.xgboost4j.LabeledPoint -import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DataBatch, XGBoostError, DMatrix => JDMatrix} +import ml.dmlc.xgboost4j.java.{Column, ColumnBatch, DMatrix => JDMatrix, XGBoostError} class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { /** @@ -33,14 +33,17 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { } /** - * init DMatrix from Iterator of LabeledPoint - * - * @param dataIter An iterator of LabeledPoint - * @param cacheInfo Cache path information, used for external memory setting, null by default. - * @throws XGBoostError native error - */ - def this(dataIter: Iterator[LabeledPoint], cacheInfo: String = null) { - this(new JDMatrix(dataIter.asJava, cacheInfo)) + * init DMatrix from Iterator of LabeledPoint + * + * @param dataIter An iterator of LabeledPoint + * @param cacheInfo Cache path information, used for external memory setting, null by default. + * @param missing Which value will be treated as the missing value + * @throws XGBoostError native error + */ + def this(dataIter: Iterator[LabeledPoint], + cacheInfo: String = null, + missing: Float = Float.NaN) { + this(new JDMatrix(dataIter.asJava, cacheInfo, missing)) } /** @@ -60,12 +63,12 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { /** * create DMatrix from sparse matrix * - * @param headers index to headers (rowHeaders for CSR or colHeaders for CSC) - * @param indices Indices (colIndexs for CSR or rowIndexs for CSC) - * @param data non zero values (sequence by row for CSR or by col for CSC) - * @param st sparse matrix type (CSR or CSC) + * @param headers index to headers (rowHeaders for CSR or colHeaders for CSC) + * @param indices Indices (colIndexs for CSR or rowIndexs for CSC) + * @param data non zero values (sequence by row for CSR or by col for CSC) + * @param st sparse matrix type (CSR or CSC) * @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as - * row number + * row number */ @throws(classOf[XGBoostError]) def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType, @@ -76,14 +79,14 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { /** * create DMatrix from sparse matrix * - * @param headers index to headers (rowHeaders for CSR or colHeaders for CSC) - * @param indices Indices (colIndexs for CSR or rowIndexs for CSC) - * @param data non zero values (sequence by row for CSR or by col for CSC) - * @param st sparse matrix type (CSR or CSC) + * @param headers index to headers (rowHeaders for CSR or colHeaders for CSC) + * @param indices Indices (colIndexs for CSR or rowIndexs for CSC) + * @param data non zero values (sequence by row for CSR or by col for CSC) + * @param st sparse matrix type (CSR or CSC) * @param shapeParam when st is CSR, it specifies the column number, otherwise it is taken as - * row number - * @param missing missing value - * @param nthread The number of threads used for constructing DMatrix + * row number + * @param missing missing value + * @param nthread The number of threads used for constructing DMatrix */ @throws(classOf[XGBoostError]) def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType, @@ -93,10 +96,11 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { /** * Create the normal DMatrix from column array interface + * * @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface * of feature columns - * @param missing missing value - * @param nthread The number of threads used for constructing DMatrix + * @param missing missing value + * @param nthread The number of threads used for constructing DMatrix */ @throws(classOf[XGBoostError]) def this(columnBatch: ColumnBatch, missing: Float, nthread: Int) { @@ -119,9 +123,9 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { /** * create DMatrix from dense matrix * - * @param data data values - * @param nrow number of rows - * @param ncol number of columns + * @param data data values + * @param nrow number of rows + * @param ncol number of columns * @param missing the specified value to represent the missing value */ @throws(classOf[XGBoostError]) @@ -181,6 +185,16 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { jDMatrix.setGroup(group) } + /** + * Set query ids (used for ranking) + * + * @param qid query ids + */ + @throws(classOf[XGBoostError]) + def setQueryId(qid: Array[Int]): Unit = { + jDMatrix.setQueryId(qid) + } + /** * Set label of DMatrix from cuda array interface */ @@ -205,8 +219,17 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { jDMatrix.setBaseMargin(column) } + /** + * set query id of dmatrix from column array interface + */ + @throws(classOf[XGBoostError]) + def setQueryId(column: Column): Unit = { + jDMatrix.setQueryId(column) + } + /** * set feature names + * * @param values feature names * @throws ml.dmlc.xgboost4j.java.XGBoostError */ @@ -217,6 +240,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { /** * set feature types + * * @param values feature types * @throws ml.dmlc.xgboost4j.java.XGBoostError */ @@ -265,6 +289,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { /** * get feature names + * * @throws ml.dmlc.xgboost4j.java.XGBoostError * @return */ @@ -275,6 +300,7 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { /** * get feature types + * * @throws ml.dmlc.xgboost4j.java.XGBoostError * @return */ diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index cfab645ed6bf..d8f169157e3a 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -214,7 +214,7 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError * Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter - (JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) { + (JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jfloat jmissing, jlongArray jout) { DMatrixHandle result; std::unique_ptr> cache_info; if (jcache_info != nullptr) { @@ -222,8 +222,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro jenv->ReleaseStringUTFChars(jcache_info, ptr); }}; } + auto missing = static_cast(jmissing); int ret = - XGDMatrixCreateFromDataIter(jiter, XGBoost4jCallbackDataIterNext, cache_info.get(), &result); + XGDMatrixCreateFromDataIter(jiter, XGBoost4jCallbackDataIterNext, cache_info.get(), + missing,&result); JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); return ret; diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index c8e48cfc9de9..f8657b5a61a1 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -26,10 +26,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGDMatrixCreateFromDataIter - * Signature: (Ljava/util/Iterator;Ljava/lang/String;[J)I + * Signature: (Ljava/util/Iterator;Ljava/lang/String;F[J)I */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter - (JNIEnv *, jclass, jobject, jstring, jlongArray); + (JNIEnv *, jclass, jobject, jstring, jfloat, jlongArray); /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index b6ffe84e30e9..0bc6f7b73f17 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -15,15 +15,18 @@ */ package ml.dmlc.xgboost4j.java; -import java.io.*; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.Random; import junit.framework.TestCase; -import ml.dmlc.xgboost4j.java.util.BigDenseMatrix; import ml.dmlc.xgboost4j.LabeledPoint; +import ml.dmlc.xgboost4j.java.util.BigDenseMatrix; import org.junit.Test; import static org.junit.Assert.assertArrayEquals; @@ -36,6 +39,32 @@ */ public class DMatrixTest { + + @Test + public void testCreateFromDataIteratorWithMissingValue() throws XGBoostError { + //create DMatrix from DataIterator + java.util.List blist = new java.util.LinkedList<>(); + blist.add(new LabeledPoint(0.1f, 4, null, new float[]{1, 0, 0, 0})); + blist.add(new LabeledPoint(0.1f, 4, null, new float[]{Float.NaN, 13, 14, 15})); + blist.add(new LabeledPoint(0.1f, 4, null, new float[]{21, 23, 0, 25})); + + // Default missing value: Float.NaN + DMatrix dmat = new DMatrix(blist.iterator(), null); + assert dmat.nonMissingNum() == 11; + + // missing value 0 + dmat = new DMatrix(blist.iterator(), null, 0.0f); + assert dmat.nonMissingNum() == 12 - 4 - 1; + + // missing value 21 + dmat = new DMatrix(blist.iterator(), null, 21.0f); + assert dmat.nonMissingNum() == 12 - 1 - 1; + + // missing value 101010101010 + dmat = new DMatrix(blist.iterator(), null, 101010101010.0f); + assert dmat.nonMissingNum() == 12 - 1; + } + @Test public void testCreateFromDataIterator() throws XGBoostError { //create DMatrix from DataIterator @@ -45,7 +74,7 @@ public void testCreateFromDataIterator() throws XGBoostError { java.util.List blist = new java.util.LinkedList(); for (int i = 0; i < nrep; ++i) { LabeledPoint p = new LabeledPoint( - 0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5}); + 0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5}); blist.add(p); labelall.add(p.label()); } @@ -290,7 +319,7 @@ public void testCreateFromDenseMatrixRef() throws XGBoostError { } finally { if (dmat0 != null) { dmat0.dispose(); - } else if (data0 != null){ + } else if (data0 != null) { data0.dispose(); } } @@ -309,9 +338,9 @@ public void testTrainWithDenseMatrixRef() throws XGBoostError { // (3,1) -> 2 // (2,3) -> 3 float[][] data = new float[][]{ - new float[]{4f, 5f}, - new float[]{3f, 1f}, - new float[]{2f, 3f} + new float[]{4f, 5f}, + new float[]{3f, 1f}, + new float[]{2f, 3f} }; data0 = new BigDenseMatrix(3, 2); for (int i = 0; i < data0.nrow; i++) @@ -428,4 +457,40 @@ public void testSetAndGetFeatureInfo() throws XGBoostError { String[] retFeatureTypes = dmat.getFeatureTypes(); assertArrayEquals(featureTypes, retFeatureTypes); } + + @Test + public void testSetAndGetQueryId() throws XGBoostError { + //create DMatrix from 10*5 dense matrix + int nrow = 10; + int ncol = 5; + float[] data0 = new float[nrow * ncol]; + //put random nums + Random random = new Random(); + for (int i = 0; i < nrow * ncol; i++) { + data0[i] = random.nextFloat(); + } + + //create label + float[] label0 = new float[nrow]; + for (int i = 0; i < nrow; i++) { + label0[i] = random.nextFloat(); + } + + //create two groups + int[] qid = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + int[] qidExpected = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + + DMatrix dmat0 = new DMatrix(data0, nrow, ncol, -0.1f); + dmat0.setLabel(label0); + dmat0.setQueryId(qid); + //check + TestCase.assertTrue(Arrays.equals(qidExpected, dmat0.getGroup())); + + //create two groups + int[] qid1 = new int[]{10, 10, 10, 20, 60, 60, 80, 80, 90, 100}; + int[] qidExpected1 = new int[]{0, 3, 4, 6, 8, 9, 10}; + dmat0.setQueryId(qid1); + TestCase.assertTrue(Arrays.equals(qidExpected1, dmat0.getGroup())); + + } } diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 3559660dd1a5..7371188650bd 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -253,7 +253,9 @@ XGB_DLL int XGDMatrixCreateFromURI(const char *config, DMatrixHandle *out) { XGB_DLL int XGDMatrixCreateFromDataIter( void *data_handle, // a Java iterator XGBCallbackDataIterNext *callback, // C++ callback defined in xgboost4j.cpp - const char *cache_info, DMatrixHandle *out) { + const char *cache_info, + float missing, + DMatrixHandle *out) { API_BEGIN(); std::string scache; @@ -264,10 +266,7 @@ XGB_DLL int XGDMatrixCreateFromDataIter( data_handle, callback); xgboost_CHECK_C_ARG_PTR(out); *out = new std::shared_ptr { - DMatrix::Create( - &adapter, std::numeric_limits::quiet_NaN(), - 1, scache - ) + DMatrix::Create(&adapter, missing, 1, scache) }; API_END(); }