Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binary spike encoding #158

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions nengo_loihi/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import nengo.utils.numpy as npext

from nengo_loihi import conv
from nengo_loihi.encoder import BinaryEncoder
from nengo_loihi.loihi_cx import (
ChipReceiveNeurons,
ChipReceiveNode,
Expand Down Expand Up @@ -364,7 +365,7 @@ def build_ensemble(model, ens):
bias=bias)


def build_interencoders(model, ens):
def build_onoff_interencoders(model, ens):
"""Build encoders accepting on/off interneuron input."""
group = model.objs[ens.neurons]['in']
scaled_encoders = model.params[ens].scaled_encoders
Expand All @@ -376,6 +377,19 @@ def build_interencoders(model, ens):
group.add_synapses(synapses, name='inter_encoders')


def build_binary_interencoders(model, ens):
"""Build encoders accepting binary-coded input."""
group = model.objs[ens.neurons]['in']
scaled_encoders = model.params[ens].scaled_encoders

weights = BinaryEncoder().make_weights(
scaled_encoders * model.inter_scale)

synapses = CxSynapses(weights.shape[0], label="inter_encoders")
synapses.set_full_weights(weights)
group.add_synapses(synapses, name='inter_encoders')


@Builder.register(nengo.neurons.NeuronType)
def build_neurons(model, neurontype, neurons, group):
# If we haven't registered a builder for a specific type, then it cannot
Expand Down Expand Up @@ -549,6 +563,8 @@ def build_connection(model, conn):
assert transform.ndim == 2, "transform shape not handled yet"
assert transform.shape[1] == conn.pre.size_out

# TODO: the weights aren't used anywhere if post is an Ensemble
# in vector space? handled by build_interencoders instead.
assert transform.shape[1] == conn.pre.size_out
if isinstance(conn.pre_obj, ChipReceiveNeurons):
weights = transform / model.dt
Expand Down Expand Up @@ -725,9 +741,21 @@ def build_connection(model, conn):

if conn.learning_rule_type is not None:
raise NotImplementedError()
elif isinstance(conn.post_obj, Ensemble) and isinstance(conn.pre_obj, Node):
# TODO: this shouldn't be a special case
if 'inter_encoders' not in post_cx.named_synapses:
build_binary_interencoders(model, conn.post_obj)

mid_ax = CxAxons(mid_cx.n, label="encoders")
mid_ax.target = post_cx.named_synapses['inter_encoders']
mid_ax.set_axon_map(mid_axon_inds)
mid_cx.add_axons(mid_ax)
model.objs[conn]['mid_axons'] = mid_ax

post_cx.configure_filter(post_tau, dt=model.dt)
elif isinstance(conn.post_obj, Ensemble):
if 'inter_encoders' not in post_cx.named_synapses:
build_interencoders(model, conn.post_obj)
build_onoff_interencoders(model, conn.post_obj)

mid_ax = CxAxons(mid_cx.n, label="encoders")
mid_ax.target = post_cx.named_synapses['inter_encoders']
Expand Down
39 changes: 39 additions & 0 deletions nengo_loihi/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np


class BinaryEncoder(object):
"""Node function for encoding a (-1, 1) vector in binary."""

def __init__(self, n_bits=8):
self.n_bits = n_bits

def get_size_out(self, d):
return 2 * self.n_bits * d

def __call__(self, dummy_time, x):
spiked = np.zeros((2, self.n_bits, len(x)), dtype=bool)
for i, x_i in enumerate(x):
sign_bit = 0 if x_i >= 0 else 1
f = np.abs(x_i)
v = 0.5 # to represent [0, 1)
for j in range(self.n_bits):
if f >= v:
f -= v
spiked[sign_bit, j, i] = True
v /= 2.
return spiked.flatten()

def make_weights(self, encoders):
weights = np.zeros(
(2, self.n_bits, encoders.shape[1], encoders.shape[0]))

for i, sign in enumerate([+1, -1]):
for bit in range(self.n_bits):
weights[i, bit, :, :] = sign / 2.**(1 + bit) * encoders.T

flat_weights = weights.reshape(-1, weights.shape[-1])
assert flat_weights.shape == (
self.get_size_out(encoders.shape[1]),
encoders.shape[0])

return flat_weights
50 changes: 26 additions & 24 deletions nengo_loihi/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

from nengo_loihi.conv import Conv2D
from nengo_loihi.encoder import BinaryEncoder
from nengo_loihi.loihi_cx import (
ChipReceiveNode, ChipReceiveNeurons, HostSendNode, HostReceiveNode,
PESModulatoryTarget)
Expand Down Expand Up @@ -279,24 +280,23 @@ def split_host_neurons_to_chip(networks, conn):

def split_host_to_chip(networks, conn):
dim = conn.size_out
encoder = BinaryEncoder()
size_enc = encoder.get_size_out(dim)

logger.debug("Creating ChipReceiveNode for %s", conn)
receive = ChipReceiveNode(
dim * 2, size_out=dim, add_to_container=False)
size_enc, size_out=dim, add_to_container=False)
networks.add(receive, "chip")
receive2post = nengo.Connection(receive, conn.post,
synapse=networks.inter_tau,
synapse=None,
add_to_container=False)
networks.add(receive2post, "chip")

logger.debug("Creating NIF ensemble for %s", conn)
ens = nengo.Ensemble(
2 * dim, dim,
neuron_type=NIF(tau_ref=0.0),
encoders=np.vstack([np.eye(dim), -np.eye(dim)]),
max_rates=np.ones(dim * 2) * networks.max_rate,
intercepts=np.ones(dim * 2) * -1,
add_to_container=False)
networks.add(ens, "host")
logger.debug("Creating spike-encoder for %s", conn)
spikes = nengo.Node(size_in=dim,
output=encoder,
add_to_container=False)
networks.add(spikes, "host")

if isinstance(conn.transform, Conv2D):
raise BuildError(
Expand All @@ -313,22 +313,24 @@ def split_host_to_chip(networks, conn):
rng=np.random.RandomState(seed=seed))
if isinstance(conn.post_obj, nengo.Ensemble):
transform = transform / conn.post_obj.radius
pre2ens = nengo.Connection(conn.pre, ens,
function=conn.function,
solver=conn.solver,
eval_points=conn.eval_points,
scale_eval_points=conn.scale_eval_points,
synapse=conn.synapse,
transform=transform,
add_to_container=False)
networks.add(pre2ens, "host")

pre2encoder = nengo.Connection(
conn.pre, spikes,
function=conn.function,
solver=conn.solver,
eval_points=conn.eval_points,
scale_eval_points=conn.scale_eval_points,
synapse=conn.synapse,
transform=transform,
add_to_container=False)
networks.add(pre2encoder, "host")

logger.debug("Creating HostSendNode for %s", conn)
send = HostSendNode(dim * 2, add_to_container=False)
send = HostSendNode(size_enc, add_to_container=False)
networks.add(send, "host")
ensneurons2send = nengo.Connection(
ens.neurons, send, synapse=None, add_to_container=False)
networks.add(ensneurons2send, "host")
spikes2send = nengo.Connection(
spikes, send, synapse=None, add_to_container=False)
networks.add(spikes2send, "host")
networks.remove(conn)

networks.host2chip_senders[send] = receive
Expand Down
32 changes: 32 additions & 0 deletions nengo_loihi/tests/test_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import nengo
import numpy as np

from nengo_loihi.encoder import BinaryEncoder


@pytest.mark.parametrize("n_bits", [1, 2, 4, 8, 16, 32])
def test_binary_encoder(n_bits):
encoder = BinaryEncoder(n_bits)
weights = encoder.make_weights(np.ones((1, 1)))

with nengo.Network() as model:
u = nengo.Node(output=lambda t: np.sin(2*np.pi*t))
spikes = nengo.Node(size_in=1, output=encoder)
y = nengo.Node(size_in=1)

nengo.Connection(u, spikes, synapse=None)
nengo.Connection(spikes, y, transform=weights.T, synapse=None)

p_y = nengo.Probe(y, synapse=None)
p_ideal = nengo.Probe(u, synapse=None)

with nengo.Simulator(model) as sim:
sim.run(1.0)

# ensure that the input does not differ from its
# binary-encoded + weighted version by more than 2^(-n_bits)
# (i.e., all of the error is from quantization)
abstol = 2.**(-n_bits)
error = sim.data[p_y] - sim.data[p_ideal]
assert np.all(np.abs(error) <= abstol)