Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
NUP-2506: Add missing state fields to serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
lscheinkman committed Apr 9, 2018
1 parent 6d483ff commit 203493e
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 15 deletions.
227 changes: 215 additions & 12 deletions src/nupic/algorithms/Cells4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ Cells4::~Cells4() {
delete[] _tmpInputBuffer;
}

//--------------------------------------------------------------------------------
/**
* Simple helper function for allocating our numerous state variables
*/
template <typename It> void allocateState(It *&state, const UInt numElmts) {
state = new It[numElmts];
memset(state, 0, numElmts * sizeof(It));
}

//--------------------------------------------------------------------------------
// Utility routines used in this file to print list of active columns and cell
// indices
Expand Down Expand Up @@ -1920,6 +1929,53 @@ void Cells4::write(Cells4Proto::Builder &proto) const {
auto learnPredictedStateT1Proto = proto.initLearnPredictedStateT1();
_learnPredictedStateT1.write(learnPredictedStateT1Proto);

if (_ownsMemory) {
auto infActiveStateT = proto.initInfActiveStateT();
_infActiveStateT.write(infActiveStateT);
auto infActiveStateT1 = proto.initInfActiveStateT1();
_infActiveStateT1.write(infActiveStateT1);

auto infPredictedStateT = proto.initInfPredictedStateT();
_infPredictedStateT.write(infPredictedStateT);
auto infPredictedStateT1 = proto.initInfPredictedStateT1();
_infPredictedStateT1.write(infPredictedStateT1);

auto cellConfidenceT = proto.initCellConfidenceT(_nCells);
for (UInt i = 0; i < _nCells; ++i) {
cellConfidenceT.set(i, _cellConfidenceT[i]);
}
auto cellConfidenceT1 = proto.initCellConfidenceT1(_nCells);
for (UInt i = 0; i < _nCells; ++i) {
cellConfidenceT1.set(i, _cellConfidenceT1[i]);
}
auto colConfidenceT = proto.initColConfidenceT(_nColumns);
for (UInt i = 0; i < _nColumns; ++i) {
colConfidenceT.set(i, _colConfidenceT[i]);
}
auto colConfidenceT1 = proto.initColConfidenceT1(_nColumns);
for (UInt i = 0; i < _nColumns; ++i) {
colConfidenceT1.set(i, _colConfidenceT1[i]);
}
}

auto prevInfPatterns = proto.initPrevInfPatterns(_prevInfPatterns.size());
for (UInt i = 0; i < _prevInfPatterns.size(); ++i) {
const auto &pattern = _prevInfPatterns[i];
auto row = prevInfPatterns.init(i, pattern.size());
for (UInt j = 0; j < pattern.size(); j++) {
row.set(j, pattern[j]);
}
}
auto prevLrnPatterns = proto.initPrevLrnPatterns(_prevLrnPatterns.size());
for (UInt i = 0; i < _prevLrnPatterns.size(); ++i) {
const auto &pattern = _prevLrnPatterns[i];
auto row = prevLrnPatterns.init(i, pattern.size());
for (UInt j = 0; j < pattern.size(); j++) {
row.set(j, pattern[j]);
}
}

NTA_CHECK(_nCells == _cells.size());
auto cellListProto = proto.initCells(_nCells);
for (UInt i = 0; i < _nCells; ++i) {
auto cellProto = cellListProto[i];
Expand Down Expand Up @@ -1984,6 +2040,59 @@ void Cells4::read(Cells4Proto::Reader &proto) {
_cells[i].read(cellProto);
}

if (proto.getOwnsMemory()) {
_infActiveStateT.initialize(_nCells);
_infActiveStateT1.initialize(_nCells);
_infPredictedStateT.initialize(_nCells);
_infPredictedStateT1.initialize(_nCells);
auto infActiveStateT = proto.getInfActiveStateT();
_infActiveStateT.read(infActiveStateT);
auto infActiveStateT1 = proto.getInfActiveStateT1();
_infActiveStateT1.read(infActiveStateT1);
auto infPredictedStateT = proto.getInfPredictedStateT();
_infPredictedStateT.read(infPredictedStateT);
auto infPredictedStateT1 = proto.getInfPredictedStateT1();
_infPredictedStateT1.read(infPredictedStateT1);

allocateState(_cellConfidenceT, _nCells);
allocateState(_cellConfidenceT1, _nCells);
allocateState(_colConfidenceT, _nColumns);
allocateState(_colConfidenceT1, _nColumns);
auto cellConfidenceT = proto.getCellConfidenceT();
for (UInt i = 0; i < cellConfidenceT.size(); i++) {
_cellConfidenceT[i] = cellConfidenceT[i];
}
auto cellConfidenceT1 = proto.getCellConfidenceT1();
for (UInt i = 0; i < cellConfidenceT1.size(); i++) {
_cellConfidenceT1[i] = cellConfidenceT1[i];
}
auto colConfidenceT = proto.getColConfidenceT();
for (UInt i = 0; i < colConfidenceT.size(); i++) {
_colConfidenceT[i] = colConfidenceT[i];
}
auto colConfidenceT1 = proto.getColConfidenceT1();
for (UInt i = 0; i < colConfidenceT1.size(); i++) {
_colConfidenceT1[i] = colConfidenceT1[i];
}
}

_prevInfPatterns.clear();
auto prevInfPatterns = proto.getPrevInfPatterns();
for (UInt i = 0; i < prevInfPatterns.size(); i++) {
_prevInfPatterns.emplace_back(prevInfPatterns[i].size());
for (UInt j = 0; j < prevInfPatterns[i].size(); j++) {
_prevInfPatterns[i][j] = prevInfPatterns[i][j];
}
}

_prevLrnPatterns.clear();
auto prevLrnPatterns = proto.getPrevLrnPatterns();
for (UInt i = 0; i < prevLrnPatterns.size(); i++) {
_prevLrnPatterns.emplace_back(prevLrnPatterns[i].size());
for (UInt j = 0; j < prevLrnPatterns[i].size(); j++) {
_prevLrnPatterns[i][j] = prevLrnPatterns[i][j];
}
}
auto segmentUpdatesListProto = proto.getSegmentUpdates();
_segmentUpdates.clear();
_segmentUpdates.resize(segmentUpdatesListProto.size());
Expand All @@ -2006,7 +2115,7 @@ void Cells4::save(std::ostream &outStream) const {
if (_checkSynapseConsistency || (_nCells * _maxSegmentsPerCell < 100000)) {
NTA_CHECK(invariants(true));
}

outStream.precision(std::numeric_limits<double>::digits10);
outStream << version() << " " << _ownsMemory << " " << _rng << " "
<< _nColumns << " " << _nCellsPerCol << " " << _activationThreshold
<< " " << _minThreshold << " " << _newSynapseCount << " "
Expand Down Expand Up @@ -2036,10 +2145,37 @@ void Cells4::save(std::ostream &outStream) const {
NTA_CHECK(_nCells == _cells.size());
for (UInt i = 0; i != _nCells; ++i) {
_cells[i].save(outStream);
}

if (_ownsMemory) {
outStream << _infActiveStateT << " " << _infActiveStateT1 << " "
<< _infPredictedStateT << " " << _infPredictedStateT1 << " "
<< std::endl;
for (UInt i = 0; i != _nCells; ++i) {
outStream << _cellConfidenceT[i] << " " << _cellConfidenceT1[i] << " ";
}
outStream << std::endl;

for (UInt i = 0; i != _nColumns; ++i) {
outStream << _colConfidenceT[i] << " " << _colConfidenceT1[i] << " ";
}
outStream << std::endl;
}

outStream << " out ";
outStream << _prevLrnPatterns.size();
for (auto &elem : _prevLrnPatterns) {
outStream << std::endl << elem.size() << " ";
std::copy(elem.begin(), elem.end(),
std::ostream_iterator<UInt>(outStream, " "));
}
outStream << std::endl << _prevInfPatterns.size();
for (auto &elem : _prevInfPatterns) {
outStream << std::endl << elem.size() << " ";
std::copy(elem.begin(), elem.end(),
std::ostream_iterator<UInt>(outStream, " "));
}

outStream << std::endl << "out" << std::endl;
}

//----------------------------------------------------------------------------
Expand Down Expand Up @@ -2121,8 +2257,8 @@ void Cells4::load(std::istream &inStream) {
_learnPredictedStateT1.load(inStream);
}

UInt n;
if (v >= 2) {
UInt n;
_segmentUpdates.clear();
inStream >> n;
for (UInt i = 0; i < n; ++i) {
Expand All @@ -2135,6 +2271,43 @@ void Cells4::load(std::istream &inStream) {
_cells[i].load(inStream);
}

if (_ownsMemory) {
_infActiveStateT.load(inStream);
_infActiveStateT1.load(inStream);
_infPredictedStateT.load(inStream);
_infPredictedStateT1.load(inStream);

allocateState(_cellConfidenceT, _nCells);
allocateState(_cellConfidenceT1, _nCells);
allocateState(_colConfidenceT, _nColumns);
allocateState(_colConfidenceT1, _nColumns);
for (UInt i = 0; i != _nCells; ++i) {
inStream >> _cellConfidenceT[i] >> _cellConfidenceT1[i];
}
for (UInt i = 0; i != _nColumns; ++i) {
inStream >> _colConfidenceT[i] >> _colConfidenceT1[i];
}
}
inStream >> n;
std::vector<UInt> pattern;
size_t size;
_prevLrnPatterns.clear();
for (UInt i = 0; i < n; i++) {
pattern.clear();
inStream >> size;
std::copy_n(std::istream_iterator<UInt>(inStream), size,
std::back_inserter(pattern));
_prevLrnPatterns.push_back(pattern);
}
inStream >> n;
_prevInfPatterns.clear();
for (UInt i = 0; i < n; i++) {
pattern.clear();
inStream >> size;
std::copy_n(std::istream_iterator<UInt>(inStream), size,
std::back_inserter(pattern));
_prevInfPatterns.push_back(pattern);
}
std::string marker;
inStream >> marker;
NTA_CHECK(marker == "out");
Expand Down Expand Up @@ -2282,15 +2455,6 @@ void Cells4::updateSegment(
_segmentUpdates.push_back(update);
}

//--------------------------------------------------------------------------------
/**
* Simple helper function for allocating our numerous state variables
*/
template <typename It> void allocateState(It *&state, const UInt numElmts) {
state = new It[numElmts];
memset(state, 0, numElmts * sizeof(It));
}

void Cells4::setCellSegmentOrder(bool matchPythonOrder) {
Cell::setSegmentOrder(matchPythonOrder);
}
Expand Down Expand Up @@ -2850,6 +3014,45 @@ bool Cells4::operator==(const Cells4 &other) const {
if (_learnPredictedStateT1 != other._learnPredictedStateT1) {
return false;
}
if (_infActiveStateT != other._infActiveStateT) {
return false;
}
if (_infActiveStateT1 != other._infActiveStateT1) {
return false;
}
if (_infPredictedStateT != other._infPredictedStateT) {
return false;
}
if (_infPredictedStateT1 != other._infPredictedStateT1) {
return false;
}
if (_prevInfPatterns != other._prevInfPatterns) {
return false;
}
if (_prevLrnPatterns != other._prevLrnPatterns) {
return false;
}
if (_nCells > 0) {
if (memcmp(_cellConfidenceT, other._cellConfidenceT,
_nCells * sizeof(Real)) != 0) {
return false;
}
if (memcmp(_cellConfidenceT1, other._cellConfidenceT1,
_nCells * sizeof(Real)) != 0) {
return false;
}
}
if (_nColumns > 0) {
if (memcmp(_colConfidenceT, other._colConfidenceT,
_nColumns * sizeof(Real)) != 0) {
return false;
}
if (memcmp(_colConfidenceT1, other._colConfidenceT1,
_nColumns * sizeof(Real)) != 0) {
return false;
}
}

return true;
}

Expand Down
1 change: 1 addition & 0 deletions src/nupic/algorithms/Segment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ class CStateIndexed : public CState {
UInt nCellsOn;
inStream >> nCellsOn;
UInt v;
_cellsOn.clear();
for (UInt i = 0; i < nCellsOn; ++i) {
inStream >> v;
_cellsOn.push_back(v);
Expand Down
15 changes: 14 additions & 1 deletion src/nupic/proto/Cells4.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using import "/nupic/proto/RandomProto.capnp".RandomProto;
using import "/nupic/proto/Segment.capnp".CStateProto;
using import "/nupic/proto/SegmentUpdate.capnp".SegmentUpdateProto;

# Next ID: 39
# Next ID: 49
struct Cells4Proto {
version @0 :UInt16;
ownsMemory @1 :Bool;
Expand Down Expand Up @@ -50,6 +50,19 @@ struct Cells4Proto {
learnPredictedStateT @35 :CStateProto;
learnPredictedStateT1 @36 :CStateProto;

infActiveStateT @39 :CStateProto;
infActiveStateT1 @40 :CStateProto;
infPredictedStateT @41 :CStateProto;
infPredictedStateT1 @42 :CStateProto;

cellConfidenceT @43 :List(Float32);
cellConfidenceT1 @44 :List(Float32);
colConfidenceT @45 :List(Float32);
colConfidenceT1 @46 :List(Float32);

prevInfPatterns @47 :List(List(UInt32));
prevLrnPatterns @48 :List(List(UInt32));

cells @37 :List(CellProto);
segmentUpdates @38 :List(SegmentUpdateProto);
}
Loading

0 comments on commit 203493e

Please sign in to comment.