From 860459cf4b318f497eeb59dbf2df805d45fdbf3a Mon Sep 17 00:00:00 2001 From: Luiz Scheinkman Date: Tue, 27 Mar 2018 23:20:20 -0700 Subject: [PATCH] NUP-2506: Add operator '==' to classes used in tests --- bindings/py/tests/algorithms/cells4_test.py | 78 +++++++++++++ bindings/py/tests/nupic_random_test.py | 7 ++ src/nupic/algorithms/Cell.cpp | 6 + src/nupic/algorithms/Cell.hpp | 4 + src/nupic/algorithms/Cells4.cpp | 52 +++++++++ src/nupic/algorithms/Cells4.hpp | 6 + src/nupic/algorithms/InSynapse.hpp | 7 ++ src/nupic/algorithms/Segment.cpp | 14 +++ src/nupic/algorithms/Segment.hpp | 29 +++++ src/nupic/algorithms/SegmentUpdate.cpp | 11 ++ src/nupic/algorithms/SegmentUpdate.hpp | 5 + src/nupic/engine/Link.cpp | 13 ++- src/nupic/engine/Link.hpp | 3 + src/nupic/engine/Network.cpp | 19 ++++ src/nupic/engine/Network.hpp | 5 + src/nupic/engine/Region.cpp | 120 ++++++++++++++++++++ src/nupic/engine/Region.hpp | 5 + src/nupic/utils/Random.cpp | 15 +++ src/nupic/utils/Random.hpp | 5 + src/test/unit/algorithms/Cells4Test.cpp | 58 ++++++++++ src/test/unit/algorithms/SegmentTest.cpp | 20 ++++ src/test/unit/engine/NetworkTest.cpp | 37 ++++++ src/test/unit/utils/RandomTest.cpp | 17 +++ 23 files changed, 535 insertions(+), 1 deletion(-) diff --git a/bindings/py/tests/algorithms/cells4_test.py b/bindings/py/tests/algorithms/cells4_test.py index d62d76194f..f843990606 100755 --- a/bindings/py/tests/algorithms/cells4_test.py +++ b/bindings/py/tests/algorithms/cells4_test.py @@ -32,7 +32,53 @@ _RGEN = Random(43) +def createCells4(nCols=8, + nCellsPerCol=4, + activationThreshold=1, + minThreshold=1, + newSynapseCount=2, + segUpdateValidDuration=2, + permInitial=0.5, + permConnected=0.8, + permMax=1.0, + permDec=0.1, + permInc=0.2, + globalDecay=0.05, + doPooling=True, + pamLength=2, + maxAge=3, + seed=42, + initFromCpp=True, + checkSynapseConsistency=False): + cells = Cells4(nCols, + nCellsPerCol, + activationThreshold, + minThreshold, + newSynapseCount, + segUpdateValidDuration, + permInitial, + permConnected, + permMax, + permDec, + permInc, + globalDecay, + doPooling, + seed, + initFromCpp, + checkSynapseConsistency) + + + cells.setPamLength(pamLength) + cells.setMaxAge(maxAge) + cells.setMaxInfBacktrack(4) + + for i in xrange(nCols): + for j in xrange(nCellsPerCol): + cells.addNewSegment(i, j, True if j % 2 == 0 else False, + [((i + 1) % nCols, (j + 1) % nCellsPerCol)]) + + return cells class Cells4Test(unittest.TestCase): @@ -205,3 +251,35 @@ def testLearn(self): cells.compute(x, True, False) self._testPersistence(cells) + + def testEquals(self): + nCols = 10 + c1 = createCells4(nCols) + c2 = createCells4(nCols) + self.assertEquals(c1, c2) + + # learn + data = [numpy.random.choice(nCols, nCols/3, False) for _ in xrange(10)] + for idx in data: + x = numpy.zeros(nCols, dtype="float32") + x[idx] = 1.0 + c1.compute(x, True, True) + c2.compute(x, True, True) + self.assertEquals(c1, c2) + + self.assertEquals(c1, c2) + + c1.rebuildOutSynapses() + c2.rebuildOutSynapses() + self.assertEquals(c1, c2) + + # inference + data = [numpy.random.choice(nCols, nCols/3, False) for _ in xrange(100)] + for idx in data: + x = numpy.zeros(nCols, dtype="float32") + x[idx] = 1.0 + c1.compute(x, True, False) + c2.compute(x, True, False) + self.assertEquals(c1, c2) + + self.assertEquals(c1, c2) diff --git a/bindings/py/tests/nupic_random_test.py b/bindings/py/tests/nupic_random_test.py index 06ab8305ca..2152d66a43 100755 --- a/bindings/py/tests/nupic_random_test.py +++ b/bindings/py/tests/nupic_random_test.py @@ -246,6 +246,13 @@ def testShuffleBadDtype(self): self.assertRaises(ValueError, r.shuffle, arr) + def testEquals(self): + r1 = Random(42) + v1 = r1.getReal64() + r2 = Random(42) + v2 = r2.getReal64() + self.assertEquals(v1, v2) + self.assertEquals(r1, r2) if __name__ == "__main__": unittest.main() diff --git a/src/nupic/algorithms/Cell.cpp b/src/nupic/algorithms/Cell.cpp index 24e890d68e..d9220a08aa 100644 --- a/src/nupic/algorithms/Cell.cpp +++ b/src/nupic/algorithms/Cell.cpp @@ -156,3 +156,9 @@ void Cell::load(std::istream &inStream) { _freeSegments.push_back(i); } } +bool Cell::operator==(const Cell &other) const { + if (_freeSegments != other._freeSegments) { + return false; + } + return _segments == other._segments; +} diff --git a/src/nupic/algorithms/Cell.hpp b/src/nupic/algorithms/Cell.hpp index d17175c25d..b1b5abe32c 100644 --- a/src/nupic/algorithms/Cell.hpp +++ b/src/nupic/algorithms/Cell.hpp @@ -108,6 +108,10 @@ class Cell : Serializable { return _segments[segIdx]; } + //---------------------------------------------------------------------- + bool operator==(const Cell &other) const; + inline bool operator!=(const Cell &other) const { return !operator==(other); } + //-------------------------------------------------------------------------------- Segment &getSegment(UInt segIdx) { NTA_ASSERT(segIdx < _segments.size()); diff --git a/src/nupic/algorithms/Cells4.cpp b/src/nupic/algorithms/Cells4.cpp index 128ce3c9ab..2327d7d794 100644 --- a/src/nupic/algorithms/Cells4.cpp +++ b/src/nupic/algorithms/Cells4.cpp @@ -2801,6 +2801,58 @@ std::ostream &operator<<(std::ostream &outStream, const Cells4 &cells) { return outStream; } +bool Cells4::operator==(const Cells4 &other) const { + + if (_activationThreshold != other._activationThreshold || + _avgInputDensity != other._avgInputDensity || + _avgLearnedSeqLength != other._avgLearnedSeqLength || + _checkSynapseConsistency != other._checkSynapseConsistency || + _doPooling != other._doPooling || _globalDecay != other._globalDecay || + _initSegFreq != other._initSegFreq || + _learnedSeqLength != other._learnedSeqLength || + _maxAge != other._maxAge || _maxInfBacktrack != other._maxInfBacktrack || + _maxLrnBacktrack != other._maxLrnBacktrack || + _maxSegmentsPerCell != other._maxSegmentsPerCell || + _maxSeqLength != other._maxSeqLength || + _maxSynapsesPerSegment != other._maxSynapsesPerSegment || + _minThreshold != other._minThreshold || _nCells != other._nCells || + _nCellsPerCol != other._nCellsPerCol || _nColumns != other._nColumns || + _newSynapseCount != other._newSynapseCount || + _nIterations != other._nIterations || + _nLrnIterations != other._nLrnIterations || + _ownsMemory != other._ownsMemory || _pamCounter != other._pamCounter || + _pamLength != other._pamLength || + _permConnected != other._permConnected || _permDec != other._permDec || + _permInc != other._permInc || _permInitial != other._permInitial || + _permMax != other._permMax || _resetCalled != other._resetCalled || + _segUpdateValidDuration != other._segUpdateValidDuration || + _verbosity != other._verbosity || _version != other._version) { + return false; + } + if (_rng != other._rng) { + return false; + } + if (_cells != other._cells) { + return false; + } + if (_segmentUpdates != other._segmentUpdates) { + return false; + } + if (_learnActiveStateT != other._learnActiveStateT) { + return false; + } + if (_learnActiveStateT1 != other._learnActiveStateT1) { + return false; + } + if (_learnPredictedStateT != other._learnPredictedStateT) { + return false; + } + if (_learnPredictedStateT1 != other._learnPredictedStateT1) { + return false; + } + return true; +} + //---------------------------------------------------------------------- /** * Compute cell and segment activities using forward propagation diff --git a/src/nupic/algorithms/Cells4.hpp b/src/nupic/algorithms/Cells4.hpp index fa9f78f168..0d3e051a96 100644 --- a/src/nupic/algorithms/Cells4.hpp +++ b/src/nupic/algorithms/Cells4.hpp @@ -397,6 +397,12 @@ class Cells4 : public Serializable { //---------------------------------------------------------------------- ~Cells4(); + //---------------------------------------------------------------------- + bool operator==(const Cells4 &other) const; + inline bool operator!=(const Cells4 &other) const { + return !operator==(other); + } + //---------------------------------------------------------------------- UInt version() const { return _version; } diff --git a/src/nupic/algorithms/InSynapse.hpp b/src/nupic/algorithms/InSynapse.hpp index dfd18f5e5f..10808b6a3a 100644 --- a/src/nupic/algorithms/InSynapse.hpp +++ b/src/nupic/algorithms/InSynapse.hpp @@ -63,6 +63,13 @@ class InSynapse { return *this; } + inline bool operator==(const InSynapse &other) const { + return _srcCellIdx == other._srcCellIdx && _permanence == other._permanence; + } + inline bool operator!=(const InSynapse &other) const { + return !operator==(other); + } + inline UInt srcCellIdx() const { return _srcCellIdx; } const inline Real &permanence() const { return _permanence; } inline Real &permanence() { return _permanence; } diff --git a/src/nupic/algorithms/Segment.cpp b/src/nupic/algorithms/Segment.cpp index d907acde6f..2cf6b86e81 100644 --- a/src/nupic/algorithms/Segment.cpp +++ b/src/nupic/algorithms/Segment.cpp @@ -81,6 +81,20 @@ Segment &Segment::operator=(const Segment &o) { return *this; } +//-------------------------------------------------------------------------------- +bool Segment::operator==(const Segment &other) const { + if (_totalActivations != other._totalActivations || + _positiveActivations != other._positiveActivations || + _lastActiveIteration != other._lastActiveIteration || + _lastPosDutyCycle != other._lastPosDutyCycle || + _lastPosDutyCycleIteration != other._lastPosDutyCycleIteration || + _seqSegFlag != other._seqSegFlag || _frequency != other._frequency || + _nConnected != other._nConnected) { + return false; + } + return _synapses == other._synapses; +} + //-------------------------------------------------------------------------------- Segment::Segment(const Segment &o) : _totalActivations(o._totalActivations), diff --git a/src/nupic/algorithms/Segment.hpp b/src/nupic/algorithms/Segment.hpp index 8f76410b55..df3386c8fc 100644 --- a/src/nupic/algorithms/Segment.hpp +++ b/src/nupic/algorithms/Segment.hpp @@ -109,6 +109,19 @@ class CState : Serializable { memcpy(_pData, o._pData, _nCells); return *this; } + bool operator==(const CState &other) const { + if (_version != other._version || _nCells != other._nCells || + _fMemoryAllocatedByPython != other._fMemoryAllocatedByPython) { + return false; + } + if (_pData != nullptr && other._pData != nullptr) { + return ::memcmp(_pData, other._pData, _nCells) == 0; + } + return _pData == other._pData; + } + inline bool operator!=(const CState &other) const { + return !operator==(other); + } bool initialize(const UInt nCells) { if (_nCells != 0) // if already initialized return false; // don't do it again @@ -217,6 +230,19 @@ class CStateIndexed : public CState { _isSorted = o._isSorted; return *this; } + bool operator==(const CStateIndexed &other) const { + if (_version != other._version || _countOn != other._countOn || + _isSorted != other._isSorted) { + return false; + } + if (_cellsOn != other._cellsOn) { + return false; + } + return CState::operator==(other); + } + inline bool operator!=(const CStateIndexed &other) const { + return !operator==(other); + } std::vector cellsOn(bool fSorted = false) { // It's better for the caller to ask us to sort, rather than // to sort himself, since we can optimize out the sort when we @@ -342,6 +368,9 @@ class Segment : Serializable { Real _lastPosDutyCycle; UInt _lastPosDutyCycleIteration; + bool operator==(const Segment &o) const; + inline bool operator!=(const Segment &o) const { return !operator==(o); } + private: bool _seqSegFlag; // sequence segment flag Real _frequency; // frequency [UNUSED IN LATEST IMPLEMENTATION] diff --git a/src/nupic/algorithms/SegmentUpdate.cpp b/src/nupic/algorithms/SegmentUpdate.cpp index 7fd8e69bf3..025f3bbb88 100644 --- a/src/nupic/algorithms/SegmentUpdate.cpp +++ b/src/nupic/algorithms/SegmentUpdate.cpp @@ -71,3 +71,14 @@ bool SegmentUpdate::invariants(Cells4 *cells) const { return ok; } + +bool SegmentUpdate::operator==(const SegmentUpdate &o) const { + + if (_cellIdx != o._cellIdx || _segIdx != o._segIdx || + _sequenceSegment != o._sequenceSegment || _timeStamp != o._timeStamp || + _phase1Flag != o._phase1Flag || + _weaklyPredicting != o._weaklyPredicting) { + return false; + } + return _synapses == o._synapses; +} diff --git a/src/nupic/algorithms/SegmentUpdate.hpp b/src/nupic/algorithms/SegmentUpdate.hpp index f08f2160e5..d02de07d15 100644 --- a/src/nupic/algorithms/SegmentUpdate.hpp +++ b/src/nupic/algorithms/SegmentUpdate.hpp @@ -85,6 +85,11 @@ class SegmentUpdate : Serializable { NTA_ASSERT(invariants()); return *this; } + //--------------------------------------------------------------------- + bool operator==(const SegmentUpdate &other) const; + inline bool operator!=(const SegmentUpdate &other) const { + return !operator==(other); + } //--------------------------------------------------------------------- bool isSequenceSegment() const { return _sequenceSegment; } diff --git a/src/nupic/engine/Link.cpp b/src/nupic/engine/Link.cpp index 3fb0fc600e..3d94a0407f 100644 --- a/src/nupic/engine/Link.cpp +++ b/src/nupic/engine/Link.cpp @@ -490,7 +490,18 @@ void Link::read(LinkProto::Reader &proto) { } } } - +bool Link::operator==(const Link &o) const { + if (initialized_ != o.initialized_ || + propagationDelay_ != o.propagationDelay_ || linkType_ != o.linkType_ || + linkParams_ != o.linkParams_ || destOffset_ != o.destOffset_ || + srcRegionName_ != o.srcRegionName_ || + destRegionName_ != o.destRegionName_ || + srcOutputName_ != o.srcOutputName_ || + destInputName_ != o.destInputName_) { + return false; + } + return true; +} namespace nupic { std::ostream &operator<<(std::ostream &f, const Link &link) { f << "\n"; diff --git a/src/nupic/engine/Link.hpp b/src/nupic/engine/Link.hpp index d59144c9da..e0f471c565 100644 --- a/src/nupic/engine/Link.hpp +++ b/src/nupic/engine/Link.hpp @@ -401,6 +401,9 @@ class Link : public Serializable { using Serializable::read; void read(LinkProto::Reader &proto); + bool operator==(const Link &other) const; + inline bool operator!=(const Link &other) const { return !operator==(other); } + private: // common initialization for the two constructors. void commonConstructorInit_(const std::string &linkType, diff --git a/src/nupic/engine/Network.cpp b/src/nupic/engine/Network.cpp index e3df60eab4..34a4206e74 100644 --- a/src/nupic/engine/Network.cpp +++ b/src/nupic/engine/Network.cpp @@ -1061,4 +1061,23 @@ void Network::unregisterCPPRegion(const std::string name) { Region::unregisterCPPRegion(name); } +bool Network::operator==(const Network &o) const { + + if (initialized_ != o.initialized_ || iteration_ != o.iteration_ || + minEnabledPhase_ != o.minEnabledPhase_ || + maxEnabledPhase_ != o.maxEnabledPhase_ || + regions_.getCount() != o.regions_.getCount()) { + return false; + } + + for (size_t i = 0; i < regions_.getCount(); i++) { + Region *r1 = regions_.getByIndex(i).second; + Region *r2 = o.regions_.getByIndex(i).second; + if (*r1 != *r2) { + return false; + } + } + return true; +} + } // namespace nupic diff --git a/src/nupic/engine/Network.hpp b/src/nupic/engine/Network.hpp index d99b2a7c0a..d1bdc85c6b 100644 --- a/src/nupic/engine/Network.hpp +++ b/src/nupic/engine/Network.hpp @@ -409,6 +409,11 @@ class Network : public Serializable { */ static void unregisterCPPRegion(const std::string name); + bool operator==(const Network &other) const; + inline bool operator!=(const Network &other) const { + return !operator==(other); + } + private: // Both constructors use this common initialization method void commonInit(); diff --git a/src/nupic/engine/Region.cpp b/src/nupic/engine/Region.cpp index 7ce50c1a88..3694614e57 100644 --- a/src/nupic/engine/Region.cpp +++ b/src/nupic/engine/Region.cpp @@ -434,4 +434,124 @@ const Timer &Region::getComputeTimer() const { return computeTimer_; } const Timer &Region::getExecuteTimer() const { return executeTimer_; } +bool Region::operator==(const Region &o) const { + + if (name_ != o.name_ || type_ != o.type_ || dims_ != o.dims_ || + phases_ != o.phases_ || dimensionInfo_ != o.dimensionInfo_ || + initialized_ != o.initialized_ || outputs_.size() != o.outputs_.size() || + inputs_.size() != o.inputs_.size()) { + return false; + } + if (spec_ != nullptr && o.spec_ != nullptr) { + // Compare specs + if (spec_->singleNodeOnly != o.spec_->singleNodeOnly || + spec_->description != o.spec_->description) { + return false; + } + + // Parameters + for (size_t i = 0; i < spec_->parameters.getCount(); ++i) { + const std::pair &p1 = + spec_->parameters.getByIndex(i); + const std::pair &p2 = + o.spec_->parameters.getByIndex(i); + if (p1.first != p2.first || p1.second.count != p2.second.count || + p1.second.description != p2.second.description || + p1.second.constraints != p2.second.constraints || + p1.second.defaultValue != p2.second.defaultValue || + p1.second.dataType != p2.second.dataType || + p1.second.accessMode != p2.second.accessMode) { + return false; + } + } + // Outputs + for (size_t i = 0; i < spec_->outputs.getCount(); ++i) { + const std::pair &p1 = + spec_->outputs.getByIndex(i); + const std::pair &p2 = + o.spec_->outputs.getByIndex(i); + if (p1.first != p2.first || p1.second.count != p2.second.count || + p1.second.regionLevel != p2.second.regionLevel || + p1.second.isDefaultOutput != p2.second.isDefaultOutput || + p1.second.sparse != p2.second.sparse || + p1.second.description != p2.second.description || + p1.second.dataType != p2.second.dataType) { + return false; + } + } + + // Outputs + for (size_t i = 0; i < spec_->inputs.getCount(); ++i) { + const std::pair &p1 = spec_->inputs.getByIndex(i); + const std::pair &p2 = + o.spec_->inputs.getByIndex(i); + if (p1.first != p2.first || p1.second.count != p2.second.count || + p1.second.regionLevel != p2.second.regionLevel || + p1.second.isDefaultInput != p2.second.isDefaultInput || + p1.second.sparse != p2.second.sparse || + p1.second.requireSplitterMap != p2.second.requireSplitterMap || + p1.second.required != p2.second.required || + p1.second.description != p2.second.description || + p1.second.dataType != p2.second.dataType) { + return false; + } + } + // Commands + for (size_t i = 0; i < spec_->commands.getCount(); ++i) { + const std::pair &p1 = + spec_->commands.getByIndex(i); + const std::pair &p2 = + o.spec_->commands.getByIndex(i); + if (p1.first != p2.first || + p1.second.description != p2.second.description) { + return false; + } + } + } else if (spec_ != o.spec_) { + // One of them is not null + return false; + } + + // Compare Regions's Input + static auto compareInput = [](decltype(*inputs_.begin()) a, decltype(a) b) { + if (a.first != b.first || + a.second->isRegionLevel() != b.second->isRegionLevel() || + a.second->isSparse() != b.second->isSparse()) { + return false; + } + auto links1 = a.second->getLinks(); + auto links2 = b.second->getLinks(); + if (links1.size() != links2.size()) { + return false; + } + for (size_t i = 0; i < links1.size(); i++) { + if (*(links1[i]) != *(links2[i])) { + return false; + } + } + return true; + }; + if (!std::equal(inputs_.begin(), inputs_.end(), o.inputs_.begin(), + compareInput)) { + return false; + } + // Compare Regions's Output + static auto compareOutput = [](decltype(*outputs_.begin()) a, decltype(a) b) { + if (a.first != b.first || + a.second->isRegionLevel() != b.second->isRegionLevel() || + a.second->isSparse() != b.second->isSparse() || + a.second->getNodeOutputElementCount() != + b.second->getNodeOutputElementCount()) { + return false; + } + return true; + }; + if (!std::equal(outputs_.begin(), outputs_.end(), o.outputs_.begin(), + compareOutput)) { + return false; + } + + return true; +} + } // namespace nupic diff --git a/src/nupic/engine/Region.hpp b/src/nupic/engine/Region.hpp index 59161b47cb..496d31755a 100644 --- a/src/nupic/engine/Region.hpp +++ b/src/nupic/engine/Region.hpp @@ -609,6 +609,11 @@ class Region : public Serializable { */ const Timer &getExecuteTimer() const; + bool operator==(const Region &other) const; + inline bool operator!=(const Region &other) const { + return !operator==(other); + } + /** * @} */ diff --git a/src/nupic/utils/Random.cpp b/src/nupic/utils/Random.cpp index d8ed11ed76..d8fe2e80d1 100644 --- a/src/nupic/utils/Random.cpp +++ b/src/nupic/utils/Random.cpp @@ -68,6 +68,10 @@ class RandomImpl { void write(RandomImplProto::Builder &proto) const; void read(RandomImplProto::Reader &proto); UInt32 getUInt32(); + bool operator==(const RandomImpl &o) const; + inline bool operator!=(const RandomImpl &other) const { + return !operator==(other); + } // Note: copy constructor and operator= are needed // The default is ok. private: @@ -125,6 +129,10 @@ Random &Random::operator=(const Random &other) { return *this; } +bool Random::operator==(const Random &o) const { + return seed_ == o.seed_ && (*impl_) == (*o.impl_); +} + Random::~Random() { delete impl_; } Random::Random(UInt64 seed) { @@ -378,6 +386,13 @@ std::istream &operator>>(std::istream &inStream, RandomImpl &r) { return inStream; } +bool RandomImpl::operator==(const RandomImpl &o) const { + if (rptr_ != o.rptr_ || fptr_ != o.fptr_) { + return false; + } + return ::memcmp(state_, o.state_, sizeof(state_)) == 0; +} + // helper function for seeding RNGs across the plugin barrier // Unless there is a logic error, should not be called if // the Random singleton has not been initialized. diff --git a/src/nupic/utils/Random.hpp b/src/nupic/utils/Random.hpp index 7b932567cc..d7333f8f6f 100644 --- a/src/nupic/utils/Random.hpp +++ b/src/nupic/utils/Random.hpp @@ -170,6 +170,11 @@ class Random : public Serializable { static const UInt32 MAX32; static const UInt64 MAX64; + bool operator==(const Random &other) const; + inline bool operator!=(const Random &other) const { + return !operator==(other); + } + // called by the plugin framework so that plugins // get the "global" seeder static void initSeeder(const RandomSeedFuncPtr r); diff --git a/src/test/unit/algorithms/Cells4Test.cpp b/src/test/unit/algorithms/Cells4Test.cpp index efb5ce0916..011fbf4535 100644 --- a/src/test/unit/algorithms/Cells4Test.cpp +++ b/src/test/unit/algorithms/Cells4Test.cpp @@ -483,3 +483,61 @@ TEST(Cells4Test, ASSERT_EQ(expectedActiveSrcCellIdxs, activeSrcCellIdxs); ASSERT_EQ(expectedActiveSynapseIdxs, activeSynapseIdxs); } + +/** + * Test operator '==' + */ +TEST(Cells4Test, testEqualsOperator) { + Cells4 cells1(10, 2, 1, 1, 1, 1, 0.5, 0.8, 1, 0.1, 0.1, 0, false, 42, true, + false); + Cells4 cells2(10, 2, 1, 1, 1, 1, 0.5, 0.8, 1, 0.1, 0.1, 0, false, 42, true, + false); + ASSERT_TRUE(cells1 == cells2); + std::vector input1(10, 0.0); + input1[1] = 1.0; + input1[4] = 1.0; + input1[5] = 1.0; + input1[9] = 1.0; + std::vector input2(10, 0.0); + input2[0] = 1.0; + input2[2] = 1.0; + input2[5] = 1.0; + input2[6] = 1.0; + std::vector input3(10, 0.0); + input3[1] = 1.0; + input3[3] = 1.0; + input3[6] = 1.0; + input3[7] = 1.0; + std::vector input4(10, 0.0); + input4[2] = 1.0; + input4[4] = 1.0; + input4[7] = 1.0; + input4[8] = 1.0; + std::vector output(10 * 2); + for (UInt i = 0; i < 10; ++i) { + cells1.compute(&input1.front(), &output.front(), true, true); + ASSERT_TRUE(cells1 != cells2); + cells2.compute(&input1.front(), &output.front(), true, true); + ASSERT_TRUE(cells1 == cells2); + + cells1.compute(&input2.front(), &output.front(), true, true); + ASSERT_TRUE(cells1 != cells2); + cells2.compute(&input2.front(), &output.front(), true, true); + ASSERT_TRUE(cells1 == cells2); + + cells1.compute(&input3.front(), &output.front(), true, true); + ASSERT_TRUE(cells1 != cells2); + cells2.compute(&input3.front(), &output.front(), true, true); + ASSERT_TRUE(cells1 == cells2); + + cells1.compute(&input4.front(), &output.front(), true, true); + ASSERT_TRUE(cells1 != cells2); + cells2.compute(&input4.front(), &output.front(), true, true); + ASSERT_TRUE(cells1 == cells2); + + cells1.reset(); + ASSERT_TRUE(cells1 != cells2); + cells2.reset(); + ASSERT_TRUE(cells1 == cells2); + } +} \ No newline at end of file diff --git a/src/test/unit/algorithms/SegmentTest.cpp b/src/test/unit/algorithms/SegmentTest.cpp index 9ee67df7fa..759407cf15 100644 --- a/src/test/unit/algorithms/SegmentTest.cpp +++ b/src/test/unit/algorithms/SegmentTest.cpp @@ -138,3 +138,23 @@ TEST(SegmentTest, freeNSynapsesStableSort) { sort(removed.begin(), removed.end()); ASSERT_EQ(removed, removed_expected); } + +/** + * Test operator '==' + */ +TEST(SegmentTest, testEqualsOperator) { + Segment segment1; + Segment segment2; + + vector inactiveSegmentIndices; + vector activeSegmentIndices; + vector activeSynapseIndices; + vector inactiveSynapseIndices; + + setUpSegment(segment1, inactiveSegmentIndices, activeSegmentIndices, + activeSynapseIndices, inactiveSynapseIndices); + ASSERT_TRUE(segment1 != segment2); + setUpSegment(segment2, inactiveSegmentIndices, activeSegmentIndices, + activeSynapseIndices, inactiveSynapseIndices); + ASSERT_TRUE(segment1 == segment2); +} \ No newline at end of file diff --git a/src/test/unit/engine/NetworkTest.cpp b/src/test/unit/engine/NetworkTest.cpp index e7156273ed..7dd75764f8 100644 --- a/src/test/unit/engine/NetworkTest.cpp +++ b/src/test/unit/engine/NetworkTest.cpp @@ -541,3 +541,40 @@ TEST(NetworkTest, Callback) { EXPECT_STREQ("level2", mydata[4].c_str()); EXPECT_STREQ("level3", mydata[5].c_str()); } + +/** + * Test operator '==' + */ +TEST(NetworkTest, testEqualsOperator) { + Network n1; + Network n2; + ASSERT_TRUE(n1 == n2); + Dimensions d; + d.push_back(4); + d.push_back(4); + + auto l1 = n1.addRegion("level1", "TestNode", ""); + ASSERT_TRUE(n1 != n2); + auto l2 = n2.addRegion("level1", "TestNode", ""); + ASSERT_TRUE(n1 == n2); + + l1->setDimensions(d); + ASSERT_TRUE(n1 != n2); + l2->setDimensions(d); + ASSERT_TRUE(n1 == n2); + + n1.addRegion("level2", "TestNode", ""); + ASSERT_TRUE(n1 != n2); + n2.addRegion("level2", "TestNode", ""); + ASSERT_TRUE(n1 == n2); + + n1.link("level1", "level2", "TestFanIn2", ""); + ASSERT_TRUE(n1 != n2); + n2.link("level1", "level2", "TestFanIn2", ""); + ASSERT_TRUE(n1 == n2); + + n1.run(1); + ASSERT_TRUE(n1 != n2); + n2.run(1); + ASSERT_TRUE(n1 == n2); +} diff --git a/src/test/unit/utils/RandomTest.cpp b/src/test/unit/utils/RandomTest.cpp index ab589064b0..5179b7066b 100644 --- a/src/test/unit/utils/RandomTest.cpp +++ b/src/test/unit/utils/RandomTest.cpp @@ -453,3 +453,20 @@ TEST(RandomTest, CapnpSerialization) { // clean up remove(outputPath); } + +/** + * Test operator '==' + */ +TEST(RandomTest, testEqualsOperator) { + Random r1(42), r2(42), r3(3); + ASSERT_TRUE(r1 == r2); + ASSERT_TRUE(r1 != r3); + ASSERT_TRUE(r2 != r3); + + UInt32 v1, v2; + v1 = r1.getUInt32(); + ASSERT_TRUE(r1 != r2); + v2 = r2.getUInt32(); + ASSERT_TRUE(r1 == r2); + ASSERT_EQ(v1, v2); +} \ No newline at end of file