diff --git a/src/treelearner/data_parallel_tree_learner.cpp b/src/treelearner/data_parallel_tree_learner.cpp index 549c7a96d59a..c9ff36da9f44 100644 --- a/src/treelearner/data_parallel_tree_learner.cpp +++ b/src/treelearner/data_parallel_tree_learner.cpp @@ -155,6 +155,22 @@ template void DataParallelTreeLearner::FindBestSplits(const Tree* tree) { TREELEARNER_T::ConstructHistograms( this->col_sampler_.is_feature_used_bytree(), true); + const int smaller_leaf_index = this->smaller_leaf_splits_->leaf_index(); + const data_size_t local_data_on_smaller_leaf = this->data_partition_->leaf_count(smaller_leaf_index); + if (local_data_on_smaller_leaf <= 0) { + // clear histogram buffer before synchronizing + // otherwise histogram contents from the previous iteration will be sent + #pragma omp parallel for schedule(static) + for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) { + if (this->col_sampler_.is_feature_used_bytree()[feature_index] == false) + continue; + const BinMapper* feature_bin_mapper = this->train_data_->FeatureBinMapper(feature_index); + const int offset = static_cast(feature_bin_mapper->GetMostFreqBin() == 0); + const int num_bin = feature_bin_mapper->num_bin(); + hist_t* hist_ptr = this->smaller_leaf_histogram_array_[feature_index].RawData(); + std::memset(reinterpret_cast(hist_ptr), 0, (num_bin - offset) * kHistEntrySize); + } + } // construct local histograms #pragma omp parallel for schedule(static) for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) { diff --git a/src/treelearner/voting_parallel_tree_learner.cpp b/src/treelearner/voting_parallel_tree_learner.cpp index ab4adeae62d9..aacd5caa4412 100644 --- a/src/treelearner/voting_parallel_tree_learner.cpp +++ b/src/treelearner/voting_parallel_tree_learner.cpp @@ -259,6 +259,48 @@ void VotingParallelTreeLearner::FindBestSplits(const Tree* tree) } TREELEARNER_T::ConstructHistograms(is_feature_used, use_subtract); + const int smaller_leaf_index = this->smaller_leaf_splits_->leaf_index(); + const data_size_t local_data_on_smaller_leaf = this->data_partition_->leaf_count(smaller_leaf_index); + if (local_data_on_smaller_leaf <= 0) { + // clear histogram buffer before synchronizing + // otherwise histogram contents from the previous iteration will be sent + OMP_INIT_EX(); + #pragma omp parallel for schedule(static) + for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) { + OMP_LOOP_EX_BEGIN(); + if (!is_feature_used[feature_index]) { continue; } + const BinMapper* feature_bin_mapper = this->train_data_->FeatureBinMapper(feature_index); + const int num_bin = feature_bin_mapper->num_bin(); + const int offset = static_cast(feature_bin_mapper->GetMostFreqBin() == 0); + hist_t* hist_ptr = this->smaller_leaf_histogram_array_[feature_index].RawData(); + std::memset(reinterpret_cast(hist_ptr), 0, (num_bin - offset) * kHistEntrySize); + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + } + + if (this->larger_leaf_splits_ != nullptr) { + const int larger_leaf_index = this->larger_leaf_splits_->leaf_index(); + if (larger_leaf_index >= 0) { + const data_size_t local_data_on_larger_leaf = this->data_partition_->leaf_count(larger_leaf_index); + if (local_data_on_larger_leaf <= 0) { + OMP_INIT_EX(); + #pragma omp parallel for schedule(static) + for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) { + OMP_LOOP_EX_BEGIN(); + if (!is_feature_used[feature_index]) { continue; } + const BinMapper* feature_bin_mapper = this->train_data_->FeatureBinMapper(feature_index); + const int num_bin = feature_bin_mapper->num_bin(); + const int offset = static_cast(feature_bin_mapper->GetMostFreqBin() == 0); + hist_t* hist_ptr = this->larger_leaf_histogram_array_[feature_index].RawData(); + std::memset(reinterpret_cast(hist_ptr), 0, (num_bin - offset) * kHistEntrySize); + OMP_LOOP_EX_END(); + } + OMP_THROW_EX(); + } + } + } + std::vector smaller_bestsplit_per_features(this->num_features_); std::vector larger_bestsplit_per_features(this->num_features_); double smaller_leaf_parent_output = this->GetParentOutput(tree, this->smaller_leaf_splits_.get());