Skip to content

Commit

Permalink
Pass the exemplar into the NB restore constructor rather than persist…
Browse files Browse the repository at this point in the history
…ing and restoring
  • Loading branch information
tveasey committed May 31, 2018
1 parent a7fa732 commit 90f22e2
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 24 deletions.
3 changes: 2 additions & 1 deletion include/maths/CNaiveBayes.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ class MATHS_EXPORT CNaiveBayes {
explicit CNaiveBayes(const CNaiveBayesFeatureDensity& exemplar,
double decayRate = 0.0,
TOptionalDouble minMaxLogLikelihoodToUseFeature = TOptionalDouble());
CNaiveBayes(const SDistributionRestoreParams& params,
CNaiveBayes(const CNaiveBayesFeatureDensity& exemplar,
const SDistributionRestoreParams& params,
core::CStateRestoreTraverser& traverser);
CNaiveBayes(const CNaiveBayes& other);

Expand Down
20 changes: 3 additions & 17 deletions lib/maths/CNaiveBayes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ const std::string CLASS_MODEL_TAG{"c"};
const std::string MIN_MAX_LOG_LIKELIHOOD_TO_USE_FEATURE_TAG{"d"};
const std::string COUNT_TAG{"e"};
const std::string CONDITIONAL_DENSITY_FROM_PRIOR_TAG{"f"};
const std::string EXEMPLAR_FROM_PRIOR_TAG{"g"};
}

CNaiveBayesFeatureDensityFromPrior::CNaiveBayesFeatureDensityFromPrior(const CPrior& prior)
Expand Down Expand Up @@ -125,9 +124,10 @@ CNaiveBayes::CNaiveBayes(const CNaiveBayesFeatureDensity& exemplar,
m_DecayRate{decayRate}, m_Exemplar{exemplar.clone()}, m_ClassConditionalDensities{2} {
}

CNaiveBayes::CNaiveBayes(const SDistributionRestoreParams& params,
CNaiveBayes::CNaiveBayes(const CNaiveBayesFeatureDensity& exemplar,
const SDistributionRestoreParams& params,
core::CStateRestoreTraverser& traverser)
: m_DecayRate{params.s_DecayRate}, m_ClassConditionalDensities{2} {
: m_DecayRate{params.s_DecayRate}, m_Exemplar{exemplar.clone()}, m_ClassConditionalDensities{2} {
traverser.traverseSubLevel(boost::bind(&CNaiveBayes::acceptRestoreTraverser,
this, boost::cref(params), _1));
}
Expand All @@ -146,13 +146,6 @@ bool CNaiveBayes::acceptRestoreTraverser(const SDistributionRestoreParams& param
do {
const std::string& name{traverser.name()};
RESTORE_BUILT_IN(CLASS_LABEL_TAG, label)
RESTORE_SETUP_TEARDOWN(
EXEMPLAR_FROM_PRIOR_TAG, CNaiveBayesFeatureDensityFromPrior density,
traverser.traverseSubLevel(
boost::bind(&CNaiveBayesFeatureDensityFromPrior::acceptRestoreTraverser,
boost::ref(density), boost::cref(params), _1)),
m_Exemplar.reset(density.clone()))
// Add other implementations' restore code here.
RESTORE_SETUP_TEARDOWN(
CLASS_MODEL_TAG, CClass class_,
traverser.traverseSubLevel(boost::bind(&CClass::acceptRestoreTraverser,
Expand All @@ -170,13 +163,6 @@ void CNaiveBayes::acceptPersistInserter(core::CStatePersistInserter& inserter) c
using TSizeClassUMapCItr = TSizeClassUMap::const_iterator;
using TSizeClassUMapCItrVec = std::vector<TSizeClassUMapCItr>;

if (dynamic_cast<const CNaiveBayesFeatureDensityFromPrior*>(m_Exemplar.get())) {
inserter.insertLevel(EXEMPLAR_FROM_PRIOR_TAG,
boost::bind(&CNaiveBayesFeatureDensity::acceptPersistInserter,
m_Exemplar.get(), _1));
}
// Add other implementations' persist code here.

TSizeClassUMapCItrVec classes;
classes.reserve(m_ClassConditionalDensities.size());
for (auto i = m_ClassConditionalDensities.begin();
Expand Down
14 changes: 9 additions & 5 deletions lib/maths/CTrendComponent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@ TOptionalDoubleDoublePr confidenceInterval(double prediction, double variance, d
return TOptionalDoubleDoublePr{};
}

CNaiveBayesFeatureDensityFromPrior naiveBayesExemplar(double decayRate) {
return CNaiveBayesFeatureDensityFromPrior{CNormalMeanPrecConjugate::nonInformativePrior(
maths_t::E_ContinuousData, TIME_SCALES[NUMBER_MODELS - 1] * decayRate)};
}

CNaiveBayes initialProbabilityOfChangeModel(double decayRate) {
decayRate *= TIME_SCALES[NUMBER_MODELS - 1];
return CNaiveBayes{CNaiveBayesFeatureDensityFromPrior{CNormalMeanPrecConjugate::nonInformativePrior(
maths_t::E_ContinuousData, decayRate)},
decayRate, -20.0};
return CNaiveBayes{naiveBayesExemplar(decayRate),
TIME_SCALES[NUMBER_MODELS - 1] * decayRate, -20.0};
}

CNormalMeanPrecConjugate initialMagnitudeOfChangeModel(double decayRate) {
Expand Down Expand Up @@ -157,7 +160,8 @@ bool CTrendComponent::acceptRestoreTraverser(const SDistributionRestoreParams& p
RESTORE(VALUE_MOMENTS_TAG, m_ValueMoments.fromDelimited(traverser.value()))
RESTORE_BUILT_IN(TIME_OF_LAST_LEVEL_CHANGE_TAG, m_TimeOfLastLevelChange)
RESTORE_NO_ERROR(PROBABILITY_OF_LEVEL_CHANGE_MODEL_TAG,
m_ProbabilityOfLevelChangeModel = CNaiveBayes(params, traverser))
m_ProbabilityOfLevelChangeModel = std::move(CNaiveBayes(
naiveBayesExemplar(m_DefaultDecayRate), params, traverser)))
RESTORE_NO_ERROR(MAGNITUDE_OF_LEVEL_CHANGE_MODEL_TAG,
m_MagnitudeOfLevelChangeModel =
CNormalMeanPrecConjugate(params, traverser))
Expand Down
3 changes: 2 additions & 1 deletion lib/maths/unittest/CNaiveBayesTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ void CNaiveBayesTest::testPersist() {
core::CRapidXmlStateRestoreTraverser traverser(parser);

maths::SDistributionRestoreParams params{maths_t::E_ContinuousData, 0.1, 0.0, 0.0, 0.0};
maths::CNaiveBayes restoredNb{params, traverser};
maths::CNaiveBayes restoredNb{maths::CNaiveBayesFeatureDensityFromPrior(normal),
params, traverser};

CPPUNIT_ASSERT_EQUAL(origNb.checksum(), restoredNb.checksum());

Expand Down

0 comments on commit 90f22e2

Please sign in to comment.