diff --git a/nengo_loihi/tests/test_connection.py b/nengo_loihi/tests/test_connection.py index 3316ecf04..428da9001 100644 --- a/nengo_loihi/tests/test_connection.py +++ b/nengo_loihi/tests/test_connection.py @@ -397,26 +397,27 @@ def test_input_node(allclose, Simulator, val, type): ) def test_ens2node(allclose, Simulator, seed, plt, pre_d, post_d, func): simtime = 0.5 + data = [] + + def conn_fn(x): + return -x + + def output_fn(t, x): + data.append(x.copy()) + with nengo.Network(seed=seed) as model: stim = nengo.Node(lambda t: [np.sin(t * 2 * np.pi / simtime)] * pre_d) a = nengo.Ensemble(100, pre_d) - nengo.Connection(stim, a) - data = [] - output = nengo.Node(lambda t, x: data.append(x), size_in=post_d, size_out=0) + output = nengo.Node(output_fn, size_in=post_d, size_out=0) transform = np.identity(max(pre_d, post_d)) transform = transform[:post_d, :pre_d] - if func: - - def conn_func(x): - return -x - - else: - conn_func = None - nengo.Connection(a, output, transform=transform, function=conn_func) + nengo.Connection( + a, output, transform=transform, function=conn_fn if func else None + ) p_stim = nengo.Probe(stim)