Skip to content

Commit

Permalink
Tutorial update.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ziaeemehr committed Dec 5, 2023
1 parent f85ae87 commit 995f6ae
Show file tree
Hide file tree
Showing 12 changed files with 270 additions and 11 deletions.
Binary file added docs/_images/sweep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/_sources/tutorial.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ Examples

.. automodule:: examples.00_intro


.. automodule:: examples.01_sweep

13 changes: 11 additions & 2 deletions docs/genindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,22 @@ <h2 id="E">E</h2>
</li>
<li><a href="modules.html#vbjax.neural_mass.BOLDTheta.epsilon">epsilon (vbjax.neural_mass.BOLDTheta attribute)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="modules.html#vbjax.neural_mass.MPRTheta.eta">eta (vbjax.neural_mass.MPRTheta attribute)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li>
examples.00_intro

<ul>
<li><a href="tutorial.html#module-examples.00_intro">module</a>
</li>
</ul></li>
<li>
examples.01_sweep

<ul>
<li><a href="tutorial.html#module-examples.01_sweep">module</a>
</li>
</ul></li>
</ul></td>
Expand Down Expand Up @@ -277,6 +284,8 @@ <h2 id="M">M</h2>

<ul>
<li><a href="tutorial.html#module-examples.00_intro">examples.00_intro</a>
</li>
<li><a href="tutorial.html#module-examples.01_sweep">examples.01_sweep</a>
</li>
<li><a href="modules.html#module-vbjax.connectome">vbjax.connectome</a>
</li>
Expand Down
Binary file added docs/images/sweep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions docs/modules.html
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<script src="_static/sphinx_highlight.js?v=dc90522c"></script>
<link rel="index" title="Index" href="genindex.html" />
<link rel="search" title="Search" href="search.html" />
<link rel="prev" title="Tutorial" href="tutorial.html" />
<link rel="prev" title="Examples" href="tutorial.html" />
</head><body>
<div class="related" role="navigation" aria-label="related navigation">
<h3>Navigation</h3>
Expand All @@ -26,7 +26,7 @@ <h3>Navigation</h3>
<a href="py-modindex.html" title="Python Module Index"
>modules</a> |</li>
<li class="right" >
<a href="tutorial.html" title="Tutorial"
<a href="tutorial.html" title="Examples"
accesskey="P">previous</a> |</li>
<li class="nav-item nav-item-0"><a href="index.html">vbjax v0.0.10 documentation</a> &#187;</li>
<li class="nav-item nav-item-this"><a href="">neural_mass</a></li>
Expand Down Expand Up @@ -923,7 +923,7 @@ <h3><a href="index.html">Table of Contents</a></h3>
<div>
<h4>Previous topic</h4>
<p class="topless"><a href="tutorial.html"
title="previous chapter">Tutorial</a></p>
title="previous chapter">Examples</a></p>
</div>
<div role="note" aria-label="source link">
<h3>This Page</h3>
Expand Down Expand Up @@ -956,7 +956,7 @@ <h3>Navigation</h3>
<a href="py-modindex.html" title="Python Module Index"
>modules</a> |</li>
<li class="right" >
<a href="tutorial.html" title="Tutorial"
<a href="tutorial.html" title="Examples"
>previous</a> |</li>
<li class="nav-item nav-item-0"><a href="index.html">vbjax v0.0.10 documentation</a> &#187;</li>
<li class="nav-item nav-item-this"><a href="">neural_mass</a></li>
Expand Down
Binary file modified docs/objects.inv
Binary file not shown.
5 changes: 5 additions & 0 deletions docs/py-modindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ <h1>Python Module Index</h1>
<td>&#160;&#160;&#160;
<a href="tutorial.html#module-examples.00_intro"><code class="xref">examples.00_intro</code></a></td><td>
<em></em></td></tr>
<tr class="cg-1">
<td></td>
<td>&#160;&#160;&#160;
<a href="tutorial.html#module-examples.01_sweep"><code class="xref">examples.01_sweep</code></a></td><td>
<em></em></td></tr>
<tr class="pcap"><td></td><td>&#160;</td><td></td></tr>
<tr class="cap" id="cap-v"><td></td><td>
<strong>v</strong></td><td></td></tr>
Expand Down
2 changes: 1 addition & 1 deletion docs/searchindex.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ Examples

.. automodule:: examples.00_intro


.. automodule:: examples.01_sweep

100 changes: 99 additions & 1 deletion docs/tutorial.html
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,105 @@ <h3>Navigation</h3>
</pre></div>
</div>
<figure class="align-default">
<a class="reference internal image-reference" href="_images/example1.jpg"><img alt="_images/example1.jpg" src="_images/example1.jpg" style="width: 320.0px; height: 240.0px;" /></a>
<a class="reference internal image-reference" href="_images/example1.jpg"><img alt="_images/example1.jpg" src="_images/example1.jpg" style="width: 480.0px; height: 360.0px;" /></a>
</figure>
<p id="module-examples.01_sweep"><strong>Example 2</strong>: Consider a network of coupled Montbrio model nodes and sweep
over the parameters of the model.</p>
<p>Starting with a few imports</p>
<div class="literal-block-wrapper docutils container" id="id2">
<div class="code-block-caption"><span class="caption-text">../../examples/01_sweep.py</span><a class="headerlink" href="#id2" title="Link to this code"></a></div>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">time</span>
<span class="kn">import</span> <span class="nn">vbjax</span> <span class="k">as</span> <span class="nn">vb</span>
<span class="kn">import</span> <span class="nn">pylab</span> <span class="k">as</span> <span class="nn">pl</span>
<span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="s1">&#39;images&#39;</span><span class="p">,</span> <span class="n">exist_ok</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
</pre></div>
</div>
</div>
<p>This example shows how to use the <cite>vbjax</cite> library to simulate a network of
Montbrio model nodes. The network is defined by the function <cite>network</cite> which
takes as arguments the state of the network and the parameters of the model.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">network</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="p">):</span>
<span class="n">r</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">x</span>
<span class="n">k</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">mpr_p</span> <span class="o">=</span> <span class="n">p</span>
<span class="n">c</span> <span class="o">=</span> <span class="n">k</span><span class="o">*</span><span class="n">r</span><span class="o">.</span><span class="n">sum</span><span class="p">(),</span> <span class="n">k</span><span class="o">*</span><span class="n">v</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
<span class="k">return</span> <span class="n">vb</span><span class="o">.</span><span class="n">mpr_dfun</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">mpr_p</span><span class="p">)</span>
</pre></div>
</div>
<p>The function noise is used to generate the noise term of the stochastic</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">noise</span><span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">p</span><span class="p">):</span>
<span class="n">_</span><span class="p">,</span> <span class="n">sigma</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">p</span>
<span class="k">return</span> <span class="n">sigma</span>
</pre></div>
</div>
<p>The function run is used to simulate the network for a given set of parameters.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">run</span><span class="p">(</span><span class="n">pars</span><span class="p">,</span> <span class="n">mpr_p</span><span class="o">=</span><span class="n">vb</span><span class="o">.</span><span class="n">mpr_default_theta</span><span class="p">):</span>
<span class="n">k</span><span class="p">,</span> <span class="n">sig</span><span class="p">,</span> <span class="n">eta</span> <span class="o">=</span> <span class="n">pars</span> <span class="c1"># explored pars</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">k</span><span class="p">,</span> <span class="n">sig</span><span class="p">,</span> <span class="n">mpr_p</span><span class="o">.</span><span class="n">_replace</span><span class="p">(</span><span class="n">eta</span><span class="o">=</span><span class="n">eta</span><span class="p">)</span> <span class="c1"># set mpr</span>
<span class="n">xs</span> <span class="o">=</span> <span class="n">loop</span><span class="p">(</span><span class="n">rv0</span><span class="p">,</span> <span class="n">zs</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span> <span class="c1"># run sim</span>
<span class="n">std</span> <span class="o">=</span> <span class="n">xs</span><span class="p">[</span><span class="mi">400</span><span class="p">:,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">std</span><span class="p">()</span> <span class="c1"># eval metric</span>
<span class="k">return</span> <span class="n">std</span> <span class="c1"># done</span>
</pre></div>
</div>
<p>Then prepare the simulation</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">n_nodes</span> <span class="o">=</span> <span class="mi">8</span>
</pre></div>
</div>
<p>define the number of nodes in the network.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">using_cpu</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">local_devices</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">platform</span> <span class="o">==</span> <span class="s1">&#39;cpu&#39;</span>
<span class="k">if</span> <span class="n">using_cpu</span><span class="p">:</span>
<span class="n">run_batches</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">pmap</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">run</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span> <span class="n">in_axes</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">run_batches</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">run</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</pre></div>
</div>
<p>defines the engine to be used for the simulation. If the simulation is run on
the CPU, then the simulation is parallelized over the cores of the CPU using <cite>jax.pmap</cite>.
Otherwise, the simulation is parallelized over the GPU using <cite>jax.vmap</cite>.</p>
<p>then we prepare the network and the noise samples</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># prepare network</span>
<span class="n">_</span><span class="p">,</span> <span class="n">loop</span> <span class="o">=</span> <span class="n">vb</span><span class="o">.</span><span class="n">make_sde</span><span class="p">(</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">network</span><span class="p">,</span> <span class="n">noise</span><span class="p">)</span>
<span class="n">rv0</span> <span class="o">=</span> <span class="n">vb</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">n_nodes</span><span class="p">)</span>
</pre></div>
</div>
<p>The rest run the simulation on set of parameters and plot the results.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="c1"># sweep sigma but just a few values are enough</span>
<span class="n">sigmas</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">]</span>
<span class="n">results</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">ng</span> <span class="o">=</span> <span class="n">vb</span><span class="o">.</span><span class="n">cores</span><span class="o">*</span><span class="mi">4</span> <span class="k">if</span> <span class="n">using_cpu</span> <span class="k">else</span> <span class="mi">32</span>

<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">sig_i</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">sigmas</span><span class="p">):</span>
<span class="c1"># create grid of k (on logarithmic scale) and eta</span>
<span class="n">log_ks</span><span class="p">,</span> <span class="n">etas</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mgrid</span><span class="p">[</span><span class="o">-</span><span class="mf">9.0</span><span class="p">:</span><span class="o">-</span><span class="mf">2.0</span><span class="p">:</span><span class="mi">1</span><span class="n">j</span><span class="o">*</span><span class="n">ng</span><span class="p">,</span> <span class="o">-</span><span class="mf">4.0</span><span class="p">:</span><span class="o">-</span><span class="mf">6.0</span><span class="p">:</span><span class="mi">1</span><span class="n">j</span><span class="o">*</span><span class="n">ng</span><span class="p">]</span>

<span class="c1"># reshape grid to big batch of values</span>
<span class="n">pars</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">c_</span><span class="p">[</span>
<span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">log_ks</span><span class="o">.</span><span class="n">ravel</span><span class="p">()),</span>
<span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">log_ks</span><span class="o">.</span><span class="n">size</span><span class="p">)</span><span class="o">*</span><span class="n">sig_i</span><span class="p">,</span>
<span class="n">etas</span><span class="o">.</span><span class="n">ravel</span><span class="p">()]</span><span class="o">.</span><span class="n">T</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>

<span class="c1"># cpu w/ pmap expects a chunk for each core</span>
<span class="k">if</span> <span class="n">using_cpu</span><span class="p">:</span>
<span class="n">pars</span> <span class="o">=</span> <span class="n">pars</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="n">vb</span><span class="o">.</span><span class="n">cores</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span><span class="o">.</span><span class="n">transpose</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>

<span class="c1"># now run</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">run_batches</span><span class="p">(</span><span class="n">pars</span><span class="p">)</span><span class="o">.</span><span class="n">block_until_ready</span><span class="p">()</span>
<span class="n">results</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">result</span><span class="p">)</span>

<span class="n">toc</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;elapsed time for sweep </span><span class="si">{</span><span class="n">toc</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">tic</span><span class="si">:</span><span class="s1">0.1f</span><span class="si">}</span><span class="s1"> s&#39;</span><span class="p">)</span>


<span class="n">pl</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
<span class="n">pl</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span><span class="mi">2</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">sig_i</span><span class="p">,</span> <span class="n">result</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">sigmas</span><span class="p">,</span> <span class="n">results</span><span class="p">)):</span>
<span class="n">pl</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">pl</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">result</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">log_ks</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">vmin</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mf">0.7</span><span class="p">)</span>
</pre></div>
</div>
<figure class="align-default">
<a class="reference internal image-reference" href="_images/sweep.png"><img alt="_images/sweep.png" src="_images/sweep.png" style="width: 800.0px; height: 200.0px;" /></a>
</figure>
</section>

Expand Down
2 changes: 1 addition & 1 deletion examples/00_intro.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
:lines: 13
.. figure:: ../../examples/images/example1.jpg
:scale: 50 %
:scale: 75 %
"""
Expand Down
Loading

0 comments on commit 995f6ae

Please sign in to comment.