Skip to content

Commit

Permalink
Merge pull request #1351 from scottpurdy/nup-2356
Browse files Browse the repository at this point in the history
Add Cells4 read/write SWIG wrapping and fix bug in Cells4 deserialization
  • Loading branch information
scottpurdy authored Jun 16, 2017
2 parents 2fa43ef + 35f0fed commit a82356f
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 29 deletions.
29 changes: 15 additions & 14 deletions src/nupic/algorithms/Cells4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2000,25 +2000,26 @@ void Cells4::write(Cells4Proto::Builder& proto) const
void Cells4::read(Cells4Proto::Reader& proto)
{
NTA_CHECK(proto.getVersion() == 2);
_ownsMemory = proto.getOwnsMemory();

initialize(proto.getNColumns(),
proto.getNCellsPerCol(),
proto.getActivationThreshold(),
proto.getMinThreshold(),
proto.getNewSynapseCount(),
proto.getSegUpdateValidDuration(),
proto.getPermInitial(),
proto.getPermConnected(),
proto.getPermMax(),
proto.getPermDec(),
proto.getPermInc(),
proto.getGlobalDecay(),
proto.getDoPooling(),
proto.getOwnsMemory());
auto randomProto = proto.getRng();
_rng.read(randomProto);
_nColumns = proto.getNColumns();
_nCellsPerCol = proto.getNCellsPerCol();
_activationThreshold = proto.getActivationThreshold();
_minThreshold = proto.getMinThreshold();
_newSynapseCount = proto.getNewSynapseCount();
_nIterations = proto.getNIterations();
_nLrnIterations = proto.getNLrnIterations();
_segUpdateValidDuration = proto.getSegUpdateValidDuration();
_initSegFreq = proto.getInitSegFreq();
_permInitial = proto.getPermInitial();
_permConnected = proto.getPermConnected();
_permMax = proto.getPermMax();
_permDec = proto.getPermDec();
_permInc = proto.getPermInc();
_globalDecay = proto.getGlobalDecay();
_doPooling = proto.getDoPooling();
_pamLength = proto.getPamLength();
_maxInfBacktrack = proto.getMaxInfBacktrack();
_maxLrnBacktrack = proto.getMaxLrnBacktrack();
Expand Down
6 changes: 3 additions & 3 deletions src/nupic/algorithms/Cells4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ namespace nupic {
UInt getMaxSeqLength() const { return _maxSeqLength;}
Real getAvgLearnedSeqLength() const { return _avgLearnedSeqLength;}
UInt getNLrnIterations() const { return _nLrnIterations;}
Int getmaxSegmentsPerCell() const { return _maxSegmentsPerCell;}
Int getMaxSynapsesPerCell() const { return _maxSynapsesPerSegment;}
bool getCheckSynapseConsistency() { return _checkSynapseConsistency;}
Int getMaxSegmentsPerCell() const { return _maxSegmentsPerCell;}
Int getMaxSynapsesPerSegment() const { return _maxSynapsesPerSegment;}
bool getCheckSynapseConsistency() const { return _checkSynapseConsistency;}


//----------------------------------------------------------------------
Expand Down
31 changes: 29 additions & 2 deletions src/nupic/bindings/algorithms.i
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ try:
except ImportError:
capnp = None
else:
from nupic.proto.SpatialPoolerProto_capnp import SpatialPoolerProto
from nupic.proto.Cells4_capnp import Cells4Proto
from nupic.proto.ClaClassifier_capnp import ClaClassifierProto
from nupic.proto.SdrClassifier_capnp import SdrClassifierProto
from nupic.proto.ConnectionsProto_capnp import ConnectionsProto
from nupic.proto.SdrClassifier_capnp import SdrClassifierProto
from nupic.proto.SpatialPoolerProto_capnp import SpatialPoolerProto
from nupic.proto.TemporalMemoryProto_capnp import TemporalMemoryProto


Expand Down Expand Up @@ -97,6 +98,7 @@ _ALGORITHMS = _algorithms
#include <nupic/algorithms/OutSynapse.hpp>
#include <nupic/algorithms/SegmentUpdate.hpp>

#include <nupic/proto/Cells4.capnp.h>
#include <nupic/proto/ConnectionsProto.capnp.h>
#include <nupic/proto/SpatialPoolerProto.capnp.h>
#include <nupic/proto/TemporalMemoryProto.capnp.h>
Expand Down Expand Up @@ -804,8 +806,33 @@ void forceRetentionOfImageSensorLiteLibrary(void) {
def __setstate__(self, inString):
self.this = _ALGORITHMS.new_Cells4()
self.loadFromString(inString)

@classmethod
def read(cls, proto):
instance = cls()
instance._initFromCapnpPyBytes(proto.as_builder().to_bytes()) # copy * 2
return instance

def write(self, pyBuilder):
"""Serialize the Cells4 instance using capnp.
:param: Destination Cells4Proto message builder
"""
reader = Cells4Proto.from_bytes(self._writeAsCapnpPyBytes()) # copy
pyBuilder.from_dict(reader.to_dict()) # copy

%}

inline PyObject* _writeAsCapnpPyBytes() const
{
return nupic::PyCapnpHelper::writeAsPyBytes(*self);
}

inline void _initFromCapnpPyBytes(PyObject* pyBytes)
{
nupic::PyCapnpHelper::initFromPyBytes(*self, pyBytes);
}

void loadFromString(const std::string& inString)
{
std::istringstream inStream(inString);
Expand Down
77 changes: 67 additions & 10 deletions src/test/unit/algorithms/Cells4Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,74 @@ std::vector<UInt> _getOrderedSynapseIndexesForSrcCells(const Segment& segment,
}


/**
* Simple comparison function that does the easy checks. It can be expanded to
* cover more of the attributes of Cells4 in the future.
*/
bool checkCells4Attributes(const Cells4& c1, const Cells4& c2)
{
if (c1.nSegments() != c2.nSegments() ||
c1.nCells() != c2.nCells() ||
c1.nColumns() != c2.nColumns() ||
c1.nCellsPerCol() != c2.nCellsPerCol() ||
c1.getMinThreshold() != c2.getMinThreshold() ||
c1.getPermConnected() != c2.getPermConnected() ||
c1.getVerbosity() != c2.getVerbosity() ||
c1.getMaxAge() != c2.getMaxAge() ||
c1.getPamLength() != c2.getPamLength() ||
c1.getMaxInfBacktrack() != c2.getMaxInfBacktrack() ||
c1.getMaxLrnBacktrack() != c2.getMaxLrnBacktrack() ||

c1.getPamCounter() != c2.getPamCounter() ||
c1.getMaxSeqLength() != c2.getMaxSeqLength() ||
c1.getAvgLearnedSeqLength() != c2.getAvgLearnedSeqLength() ||
c1.getNLrnIterations() != c2.getNLrnIterations() ||

c1.getMaxSegmentsPerCell() != c2.getMaxSegmentsPerCell() ||
c1.getMaxSynapsesPerSegment() != c2.getMaxSynapsesPerSegment() ||
c1.getCheckSynapseConsistency() != c2.getCheckSynapseConsistency())
{
return false;
}
return true;
}


TEST(Cells4Test, capnpSerialization)
{
Cells4 cells(
10, 2, 1, 1, 1, 1, 0.5, 0.8, 1, 0.1, 0.1, 0, false, -1, true, false);
std::vector<Real> input(10, 0.0);
input[1] = 1.0;
input[4] = 1.0;
input[5] = 1.0;
input[9] = 1.0;
std::vector<Real> input1(10, 0.0);
input1[1] = 1.0;
input1[4] = 1.0;
input1[5] = 1.0;
input1[9] = 1.0;
std::vector<Real> input2(10, 0.0);
input2[0] = 1.0;
input2[2] = 1.0;
input2[5] = 1.0;
input2[6] = 1.0;
std::vector<Real> input3(10, 0.0);
input3[1] = 1.0;
input3[3] = 1.0;
input3[6] = 1.0;
input3[7] = 1.0;
std::vector<Real> input4(10, 0.0);
input4[2] = 1.0;
input4[4] = 1.0;
input4[7] = 1.0;
input4[8] = 1.0;
std::vector<Real> output(10*2);
cells.compute(&input.front(), &output.front(), true, true);
for (UInt i = 0; i < 10; ++i)
{
cells.compute(&input1.front(), &output.front(), true, true);
cells.compute(&input2.front(), &output.front(), true, true);
cells.compute(&input3.front(), &output.front(), true, true);
cells.compute(&input4.front(), &output.front(), true, true);
cells.reset();
}

Cells4 secondCells(
10, 2, 1, 1, 1, 1, 0.5, 0.8, 1, 0.1, 0.1, 0, false, -1, true, false);
Cells4 secondCells;
{
capnp::MallocMessageBuilder message1;
Cells4Proto::Builder cells4Builder = message1.initRoot<Cells4Proto>();
Expand All @@ -110,13 +164,16 @@ TEST(Cells4Test, capnpSerialization)
secondCells.read(cells4Reader);
}

NTA_CHECK(checkCells4Attributes(cells, secondCells));

std::vector<Real> secondOutput(10*2);
cells.compute(&input.front(), &output.front(), true, true);
secondCells.compute(&input.front(), &secondOutput.front(), true, true);
cells.compute(&input1.front(), &output.front(), true, true);
secondCells.compute(&input1.front(), &secondOutput.front(), true, true);
for (UInt i = 0; i < 10; ++i)
{
ASSERT_EQ(output[i], secondOutput[i]) << "Outputs differ at index " << i;
}
NTA_CHECK(checkCells4Attributes(cells, secondCells));
}


Expand Down

0 comments on commit a82356f

Please sign in to comment.