Skip to content

Commit

Permalink
Compare with get_dump() instead of save_raw() (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
krfricke authored Aug 5, 2021
1 parent 3914506 commit 532c6ff
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions xgboost_ray/tests/test_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def testCheckpointContinuationValidity(self):
lc1.load_model(last_checkpoint_1)
self.assertEqual(last_checkpoint_1, last_checkpoint_other_rank_1)
self.assertEqual(last_checkpoint_1, lc1.save_raw())
self.assertEqual(bst_1.save_raw(), lc1.save_raw())
self.assertEqual(bst_1.get_dump(), lc1.get_dump())

# Start new training run, starting from existing model
res_2 = {}
Expand Down Expand Up @@ -361,15 +361,15 @@ def testCheckpointContinuationValidity(self):
# Sanity check
self.assertEqual(first_checkpoint_2, first_checkpoint_other_actor_2)
self.assertEqual(last_checkpoint_2, last_checkpoint_other_actor_2)
self.assertEqual(bst_2.save_raw(), lcp_bst.save_raw())
self.assertEqual(bst_2.get_dump(), lcp_bst.get_dump())

# Training should not have proceeded for the first checkpoint,
# so trees should be equal
self.assertEqual(last_checkpoint_1, fcp_bst.save_raw())
self.assertEqual(lc1.get_dump(), fcp_bst.get_dump())

# Training should have proceeded for the last checkpoint,
# so trees should not be equal
self.assertNotEqual(fcp_bst.save_raw(), lcp_bst.save_raw())
self.assertNotEqual(fcp_bst.get_dump(), lcp_bst.get_dump())

def testSameResultWithAndWithoutError(self):
"""Get the same model with and without errors during training."""
Expand Down

0 comments on commit 532c6ff

Please sign in to comment.