diff --git a/nengo_loihi/tests/test_learning.py b/nengo_loihi/tests/test_learning.py index 88cf0afc7..046a8e549 100644 --- a/nengo_loihi/tests/test_learning.py +++ b/nengo_loihi/tests/test_learning.py @@ -56,3 +56,37 @@ def test_pes_comm_channel(allclose, Simulator, seed, plt, N, D): sim.data[p_a][t > 0.1], sim.data[p_stim][t > 0.1], atol=0.2, rtol=0.2) assert errors.min() < 0.3, "Not able to fit correctly" assert m_best > (0.3 if N/D < 150 else 0.6) + +def test_multiple_pes(allclose, Simulator, seed): + n_errors = 5 + targets = np.linspace(-1, 1, n_errors) + with nengo.Network(seed=seed) as model: + a = nengo.networks.EnsembleArray(200, n_ensembles=n_errors) + errors = nengo.Node(None, size_in=n_errors) + output = nengo.Node(None, size_in=n_errors) + + + target = nengo.Node(targets) + nengo.Connection(target, errors, transform=-1) + nengo.Connection(output, errors) + + for i in range(n_errors): + c = nengo.Connection( + a.ea_ensembles[i], + output[i], + function=lambda x: 0, + learning_rule_type=nengo.PES(learning_rate=1e-3), + ) + nengo.Connection(errors[i], c.learning_rule) + + p = nengo.Probe(output, synapse=0.1) + with Simulator(model, precompute=False) as sim: + sim.run(1.0) + + # TODO: these values should converge to "targets", but they don't + # don't seem to right now. I'm not sure why, but adjusting the + # learning rate and the run time does affect it. + print(sim.data[p][-100:]) + + +