Skip to content

Commit

Permalink
Reduce the perf issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
thorstenhater committed Jan 10, 2024
1 parent 6740990 commit bd7fb76
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 54 deletions.
1 change: 0 additions & 1 deletion lif.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def probes(self, gid):
sim = A.simulation(rec)
sim.record(A.spike_recording.all)
sim.progress_banner()
sim.set_binning_policy(A.binning.regular, dt)
hdl = sim.sample((0, 0), A.regular_schedule(dt)) # gid, off

t0 = pc()
Expand Down
95 changes: 45 additions & 50 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
import numpy as np
import pandas as pd
import numpy.random as rd
from scipy.stats import truncnorm
import arbor as A
from arbor import units as U
from time import perf_counter as pc
from logging import warning
import seaborn as sns
import matplotlib.pyplot as plt


dt = 0.05 # ms
T = 100 # ms
dt = 0.05 * U.ms
T = 100 * U.ms
VERBOSE = False


Expand Down Expand Up @@ -44,16 +45,21 @@ def make_hh(gid):
soma = "(tag 1)"
decor = (
A.decor()
.set_property(Vm=-65)
.set_property(Vm=-65 * U.mV)
.paint(soma, A.density("hh"))
.place(center, A.threshold_detector(-50), "source")
.place(center, A.threshold_detector(-50 * U.mV), "source")
.place(center, A.synapse("expsyn", {"tau": 0.5, "e": 0}), "synapse")
)
return A.cable_cell(tree, decor)


def make_spike_source(gid=0, *, tstart=0, tend=15, f=0.15): # ms, ms, kHz
return A.spike_source_cell("source", A.poisson_schedule(tstart=tstart, freq=f, tstop=tend, seed=gid))
return A.spike_source_cell(
"source",
A.poisson_schedule(
tstart=tstart * U.ms, freq=f * U.kHz, tstop=tend * U.ms, seed=gid
),
)


make_l23e = make_hh
Expand Down Expand Up @@ -103,7 +109,15 @@ def make_spike_source(gid=0, *, tstart=0, tend=15, f=0.15): # ms, ms, kHz
] + ["th"]


class recipe(A.recipe):
# NOTE: We're using the truncated normal distribution here to avoid
# negative delays
def delay(mu, sigma, n):
return truncnorm(
(dt.value - mu) / sigma, (T.value - mu) / sigma, loc=mu, scale=sigma
).rvs(n)


class ucircuit(A.recipe):
def __init__(
self,
*,
Expand Down Expand Up @@ -149,17 +163,7 @@ def __init__(
0.0512,
],
[0.0364, 0.001, 0.0034, 0.0005, 0.0277, 0.008, 0.0658, 0.1443, 0.0196],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
]
)
# Scale weights for HH ./. LIF
Expand All @@ -184,7 +188,7 @@ def __init__(
self.weight_background = 585.39
# Thalamic inputs
self.f_thalamic = 15e-3
self.weight_thalamic = 0 #585.39
self.weight_thalamic = 0 # 585.39
self.delay_thalamic = 1.5
# Record synapse counts for reporting. We'd expect p_s_t*n_s*n_t on
# average for source and target populations.
Expand All @@ -196,29 +200,21 @@ def __init__(
# used.
self.connections = defaultdict(lambda: np.zeros_like(POPS))

def make_connection(self, src, tgt):
def make_connection_parameters(self, src, tgt, n):
# NOTE: The mean weight of the connection from L4E to L23E is doubled
if src == ITH:
w = self.weight_thalamic * self.weight_scale
d = self.delay_thalamic
w = np.ones(n) * self.weight_thalamic * self.weight_scale
d = np.ones(n) * self.delay_thalamic
elif src == I4E and tgt == I23E:
w = rd.normal(2 * self.mean_weight_exc, self.stddev_weight_exc)
d = rd.normal(self.mean_delay_exc, self.stddev_delay_exc)
w = rd.normal(2 * self.mean_weight_exc, self.stddev_weight_exc, n)
d = delay(self.mean_delay_exc, self.stddev_delay_exc, n)
elif src % 2 == 0: # NOTE: all the excitatory ones are even.
w = rd.normal(self.mean_weight_exc, self.stddev_weight_exc)
d = rd.normal(self.mean_delay_exc, self.stddev_delay_exc)
w = rd.normal(self.mean_weight_exc, self.stddev_weight_exc, n)
d = delay(self.mean_delay_exc, self.stddev_delay_exc, n)
else:
w = rd.normal(self.mean_weight_inh, self.stddev_weight_inh)
d = rd.normal(self.mean_delay_inh, self.stddev_delay_inh)
# NOTE: There's a bug on clang (at least on MacOS) that results in
# broken simulations if d < dt, so fix it here. Usually Arbor should do
# this on its own.
if d < dt:
d = dt
warning(
f"Connection {src} -> {tgt} has delay less than dt={dt}, using dt instead."
)
return w, d
w = rd.normal(self.mean_weight_inh, self.stddev_weight_inh, n)
d = delay(self.mean_delay_inh, self.stddev_delay_inh, n)
return w, d * U.ms

def gid_to_pop(self, gid):
# return the first IDX where our GID is less than POP[IDX+1]
Expand Down Expand Up @@ -254,21 +250,21 @@ def connections_on(self, tgt):
tgt_pop = self.gid_to_pop(tgt)
# Scan all Population types
for src_pop in POPS:
n = 0
p = self.connection_probability[tgt_pop][src_pop]
n_src = self.size[src_pop]
# Generate list of connection srcs
srcs = np.argwhere(rd.random(n_src) < p)
ws, ds = self.make_connection_parameters(src_pop, tgt_pop, srcs.size)
# Now reify all those into connection objects
# NOTE: We are simply skipping self connections here, but maybe
# we need to re-draw those?
for src in srcs:
if src == tgt:
continue
w, d = self.make_connection(src_pop, tgt_pop)
res.append(A.connection((src, "source"), "synapse", w, d))
n += 1
self.connections[src_pop][tgt_pop] += n
old = len(res)
res += [
A.connection((src, "source"), "synapse", w, d)
for (src, w, d) in zip(srcs, ws, ds)
if src != tgt
]
self.connections[src_pop][tgt_pop] += len(res) - old
return res

def event_generators(self, gid):
Expand All @@ -282,26 +278,25 @@ def event_generators(self, gid):
A.event_generator(
"synapse",
self.weight_background * self.weight_scale,
A.poisson_schedule(tstart=0.0, freq=f, seed=gid),
A.poisson_schedule(tstart=0.0 * U.ms, freq=f * U.kHz, seed=gid),
)
]


rec = recipe(
rec = ucircuit(
l23=(20683, 5834), # exc, inh
l4=(21915, 5479),
l5=(4850, 1065),
l6=(14395, 2948),
nth=902,
scale=0.1,
scale=0.01,
w_scale=5e-6,
)

ctx = A.context(threads=8)
sim = A.simulation(rec, ctx)
sim.record(A.spike_recording.all)
sim.progress_banner()
sim.set_binning_policy(A.binning.regular, dt)

banner()
print(f"Set up the simulation, total cells N={rec.N}")
Expand Down Expand Up @@ -354,6 +349,6 @@ def event_generators(self, gid):
ax.set_xlabel("Time $(t/ms)$")
ax.set_ylabel("GID")
ax.set_ylim(0, rec.N)
ax.set_xlim(0, T)
ax.set_xlim(0, T.value)
fg.savefig("main-spikes.pdf")
fg.savefig("main-spikes.png")
3 changes: 0 additions & 3 deletions single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@


def make_hh():
# TODO figure out HH parameters
# TODO figure out cell geometry
tree = A.segment_tree()
tree.append(A.mnpos, A.mpoint(-3, 0, 0, 3), A.mpoint(3, 0, 0, 3), tag=1)
center = "(location 0 0.5)"
Expand Down Expand Up @@ -71,7 +69,6 @@ def probes(self, gid):
sim = A.simulation(rec)
sim.record(A.spike_recording.all)
sim.progress_banner()
sim.set_binning_policy(A.binning.regular, dt)
hdl = sim.sample((0, 0), A.regular_schedule(dt)) # gid, off

t0 = pc()
Expand Down

0 comments on commit bd7fb76

Please sign in to comment.