Skip to content

Commit

Permalink
Multidimensional learning now works!
Browse files Browse the repository at this point in the history
The actual communication protocal is not as efficient as
it could be, and this commit only allows for one learning
connection at a time.
  • Loading branch information
tcstewar committed Sep 26, 2018
1 parent 0811e2a commit 13c5084
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 75 deletions.
9 changes: 8 additions & 1 deletion nengo_loihi/loihi_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,15 @@ def create_io_snip(self):
# --- generate custom code
# Determine which cores have learning
n_errors = 0
total_error_len = 0
max_error_len = 0
for core in self.board.chips[0].cores: # TODO: don't assume 1 chip
if core.learning_coreid:
error_len = core.groups[0].n // 2
if error_len > max_error_len:
max_error_len = error_len
n_errors += 1
total_error_len += 2 + error_len

n_outputs = 1
probes = []
Expand All @@ -542,6 +548,7 @@ def create_io_snip(self):
code = template.render(
n_outputs=n_outputs,
n_errors=n_errors,
max_error_len=max_error_len,
cores=cores,
probes=probes,
)
Expand All @@ -568,7 +575,7 @@ def create_io_snip(self):
phase="preLearnMgmt",
)

size = self.snip_max_spikes_per_step * 2 + 1 + n_errors*2
size = self.snip_max_spikes_per_step * 2 + 1 + total_error_len
logger.debug("Creating nengo_io_h2c channel")
self.nengo_io_h2c = self.n2board.createChannel(b'nengo_io_h2c',
"int", size)
Expand Down
6 changes: 3 additions & 3 deletions nengo_loihi/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ def handle_host2chip_communications(self): # noqa: C901
for sender, receiver in self.host2chip_senders.items():
if isinstance(receiver, splitter.PESModulatoryTarget):
for t, x in sender.queue:
x = int(100 * x) # >128 is an issue on chip
x = (100*x).astype(int)
x = np.clip(x, -100, 100)
probe = receiver.target
conn = self.model.probe_conns[probe]
dec_cx = self.model.objs[conn]['decoded']
Expand All @@ -510,7 +511,7 @@ def handle_host2chip_communications(self): # noqa: C901

assert coreid is not None

errors.append([coreid, x])
errors.append([coreid, len(x)]+x.tolist())
del sender.queue[:]

else:
Expand Down Expand Up @@ -542,7 +543,6 @@ def handle_host2chip_communications(self): # noqa: C901
assert spike[0] == 0
msg.extend(spike[1:3])
for error in errors:
assert len(error) == 2
msg.extend(error)
self.loihi.nengo_io_h2c.write(len(msg), msg)

Expand Down
17 changes: 12 additions & 5 deletions nengo_loihi/snips/nengo_io.c.template
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#define N_OUTPUTS {{ n_outputs }}
#define N_ERRORS {{ n_errors }}
#define MAX_ERROR_LEN {{ max_error_len }}

int guard_io(runState *s) {
return 1;
Expand All @@ -18,7 +19,9 @@ void nengo_io(runState *s) {
int outChannel = getChannelID("nengo_io_c2h");
int32_t count[1];
int32_t spike[2];
int32_t error[2];
int32_t error_info[2];
int32_t error_data[MAX_ERROR_LEN];
int32_t error_index = 0;
int32_t output[N_OUTPUTS];

if (inChannel == -1 || outChannel == -1) {
Expand All @@ -42,10 +45,14 @@ void nengo_io(runState *s) {

// Communicate with learning snip
for (int i=0; i < N_ERRORS; i++) {
readChannel(inChannel, error, 2);
// printf("send error %d.%d\n", error[0], error[1]);
s->userData[0] = error[0];
s->userData[1] = error[1];
readChannel(inChannel, error_info, 2);
readChannel(inChannel, error_data, error_info[1]);
s->userData[error_index] = error_info[0];
s->userData[error_index + 1] = error_info[1];
for (int j=0; j < error_info[1]; j++) {
s->userData[error_index + 2 + j] = error_data[j];
}
error_index += 2 + error_info[1];
}

output[0] = s->time;
Expand Down
117 changes: 62 additions & 55 deletions nengo_loihi/snips/nengo_learn.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,75 @@
#include <string.h>
#include "nengo_learn.h"

#define N_ERRORS 1

int guard_learn(runState *s) {
return 1;
}

void nengo_learn(runState *s) {
int core = s->userData[0];
int error = (signed char) s->userData[1];
int offset = 0;
for (int error_index=0; error_index < N_ERRORS; error_index++) {
int core = s->userData[offset];
int n_vals = s->userData[offset+1];
for (int i=0; i < n_vals; i++) {
int error = (signed char) s->userData[offset+2+i];

NeuronCore *neuron;
neuron = NEURON_PTR((CoreId) {.id=core});
NeuronCore *neuron;
neuron = NEURON_PTR((CoreId) {.id=core});

int cx_idx = 0;
int cx_idx = i;

if (error > 0) {
neuron->stdp_post_state[cx_idx] = \
(PostTraceEntry) {
.Yspike0 = 0,
.Yspike1 = 0,
.Yspike2 = 0,
.Yepoch0 = abs(error),
.Yepoch1 = 0,
.Yepoch2 = 0,
.Tspike = 0,
.TraceProfile = 3,
.StdpProfile = 1
};
neuron->stdp_post_state[cx_idx+1] = \
(PostTraceEntry) {
.Yspike0 = 0,
.Yspike1 = 0,
.Yspike2 = 0,
.Yepoch0 = abs(error),
.Yepoch1 = 0,
.Yepoch2 = 0,
.Tspike = 0,
.TraceProfile = 3,
.StdpProfile = 0
};
} else {
neuron->stdp_post_state[cx_idx] = \
(PostTraceEntry) {
.Yspike0 = 0,
.Yspike1 = 0,
.Yspike2 = 0,
.Yepoch0 = abs(error),
.Yepoch1 = 0,
.Yepoch2 = 0,
.Tspike = 0,
.TraceProfile = 3,
.StdpProfile = 0
};
neuron->stdp_post_state[cx_idx+1] = \
(PostTraceEntry) {
.Yspike0 = 0,
.Yspike1 = 0,
.Yspike2 = 0,
.Yepoch0 = abs(error),
.Yepoch1 = 0,
.Yepoch2 = 0,
.Tspike = 0,
.TraceProfile = 3,
.StdpProfile = 1
};
if (error > 0) {
neuron->stdp_post_state[cx_idx] = \
(PostTraceEntry) {
.Yspike0 = 0,
.Yspike1 = 0,
.Yspike2 = 0,
.Yepoch0 = abs(error),
.Yepoch1 = 0,
.Yepoch2 = 0,
.Tspike = 0,
.TraceProfile = 3,
.StdpProfile = 1
};
neuron->stdp_post_state[cx_idx+n_vals] = \
(PostTraceEntry) {
.Yspike0 = 0,
.Yspike1 = 0,
.Yspike2 = 0,
.Yepoch0 = abs(error),
.Yepoch1 = 0,
.Yepoch2 = 0,
.Tspike = 0,
.TraceProfile = 3,
.StdpProfile = 0
};
} else {
neuron->stdp_post_state[cx_idx] = \
(PostTraceEntry) {
.Yspike0 = 0,
.Yspike1 = 0,
.Yspike2 = 0,
.Yepoch0 = abs(error),
.Yepoch1 = 0,
.Yepoch2 = 0,
.Tspike = 0,
.TraceProfile = 3,
.StdpProfile = 0
};
neuron->stdp_post_state[cx_idx+n_vals] = \
(PostTraceEntry) {
.Yspike0 = 0,
.Yspike1 = 0,
.Yspike2 = 0,
.Yepoch0 = abs(error),
.Yepoch1 = 0,
.Yepoch2 = 0,
.Tspike = 0,
.TraceProfile = 3,
.StdpProfile = 1
};
}
}
}
21 changes: 11 additions & 10 deletions nengo_loihi/tests/test_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,25 @@
import pytest


@pytest.mark.hang
@pytest.mark.parametrize('N', [100, 300])
def test_pes_comm_channel(allclose, Simulator, seed, plt, N):
input_fn = lambda t: np.sin(t*2*np.pi)
@pytest.mark.parametrize('N', [400, 600])
@pytest.mark.parametrize('D', [1, 3])
def test_pes_comm_channel(allclose, Simulator, seed, plt, N, D):
scale = np.linspace(1, 0, D+1)[:-1]
input_fn = lambda t: np.sin(t*2*np.pi)*scale

with nengo.Network(seed=seed) as model:
stim = nengo.Node(input_fn)

a = nengo.Ensemble(N, 1)
a = nengo.Ensemble(N, D)

b = nengo.Node(None, size_in=1, size_out=1)
b = nengo.Node(None, size_in=D)

nengo.Connection(stim, a, synapse=None)
conn = nengo.Connection(
a, b, function=lambda x: 0, synapse=0.01,
a, b, function=lambda x: [0]*D, synapse=0.01,
learning_rule_type=nengo.PES(learning_rate=1e-3))

error = nengo.Node(None, size_in=1)
error = nengo.Node(None, size_in=D)
nengo.Connection(b, error)
nengo.Connection(stim, error, transform=-1)
nengo.Connection(error, conn.learning_rule)
Expand All @@ -40,7 +41,7 @@ def test_pes_comm_channel(allclose, Simulator, seed, plt, N):

# --- fit input_fn to output, determine magnitude
# The larger the magnitude is, the closer the output is to the input
x = input_fn(t)[t > 4]
x = np.array([input_fn(tt)[0] for tt in t[t>4]])
y = sim.data[p_b][t > 4][:, 0]
m = np.linspace(0, 1, 21)
errors = np.abs(y - m[:, None]*x).mean(axis=1)
Expand All @@ -54,4 +55,4 @@ def test_pes_comm_channel(allclose, Simulator, seed, plt, N):
assert allclose(
sim.data[p_a][t > 0.1], sim.data[p_stim][t > 0.1], atol=0.2, rtol=0.2)
assert errors.min() < 0.3, "Not able to fit correctly"
assert m_best > (0.3 if N < 150 else 0.6)
assert m_best > (0.3 if N/D < 150 else 0.6)
2 changes: 1 addition & 1 deletion nengo_loihi/tests/test_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@pytest.mark.parametrize("pre_dims", [1, 3])
@pytest.mark.parametrize("post_dims", [1])
@pytest.mark.parametrize("post_dims", [1, 3])
@pytest.mark.parametrize("learn", [True, False])
@pytest.mark.parametrize("use_solver", [True, False])
def test_manual_decoders(
Expand Down

0 comments on commit 13c5084

Please sign in to comment.