diff --git a/nengo_loihi/loihi_cx.py b/nengo_loihi/loihi_cx.py index 920fae10b..8da878e87 100644 --- a/nengo_loihi/loihi_cx.py +++ b/nengo_loihi/loihi_cx.py @@ -365,6 +365,8 @@ def discretize(target, value): # discretize tracing mag into integer and fractional components mag_int, mag_frac = tracing_mag_int_frac(synapse.tracing_mag) if mag_int > 127: + warnings.warn("Trace increment exceeds upper limit " + "(learning rate may be too large)") mag_int = 127 mag_frac = 127 synapse.tracing_mag = mag_int + mag_frac / 128. diff --git a/nengo_loihi/tests/test_learning.py b/nengo_loihi/tests/test_learning.py index 09360144f..39745e1ca 100644 --- a/nengo_loihi/tests/test_learning.py +++ b/nengo_loihi/tests/test_learning.py @@ -1,16 +1,29 @@ import nengo +from nengo.exceptions import ValidationError, SimulationError +from nengo.utils.numpy import rms import numpy as np import pytest -from nengo.exceptions import ValidationError +import nengo_loihi.builder -@pytest.mark.parametrize('n_per_dim', [120, 200]) -@pytest.mark.parametrize('dims', [1, 3]) -def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims): - scale = np.linspace(1, 0, dims + 1)[:-1] - input_fn = lambda t: np.sin(t * 2 * np.pi) * scale - tau = 0.01 +def pes_network( + n_per_dim, + dims, + seed, + learning_rule_type=nengo.PES(learning_rate=1e-3), + input_scale=None, + error_scale=1., + learn_synapse=0.005, + probe_synapse=0.02, +): + if input_scale is None: + input_scale = np.linspace(1, 0, dims + 1)[:-1] + assert input_scale.size == dims + + input_fn = lambda t: np.sin(t * 2 * np.pi) * input_scale + + probes = {} with nengo.Network(seed=seed) as model: stim = nengo.Node(input_fn) @@ -21,15 +34,24 @@ def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims): conn = nengo.Connection( pre, post, function=lambda x: np.zeros(dims), - synapse=tau, - learning_rule_type=nengo.PES(learning_rate=1e-3)) + synapse=learn_synapse, + learning_rule_type=learning_rule_type) - nengo.Connection(post, conn.learning_rule) - nengo.Connection(stim, conn.learning_rule, transform=-1) + nengo.Connection(post, conn.learning_rule, transform=error_scale) + nengo.Connection(stim, conn.learning_rule, transform=-error_scale) + + probes['stim'] = nengo.Probe(stim, synapse=probe_synapse) + probes['pre'] = nengo.Probe(pre, synapse=probe_synapse) + probes['post'] = nengo.Probe(post, synapse=probe_synapse) - p_stim = nengo.Probe(stim, synapse=0.02) - p_pre = nengo.Probe(pre, synapse=0.02) - p_post = nengo.Probe(post, synapse=0.02) + return model, probes + + +@pytest.mark.parametrize('n_per_dim', [120, 200]) +@pytest.mark.parametrize('dims', [1, 3]) +def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims): + tau = 0.01 + model, probes = pes_network(n_per_dim, dims, seed, learn_synapse=tau) simtime = 5.0 with nengo.Simulator(model) as nengo_sim: @@ -38,32 +60,125 @@ def test_pes_comm_channel(allclose, plt, seed, Simulator, n_per_dim, dims): with Simulator(model) as loihi_sim: loihi_sim.run(simtime) + with Simulator(model, target='simreal') as real_sim: + real_sim.run(simtime) + t = nengo_sim.trange() pre_tmask = t > 0.1 post_tmask = t > simtime - 1.0 inter_tau = loihi_sim.model.inter_tau - y = nengo_sim.data[p_stim] + y = nengo_sim.data[probes['stim']] y_dpre = nengo.Lowpass(inter_tau).filt(y) y_dpost = nengo.Lowpass(tau).combine(nengo.Lowpass(inter_tau)).filt(y_dpre) - y_nengo = nengo_sim.data[p_post] - y_loihi = loihi_sim.data[p_post] + y_nengo = nengo_sim.data[probes['post']] + y_loihi = loihi_sim.data[probes['post']] + y_real = real_sim.data[probes['post']] plt.subplot(211) plt.plot(t, y_dpost, 'k', label='target') plt.plot(t, y_nengo, 'b', label='nengo') plt.plot(t, y_loihi, 'g', label='loihi') + plt.plot(t, y_real, 'r:', label='real') + plt.legend() plt.subplot(212) plt.plot(t[post_tmask], y_loihi[post_tmask] - y_dpost[post_tmask], 'k') plt.plot(t[post_tmask], y_loihi[post_tmask] - y_nengo[post_tmask], 'b') - assert allclose(loihi_sim.data[p_pre][pre_tmask], y_dpre[pre_tmask], + x_loihi = loihi_sim.data[probes['pre']] + assert allclose(x_loihi[pre_tmask], y_dpre[pre_tmask], atol=0.1, rtol=0.05) + assert allclose(y_loihi[post_tmask], y_dpost[post_tmask], atol=0.1, rtol=0.05) assert allclose(y_loihi, y_nengo, atol=0.2, rtol=0.2) + assert allclose(y_real[post_tmask], y_dpost[post_tmask], + atol=0.1, rtol=0.05) + assert allclose(y_real, y_nengo, atol=0.2, rtol=0.2) + + +def test_pes_overflow(allclose, plt, seed, Simulator): + dims = 3 + n_per_dim = 120 + tau = 0.01 + model, probes = pes_network(n_per_dim, dims, seed, learn_synapse=tau, + input_scale=np.linspace(1, 0.7, dims)) + + simtime = 3.0 + loihi_model = nengo_loihi.builder.Model() + # set learning_wgt_exp low to create overflow in weight values + loihi_model.pes_wgt_exp = -1 + + with Simulator(model, model=loihi_model) as loihi_sim: + loihi_sim.run(simtime) + + t = loihi_sim.trange() + post_tmask = t > simtime - 1.0 + + inter_tau = loihi_sim.model.inter_tau + y = loihi_sim.data[probes['stim']] + y_dpre = nengo.Lowpass(inter_tau).filt(y) + y_dpost = nengo.Lowpass(tau).combine(nengo.Lowpass(inter_tau)).filt(y_dpre) + y_loihi = loihi_sim.data[probes['post']] + + plt.plot(t, y_dpost, 'k', label='target') + plt.plot(t, y_loihi, 'g', label='loihi') + + # --- fit output to scaled version of target output + z_ref0 = y_dpost[post_tmask][:, 0] + z_loihi = y_loihi[post_tmask] + scale = np.linspace(0, 1, 50) + E = np.abs(z_loihi - scale[:, None, None]*z_ref0[:, None]) + errors = E.mean(axis=1) # average over time (errors is: scales x dims) + for j in range(dims): + errors_j = errors[:, j] + i = np.argmin(errors_j) + assert errors_j[i] < 0.1, ("Learning output for dim %d did not match " + "any scaled version of the target output" + % j) + assert scale[i] > 0.4, "Learning output for dim %d is too small" % j + assert scale[i] < 0.7, ("Learning output for dim %d is too large " + "(weights or traces not clipping as expected)" + % j) + + +def test_pes_error_clip(allclose, plt, seed, Simulator): + dims = 2 + n_per_dim = 120 + tau = 0.01 + error_scale = 5. # scale up error signal so it clips + model, probes = pes_network( + n_per_dim, dims, seed, learn_synapse=tau, + learning_rule_type=nengo.PES(learning_rate=1e-3 / error_scale), + input_scale=np.array([1., -1.]), + error_scale=error_scale) + + simtime = 3.0 + with pytest.warns(UserWarning, match=r'PES error.*clipping'): + with Simulator(model) as loihi_sim: + loihi_sim.run(simtime) + + t = loihi_sim.trange() + post_tmask = t > simtime - 1.0 + + inter_tau = loihi_sim.model.inter_tau + y = loihi_sim.data[probes['stim']] + y_dpre = nengo.Lowpass(inter_tau).filt(y) + y_dpost = nengo.Lowpass(tau).combine(nengo.Lowpass(inter_tau)).filt(y_dpre) + y_loihi = loihi_sim.data[probes['post']] + + plt.plot(t, y_dpost, 'k', label='target') + plt.plot(t, y_loihi, 'g', label='loihi') + + # --- assert that we've learned something, but not everything + error = (rms(y_loihi[post_tmask] - y_dpost[post_tmask]) + / rms(y_dpost[post_tmask])) + assert error < 0.5 + assert error > 0.05 + # ^ error on emulator vs chip is quite different, hence large tolerances + @pytest.mark.parametrize('init_function', [None, lambda x: 0]) def test_multiple_pes(init_function, allclose, plt, seed, Simulator): @@ -114,3 +229,31 @@ def test_pes_pre_synapse_type_error(Simulator): with pytest.raises(ValidationError): with Simulator(model): pass + + +def test_pes_trace_increment_clip_warning(seed, Simulator): + dims = 2 + n_per_dim = 120 + model, _ = pes_network( + n_per_dim, dims, seed, + learning_rule_type=nengo.PES(learning_rate=1e-1)) + + with pytest.warns(UserWarning, match="Trace increment exceeds upper"): + with Simulator(model): + pass + + +def test_drop_trace_spikes(Simulator, seed): + with nengo.Network(seed=seed) as net: + a = nengo.Ensemble(10, 1, gain=nengo.dists.Choice([1]), + bias=nengo.dists.Choice([2000]), + neuron_type=nengo.SpikingRectifiedLinear()) + b = nengo.Node(size_in=1) + + conn = nengo.Connection(a, b, learning_rule_type=nengo.PES(1)) + + nengo.Connection(b, conn.learning_rule) + + with Simulator(net, target="sim") as sim: + with pytest.raises(SimulationError): + sim.run(1.0) diff --git a/nengo_loihi/tests/test_loihi_api.py b/nengo_loihi/tests/test_loihi_api.py index 86065cc05..712f16126 100644 --- a/nengo_loihi/tests/test_loihi_api.py +++ b/nengo_loihi/tests/test_loihi_api.py @@ -2,8 +2,8 @@ import numpy as np import pytest -from nengo_loihi.loihi_api import overflow_signed -from nengo_loihi.loihi_api import decay_int, decay_magnitude +from nengo_loihi.loihi_api import ( + overflow_signed, decay_int, decay_magnitude, SynapseFmt) @pytest.mark.parametrize("b", (8, 16, 17, 23)) @@ -87,3 +87,17 @@ def empirical_decay_magnitude(decay, x0): plt.plot(relative_diff.clip(0, None)) assert np.all(relative_diff < 1e-6) + + +@pytest.mark.parametrize("lossy_shift", (True, False)) +def test_lossy_shift(lossy_shift, rng): + wgt_bits = 6 + w = rng.uniform(-100, 100, size=(10, 10)) + fmt = SynapseFmt(wgtBits=wgt_bits, wgtExp=0, fanoutType=0) + + w2 = fmt.discretize_weights(w, lossy_shift=lossy_shift) + + clipped = np.round(w / 4).clip(-2 ** wgt_bits, 2 ** wgt_bits).astype( + np.int32) + + assert np.allclose(w2, np.left_shift(clipped, 8))