Skip to content

Commit

Permalink
Sherwood lib:
Browse files Browse the repository at this point in the history
parallel forest trainer: bug fixing
parallel forest trainer: enabling with PPL
  • Loading branch information
ereator committed Aug 23, 2016
1 parent fff0936 commit 4fde247
Show file tree
Hide file tree
Showing 10 changed files with 633 additions and 697 deletions.
3 changes: 3 additions & 0 deletions DGM/DGM.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
<ClInclude Include="PriorNode.h" />
<ClInclude Include="Random.h" />
<ClInclude Include="RForest.h" />
<ClInclude Include="sherwood\Forest.h" />
<ClInclude Include="sherwood\ForestTrainer.h" />
<ClInclude Include="sherwood\ParallelForestTrainer.h" />
<ClInclude Include="sherwood\utilities\DataPointCollection.h" />
<ClInclude Include="sherwood\utilities\FeatureResponseFunctions.h" />
<ClInclude Include="sherwood\utilities\StatisticsAggregators.h" />
Expand Down
9 changes: 9 additions & 0 deletions DGM/DGM.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,15 @@
<ClInclude Include="..\include\parallel.h">
<Filter>Include</Filter>
</ClInclude>
<ClInclude Include="sherwood\ForestTrainer.h">
<Filter>External\Sherwood Utilities</Filter>
</ClInclude>
<ClInclude Include="sherwood\ParallelForestTrainer.h">
<Filter>External\Sherwood Utilities</Filter>
</ClInclude>
<ClInclude Include="sherwood\Forest.h">
<Filter>External\Sherwood Utilities</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="sherwood\utilities\FeatureResponseFunctions.cpp">
Expand Down
30 changes: 21 additions & 9 deletions DGM/TrainNodeMsRF.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
#include "TrainNodeMsRF.h"
#include "macroses.h"

//#include "sherwood\ParallelForestTrainer.h" // for parallle computing
#ifdef USE_SHERWOOD

#include "sherwood\Sherwood.h"

#ifdef USE_PPL
#include "sherwood\ParallelForestTrainer.h" // for parallle computing
#endif

#include "sherwood\utilities\FeatureResponseFunctions.h"
#include "sherwood\utilities\StatisticsAggregators.h"
Expand All @@ -27,11 +33,12 @@ CTrainNodeMsRF::CTrainNodeMsRF(byte nStates, word nFeatures, int maxSamples) : C
void CTrainNodeMsRF::init(TrainNodeMsRFParams params)
{
// Some default parameters
m_params.MaxDecisionLevels = params.max_decision_levels - 1;
m_params.NumberOfCandidateFeatures = params.num_of_candidate_features;
m_params.NumberOfCandidateThresholdsPerFeature = params.num_of_candidate_thresholds_per_feature;
m_params.NumberOfTrees = params.num_ot_trees;
m_params.Verbose = params.verbose;
m_pParams = std::auto_ptr<sw::TrainingParameters>(new sw::TrainingParameters());
m_pParams->MaxDecisionLevels = params.max_decision_levels - 1;
m_pParams->NumberOfCandidateFeatures = params.num_of_candidate_features;
m_pParams->NumberOfCandidateThresholdsPerFeature = params.num_of_candidate_thresholds_per_feature;
m_pParams->NumberOfTrees = params.num_ot_trees;
m_pParams->Verbose = params.verbose;

m_pData = std::auto_ptr<sw::DataPointCollection>(new sw::DataPointCollection());
m_pData->m_dimension = m_nFeatures;
Expand Down Expand Up @@ -100,8 +107,12 @@ void CTrainNodeMsRF::train(void)

sw::Random random;
sw::ClassificationTrainingContext classificationContext(m_nStates, m_nFeatures);
m_pForest = sw::ForestTrainer<sw::LinearFeatureResponse, sw::HistogramAggregator>::TrainForest(random, m_params, classificationContext, *m_pData);
//m_pForest = sw::ParallelForestTrainer<sw::LinearFeatureResponse, sw::HistogramAggregator>::TrainForest(random, m_params, classificationContext, *m_pData);
#ifdef USE_PPL
// Use this function with cautions - it is not verifiied!
m_pForest = sw::ParallelForestTrainer<sw::LinearFeatureResponse, sw::HistogramAggregator>::TrainForest(random, *m_pParams, classificationContext, *m_pData);
#else
m_pForest = sw::ForestTrainer<sw::LinearFeatureResponse, sw::HistogramAggregator>::TrainForest(random, *m_pParams, classificationContext, *m_pData);
#endif
}

void CTrainNodeMsRF::calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Expand All @@ -128,4 +139,5 @@ void CTrainNodeMsRF::calculateNodePotentials(const Mat &featureVector, Mat &pote
for (byte s = 0; s < m_nStates; s++)
potential.at<float>(s, 0) = (1.0f - mudiness) * h.GetProbability(s);
}
}
}
#endif
12 changes: 8 additions & 4 deletions DGM/TrainNodeMsRF.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
#pragma once

#include "TrainNode.h"
#include "sherwood\Sherwood.h"


namespace sw = MicrosoftResearch::Cambridge::Sherwood;
#ifdef USE_SHERWOOD

namespace MicrosoftResearch { namespace Cambridge { namespace Sherwood {
class LinearFeatureResponse;
class HistogramAggregator;
class DataPointCollection;
template<class F, class S> class Forest;
struct TrainingParameters;
}}}

namespace sw = MicrosoftResearch::Cambridge::Sherwood;

namespace DirectGraphicalModels
{
///@brief Microsoft Research Random Forest parameters
Expand Down Expand Up @@ -43,6 +45,7 @@ namespace DirectGraphicalModels
* @ingroup moduleTrainNode
* @brief Microsoft Sherwood Random Forest training class
* @details This class is based on the <a href="http://research.microsoft.com/en-us/downloads/52d5b9c3-a638-42a1-94a5-d549e2251728/">Sherwood C++ code library for decision forests</a> v.1.0.0
* > In order to use the Sherwood library, DGM must be build with the \b USE_SHERWOOD flag
* @author Sergey G. Kosov, [email protected]
*/
class CTrainNodeMsRF : public CTrainNode
Expand Down Expand Up @@ -100,7 +103,8 @@ namespace DirectGraphicalModels


private:
sw::TrainingParameters m_params;
std::auto_ptr<sw::TrainingParameters> m_pParams;
size_t m_maxSamples;
};
}
#endif
Loading

0 comments on commit 4fde247

Please sign in to comment.