Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[breaking] [jvm-packages] Remove rabit check point. #9599

Merged
merged 7 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demo/guide-python/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def check(as_pickle):
# Use callback class from xgboost.callback
# Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, name="model"
directory=tmpdir, interval=rounds, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
Expand All @@ -118,7 +118,7 @@ def check(as_pickle):
# This version of checkpoint saves everything including parameters and
# model. See: doc/tutorials/saving_model.rst
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, as_pickle=True, name="model"
directory=tmpdir, interval=rounds, as_pickle=True, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
Expand Down
18 changes: 0 additions & 18 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1308,24 +1308,6 @@ XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, bst_ulong *out_len,
XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
const void *buf, bst_ulong len);

/*!
* \brief Initialize the booster from rabit checkpoint.
* This is used in distributed training API.
* \param handle handle
* \param version The output version of the model.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version);

/*!
* \brief Save the current checkpoint to rabit.
* \param handle handle
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle);


/*!
* \brief Save XGBoost's internal configuration into a JSON document. Currently the
* support is experimental, function signature may change in the future without
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,57 +32,53 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
}

private def createNewModels():
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
(String, XGBoostClassificationModel, XGBoostClassificationModel) = {
val tmpPath = createTmpFolder("test").toAbsolutePath.toString
val (model4, model8) = {
val (model2, model4) = {
val training = buildDataFrame(Classification.train)
val paramMap = produceParamMap(tmpPath, 2)
(new XGBoostClassifier(paramMap ++ Seq("num_round" -> 2)).fit(training),
new XGBoostClassifier(paramMap ++ Seq("num_round" -> 4)).fit(training))
}
(tmpPath, model4, model8)
(tmpPath, model2, model4)
}

test("test update/load models") {
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))

manager.updateCheckpoint(model4._booster.booster)
manager.updateCheckpoint(model2._booster.booster)
var files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "4.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 4)
assert(files.head.getPath.getName == "1.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 2)

manager.updateCheckpoint(model8._booster)
manager.updateCheckpoint(model4._booster)
files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
assert(manager.loadCheckpointAsScalaBooster().getVersion == 8)
assert(files.head.getPath.getName == "3.model")
assert(manager.loadCheckpointAsScalaBooster().getNumBoostedRound == 4)
}

test("test cleanUpHigherVersions") {
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()

val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
manager.updateCheckpoint(model8._booster)
manager.cleanUpHigherVersions(8)
assert(new File(s"$tmpPath/8.model").exists())
manager.updateCheckpoint(model4._booster)
manager.cleanUpHigherVersions(3)
assert(new File(s"$tmpPath/3.model").exists())

manager.cleanUpHigherVersions(4)
assert(!new File(s"$tmpPath/8.model").exists())
manager.cleanUpHigherVersions(2)
assert(!new File(s"$tmpPath/3.model").exists())
}

test("test checkpoint rounds") {
import scala.collection.JavaConverters._
val (tmpPath, model4, model8) = createNewModels()
val (tmpPath, model2, model4) = createNewModels()
val manager = new ExternalCheckpointManager(tmpPath, FileSystem.get(sc.hadoopConfiguration))
assertResult(Seq(7))(
manager.getCheckpointRounds(0, 7).asScala)
assertResult(Seq(2, 4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
manager.updateCheckpoint(model4._booster)
assertResult(Seq(4, 6, 7))(
manager.getCheckpointRounds(2, 7).asScala)
assertResult(Seq(2))(manager.getCheckpointRounds(0, 0, 3).asScala)
assertResult(Seq(0, 2, 4, 6))(manager.getCheckpointRounds(0, 2, 7).asScala)
assertResult(Seq(0, 2, 4, 6, 7))(manager.getCheckpointRounds(0, 2, 8).asScala)
}


Expand All @@ -109,8 +105,8 @@ class ExternalCheckpointManagerSuite extends AnyFunSuite with TmpFolderPerSuite
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
assert(files.head.getPath.getName == "4.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/4.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) >= error(prevModel._booster))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -787,35 +787,6 @@ private Map<String, Double> getFeatureImportanceFromModel(
return importanceMap;
}

/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
modelInfos));
return modelInfos[0];
}

public int getVersion() {
return this.version;
}

public void setVersion(int version) {
this.version = version;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems still need to remove the definition of version

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what would it do to binary serialization.

}

/**
* Save model into raw byte array. Currently it's using the deprecated format as
* default, which will be changed into `ubj` in future releases.
Expand All @@ -841,29 +812,6 @@ public byte[] toByteArray(String format) throws XGBoostError {
return bytes[0];
}

/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
version = out[0];
return version;
}

/**
* Save the booster model into thread-local rabit checkpoint and increment the version.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
version += 1;
}

/**
* Get number of model features.
* @return the number of features.
Expand All @@ -874,6 +822,11 @@ public long getNumFeature() throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumFeature(this.handle, numFeature));
return numFeature[0];
}
public int getNumBoostedRound() throws XGBoostError {
int[] numRound = new int[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetNumBoostedRound(this.handle, numRound));
return numRound[0];
}

/**
* Internal initialization function.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*
Copyright (c) 2014-2023 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;

import java.io.IOException;
Expand All @@ -15,7 +30,7 @@ public class ExternalCheckpointManager {

private Log logger = LogFactory.getLog("ExternalCheckpointManager");
private String modelSuffix = ".model";
private Path checkpointPath;
private Path checkpointPath; // directory for checkpoints
private FileSystem fs;

public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
Expand All @@ -35,6 +50,7 @@ private List<Integer> getExistingVersions() throws IOException {
if (!fs.exists(checkpointPath)) {
return new ArrayList<>();
} else {
// Get integer versions from a list of checkpoint files.
return Arrays.stream(fs.listStatus(checkpointPath))
.map(path -> path.getPath().getName())
.filter(fileName -> fileName.endsWith(modelSuffix))
Expand All @@ -44,19 +60,23 @@ private List<Integer> getExistingVersions() throws IOException {
}
}

private Integer latest(List<Integer> versions) {
return versions.stream()
.max(Comparator.comparing(Integer::valueOf)).get();
}

public void cleanPath() throws IOException {
fs.delete(checkpointPath, true);
}

public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
List<Integer> versions = getExistingVersions();
if (versions.size() > 0) {
int latestVersion = versions.stream().max(Comparator.comparing(Integer::valueOf)).get();
int latestVersion = this.latest(versions);
String checkpointPath = getPath(latestVersion);
InputStream in = fs.open(new Path(checkpointPath));
logger.info("loaded checkpoint from " + checkpointPath);
Booster booster = XGBoost.loadModel(in);
booster.setVersion(latestVersion);
return booster;
} else {
return null;
Expand All @@ -65,13 +85,16 @@ public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {

public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XGBoostError {
List<String> prevModelPaths = getExistingVersions().stream()
.map(this::getPath).collect(Collectors.toList());
String eventualPath = getPath(boosterToCheckpoint.getVersion());
.map(this::getPath).collect(Collectors.toList());
// checkpointing is done after update, so n_rounds - 1 is the current iteration
// accounting for training continuation.
Integer iter = boosterToCheckpoint.getNumBoostedRound() - 1;
String eventualPath = getPath(iter);
String tempPath = eventualPath + "-" + UUID.randomUUID();
try (OutputStream out = fs.create(new Path(tempPath), true)) {
boosterToCheckpoint.saveModel(out);
fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
logger.info("saving checkpoint with version " + iter);
prevModelPaths.stream().forEach(path -> {
try {
fs.delete(new Path(path), true);
Expand All @@ -83,35 +106,34 @@ public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XG
}

public void cleanUpHigherVersions(int currentRound) throws IOException {
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
getExistingVersions().stream().filter(v -> v > currentRound).forEach(v -> {
try {
fs.delete(new Path(getPath(v)), true);
} catch (IOException e) {
logger.error("failed to clean checkpoint from other training instance", e);
}
});
}

public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
// Get a list of iterations that need checkpointing.
public List<Integer> getCheckpointRounds(
int firstRound, int checkpointInterval, int numOfRounds)
throws IOException {
int end = firstRound + numOfRounds; // exclusive
int lastRound = end - 1;
if (end - 1 < 0) {
throw new IllegalArgumentException("Inavlid `numOfRounds`.");
}

List<Integer> arr = new ArrayList<>();
if (checkpointInterval > 0) {
List<Integer> prevRounds =
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
prevRounds.add(0);
int firstCheckpointRound = prevRounds.stream()
.max(Comparator.comparing(Integer::valueOf)).get() + checkpointInterval;
List<Integer> arr = new ArrayList<>();
for (int i = firstCheckpointRound; i <= numOfRounds; i += checkpointInterval) {
for (int i = firstRound; i < end; i += checkpointInterval) {
arr.add(i);
}
arr.add(numOfRounds);
return arr;
} else if (checkpointInterval <= 0) {
List<Integer> l = new ArrayList<Integer>();
l.add(numOfRounds);
return l;
} else {
throw new IllegalArgumentException("parameters \"checkpoint_path\" should also be set.");
}

if (!arr.contains(lastRound)) {
arr.add(lastRound);
}
return arr;
}
}
Loading
Loading