Skip to content

Commit

Permalink
Merge pull request #223 from jhnnsnk/corrcoef_fct
Browse files Browse the repository at this point in the history
Correlation function
  • Loading branch information
jhnnsnk authored Aug 15, 2024
2 parents d1a04e4 + 6a777e9 commit ed1baa4
Show file tree
Hide file tree
Showing 9 changed files with 497 additions and 21 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies:
- sphinxcontrib-bibtex
- sphinx-tabs
- sphinx-gallery
- joblib
- pip:
- nnmt
- parameters
Expand Down
4 changes: 2 additions & 2 deletions mesocircuit/mesocircuit_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, name_exp='base', custom_params=None, data_dir=None,
if not os.path.isdir(self.data_dir_exp):
print(f'Creating directory: {self.data_dir_exp}')
os.makedirs(self.data_dir_exp)

print(f'Data directory: {self.data_dir_exp}')

if not load:
Expand Down Expand Up @@ -595,7 +595,7 @@ def _write_jobscripts(self, paramset, path):
"LD_PRELOAD skipped because jemalloc is not in PATH.")

if name in ['lfp_simulation', 'lfp_postprocess', 'lfp_plotting']:
run_cmd = f'srun --mpi=pmi2'
run_cmd = f'srun'
else:
run_cmd = f'srun --cpus-per-task={dic["local_num_threads"]} --threads-per-core=1 --cpu-bind=rank'

Expand Down
2 changes: 1 addition & 1 deletion mesocircuit/parameterization/base_analysis_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
# if 'auto': the population size of the smallest population is taken.
# if the given number is higher than the smallest population size, the
# latter is also assumed.
'ccs_num_neurons': 200,
'ccs_num_neurons': 512,
# time interval for computing correlation coefficients (in ms).
# a good choice is equal to the refractory time.
# it can also be an iterable list of time intervals, e.g., [2., 50., 200.]
Expand Down
12 changes: 12 additions & 0 deletions mesocircuit/plotting/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def raster(circuit, all_sptrains, all_pos_sorting_arrays):
left = 0.17
right = 0.92

# restrict maximum time interval
diff_time_interval = time_interval[1] - time_interval[0]
max_time_interval = 10000 # maximum interval in ms
if diff_time_interval > max_time_interval:
time_interval[1] = time_interval[0] + max_time_interval

print(f'Plotting spike raster for interval: {time_interval} ms')

# automatically compute a sample step for this figure
Expand Down Expand Up @@ -256,6 +262,12 @@ def instantaneous_firing_rates(circuit, all_sptrains_bintime):
left = 0.17
right = 0.92

# restrict maximum time interval
diff_time_interval = time_interval[1] - time_interval[0]
max_time_interval = 10000 # maximum interval in ms
if diff_time_interval > max_time_interval:
time_interval[1] = time_interval[0] + max_time_interval

print(
f'Plotting instantaneous firing rates for interval: {time_interval} ms')

Expand Down
95 changes: 94 additions & 1 deletion mesocircuit/plotting/ms_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,14 @@ def reference_vs_upscaled(output_dir, ref_circuit, ups_circuit,
titles = ['reference model, ' + r'1 mm$^2$',
'upscaled model, ' + r'1 mm$^2$ sampled']

interval = ref_circuit.ana_dict['ccs_time_interval']

for i, prefix in enumerate(['ref', 'ups']):
all_CCs = {}
all_CCs_distances = d[prefix + '_all_CCs_distances']
for X in all_CCs_distances:
if isinstance(all_CCs_distances[X], h5py._hl.group.Group):
all_CCs[X] = all_CCs_distances[X]['ccs']
all_CCs[X] = all_CCs_distances[X][f'ccs_{interval}ms']
else:
all_CCs[X] = np.array([])

Expand All @@ -232,6 +234,97 @@ def reference_vs_upscaled(output_dir, ref_circuit, ups_circuit,
axes[4].set_title(titles[i], pad=15)

plt.savefig(os.path.join(output_dir, 'rev_vs_ups_statistics.pdf'))

return


def correlation(output_dir, circuit):
"""
Figure of correlation structure.
Parameters
----------
output_dir
Output directory.
circuit
Mesocircuit instance.
"""
# load data
d = {}
for all_datatype in ['all_CCs_distances', 'all_cross_correlation_functions']:
fn = os.path.join(
circuit.data_dir_circuit, 'processed_data', all_datatype + '.h5')
data = h5py.File(fn, 'r')
d.update({all_datatype: data})

# extract all_CCs from all_CCs_distances
ccs_time_intervals = np.array(
circuit.ana_dict['ccs_time_interval']).reshape(-1)
all_CCs = {}
for i, interval in enumerate(ccs_time_intervals):
all_CCs[i] = {}
for X in d['all_CCs_distances']:
if X != 'TC':
all_CCs[i][X] = d['all_CCs_distances'][X][f'ccs_{interval}ms']

# use the same bin width but different interval for CC distributions
# here and in statistics_overview
distr_max_cc = 0.04
distr_num_bins = int(circuit.plot_dict['distr_num_bins']
* distr_max_cc
/ circuit.plot_dict['distr_max_cc'])

# bins used in distribution in [0,1]
bins_unscaled = (np.arange(0, distr_num_bins + 1) / distr_num_bins)
bins = 2. * (bins_unscaled - 0.5) * distr_max_cc

fig = plt.figure(figsize=(circuit.plot_dict['fig_width_2col'], 3))
gs = gridspec.GridSpec(1, 2)
gs.update(left=0.04, right=0.96, bottom=0.09, top=0.93, wspace=0.3)

# distributions of correlation coefficients for different time lags
ax = plot.plot_population_panels_2cols(
gs[0, 0],
plotfunc=plot.plotfunc_distributions,
populations=circuit.net_dict['populations'][:-1],
layer_labels=circuit.plot_dict['layer_labels'],
data2d=all_CCs,
pop_colors=circuit.plot_dict['pop_colors'],
xlabel='$CC$',
ylabel='p (a.u.)',
bins=bins,
MaxNLocatorNBins=2)

plot.add_label(ax, 'A')

# legend
num = len(circuit.ana_dict['ccs_time_interval'])
legend_labels = np.array(
circuit.ana_dict['ccs_time_interval']).astype(int) # int

colors = [plot.adjust_lightness(
circuit.plot_dict['pop_colors'][0], 1-j/(num-1)) for j in np.arange(num)]

lines = [matplotlib.lines.Line2D([0], [0], color=c) for c in colors]

ax.legend(lines, legend_labels, title=r'$\Delta t_{CC}$ (ms)',
loc='center', bbox_to_anchor=(0.9, 0.75),
frameon=False,
fontsize=matplotlib.rcParams['font.size'] * 0.8)

#####

# spike train cross-correlation functions
ax = plot.plot_cross_correlation_functions(
gs[0, 1],
layer_labels=circuit.plot_dict['layer_labels'],
all_cross_correlation_functions=d['all_cross_correlation_functions'],
pop_colors=circuit.plot_dict['pop_colors'])

plot.add_label(ax, 'B')

plt.savefig(os.path.join(output_dir, 'correlation.pdf'))

return


Expand Down
164 changes: 160 additions & 4 deletions mesocircuit/plotting/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import matplotlib.colors as mc
import colorsys
from mpi4py import MPI
import os
import warnings
Expand Down Expand Up @@ -457,6 +459,148 @@ def plot_spatial_snapshots(
return ax


def plot_population_panels_2cols(
gs,
plotfunc,
populations,
layer_labels,
data2d,
pop_colors,
xlabel='',
ylabel='',
wspace=0.3,
**kwargs):
"""
Generic function to plot 2 columns of panels for an even number of populations.
Multiple curves per population are possible.
"""
ncols = int(np.floor(np.sqrt(len(populations))))
nrows = len(populations) // ncols
gsf = gridspec.GridSpecFromSubplotSpec(
nrows, ncols, subplot_spec=gs, wspace=wspace)

for i, X in enumerate(populations):
# select subplot
ax = plt.subplot(gsf[i])
for loc in ['top', 'right']:
ax.spines[loc].set_color('none')

# iterate over 2 dimensional data
num = len(data2d)
for j in np.arange(num):
colors = [adjust_lightness(c, 1-j/(num-1)) for c in pop_colors]
plotfunc(ax, X, i, data=data2d[j],
pop_colors=colors, **kwargs)

layer = layer_labels[int(i / 2.)]
if i == 0:
ax.set_title('E')
ax.set_ylabel(ylabel + '\n' + layer)
ax_label = ax
if i % ncols == 0 and i != 0:
ax.set_ylabel(layer)

if i == 1:
ax.set_title('I')

if i % ncols > 0:
ax.set_yticklabels([])

if i >= len(populations) - 2:
ax.set_xlabel(xlabel)
else:
ax.set_xticklabels([])

return ax_label


def plot_cross_correlation_functions(
gs,
layer_labels,
all_cross_correlation_functions,
pop_colors,
lag_max_plot=None,
scale_exp_plot=5,
cc_max_plot=5):
"""
"""
spcorrs = all_cross_correlation_functions

# average cross-correlation functions
spcorrs_mean = {}
for X, Y in zip(['L23E', 'L23E', 'L23I',
'L4E', 'L4E', 'L4I',
'L5E', 'L5E', 'L5I',
'L6E', 'L6E', 'L6I'],
['L23E', 'L23I', 'L23I',
'L4E', 'L4I', 'L4I',
'L5E', 'L5I', 'L5I',
'L6E', 'L6I', 'L6I']):

spcorrs_mean[f'{X}:{Y}'] = spcorrs[f'{X}:{Y}'][()].mean(axis=0)

# which time lags to plot
lag_times = np.array(spcorrs['lag_times'])
if not lag_max_plot:
lag_max = lag_times[-1]
else:
lag_max = lag_max_plot
inds = (lag_times >= -lag_max) & (lag_times <= lag_max)

gsf = gridspec.GridSpecFromSubplotSpec(
2, 2, subplot_spec=gs, hspace=0.5, wspace=0.5)

for i, L in enumerate(['L23', 'L4', 'L5', 'L6']):
ax = plt.subplot(gsf[i])

for loc in ['top', 'right']:
ax.spines[loc].set_color('none')

for j, key in enumerate([f'{L}E:{L}E', f'{L}E:{L}I', f'{L}I:{L}I']):
XY = key.split(':')
if XY[0][-1] == 'E' and XY[1][-1] == 'E':
color = pop_colors[::2][i]
elif XY[0][-1] == 'I' and XY[1][-1] == 'I':
color = pop_colors[1::2][i]
else:
color = 'k'

if L == 'L23':
label = 'L2/3' + XY[0][-1] + ':' + 'L2/3' + XY[1][-1]
else:
label = key

ax.plot(lag_times[inds],
spcorrs_mean[key][inds] * 10**(scale_exp_plot),
color=color,
label=f'{XY[0][-1]}:{XY[1][-1]}')

ax.set_ylim(-cc_max_plot, cc_max_plot)

ax.set_title(layer_labels[i])
ax.axhline(y=0, color="grey", ls=':')
ax.axvline(x=0, color="grey", ls=':')
ax.legend(frameon=False,
loc='center', bbox_to_anchor=(1., 0.8),
fontsize=matplotlib.rcParams['font.size'] * 0.8)

if i == 0:
ax_label = ax
if i < 2:
ax.set_xticklabels([])
if i >= 2:
ax.set_xlabel('time lag (ms)')

ylabel = '$CC^s$'
if scale_exp_plot != 1:
ylabel += r' ($10^{' + f'{-scale_exp_plot}' + r'}$)'

if i % 2 == 0:
ax.set_ylabel(ylabel)

return ax_label


def plot_crosscorrelation_funcs_thalamic_pulses(
gs,
all_CCfuncs_thalamic_pulses,
Expand Down Expand Up @@ -601,7 +745,7 @@ def linfunc(t, r0, v):
if i == 0:
ax.set_title('E')
ax.set_ylabel('distance (mm)\n' + layer)
ax_return = ax
ax_label = ax
if i % ncols == 0 and i != 0:
ax.set_ylabel(layer)

Expand Down Expand Up @@ -638,7 +782,7 @@ def linfunc(t, r0, v):
cb = fig.colorbar(
im, cax=cax, orientation='vertical')
cb.set_label(r'$CC^\nu$', labelpad=0.1)
return ax_return
return ax_label


def plot_theory_overview(
Expand Down Expand Up @@ -899,7 +1043,7 @@ def plot_population_panels(
yticklabels=True,
**kwargs):
"""
Generic function to plot four vertically arranged panels, one for each
Generic function to plot vertically arranged panels, one for each
population.
"""
num_pops = len(populations)
Expand Down Expand Up @@ -1218,7 +1362,7 @@ def plotfunc_CCs_distance(
return

distances = data[X]['distances_mm'][:max_num_pairs]
ccs = data[X][key_ccs][:max_num_pairs]
ccs = data[X][key_ccs + 'ms'][:max_num_pairs]

# loop for reducing zorder-bias
blocksize = int(len(distances) / nblocks)
Expand Down Expand Up @@ -1292,6 +1436,18 @@ def plotfunc_theory_power_spectra(ax, X, i, data, pop_colors):
return


def adjust_lightness(color, amount=0.5):
"""
Function from: https://stackoverflow.com/questions/37765197/darken-or-lighten-a-color-in-matplotlib
"""
try:
c = mc.cnames[color]
except:
c = color
c = colorsys.rgb_to_hls(*mc.to_rgb(c))
return colorsys.hls_to_rgb(c[0], max(0, min(1, amount * c[1])), c[2])


def colorbar(
ax,
im,
Expand Down
Loading

0 comments on commit ed1baa4

Please sign in to comment.