Skip to content

Commit

Permalink
Merge pull request #21 from bioAI-Oslo/main
Browse files Browse the repository at this point in the history
Update paper branch to match main
  • Loading branch information
JakobSonstebo authored Sep 13, 2023
2 parents 70fade4 + c6b4961 commit 1a75b51
Show file tree
Hide file tree
Showing 25 changed files with 1,192 additions and 1,054 deletions.
Binary file added docs/_static/batch_size.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/cpu_gpu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/isi_distribution.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/synaptic_scaling.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
68 changes: 68 additions & 0 deletions docs/benchmarks/benchmarks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
############################
Memory usage and performance
############################

.. currentmodule:: spikeometric.models

In this section, we discuss the memory usage and performance of core spikeometric functionality.
The benchmarks are run in `Google Colab <https://colab.research.google.com/drive/1tQ0ULV04g6W45QCfQheLVFwladLzVJ8v?usp=sharing>`_ and the code is available in the `examples` folder of the
`GitHub repository <https://github.com/bioAI-Oslo/Spikeometric>`_. The GPU used for the experiments in this section is the Nvidia V100 with 16GB of memory
and the CPU is the Intel Xeon with 13GB of RAM.

Memory usage
------------
The main memory consumer when running a simulation is the synaptic weights. The GLM models support time filters on the synaptic weights
to capture persisting effects of the pre-synaptic spikes over time, while the synaptic activation models (:class:`ThresholdSAM` and :class:`RectifiedSAM`)
only consider spike-events from the previous time step. In general, the weights are stored as a :math:`[N_{syn}, T]` tensor, where :math:`N_{syn}` is the number of synapses
and :math:`T` is the length of the time filter. In addition, we need an edge index to specify the pre- and post-synaptic neurons of each synapse, which is stored as a :math:`[2, N_{syn}]` tensor.
One exception is the :class:`BernoulliGLM`, which also includes a refractory period modeled as a negative self-connection, adding
an additional :math:`N_{neurons}` synapses.

We also need to store the number of spikes of each neuron per time step, which by default consumes 32 bytes.
In most cases, however, we don't expect the number of spikes per time step for any neuron to exceed 127, which means we can safely reduce the memory conumption to
8 bytes by passing :code:`torch.int8` as the :code:`store_as_dtype` argument of the :meth:`simulate` method if we need additional memory.

Concretely, the total memory usage (in bytes) of GLM models can be estimated as

.. math::
M_{GLM} = 32 \cdot T \cdot N_{syn} + 32 \cdot 2 \cdot N_{syn} + 8 \cdot N_{neurons} \cdot N_{steps}
where :math:`N_{syn}` is the number of synapses, :math:`N_{neurons}` is the number of neurons, and :math:`N_{steps}` is the
number. For :class:`BernoulliGLM`, we simply add on another :math:`32 \cdot T \cdot N_{neurons}`.
In addition to the spikes, the synaptic activation models also need to store a 32 byte floating-point state for each neuron per time step.
The memory usage of the synaptic activation models can therefore be estimated as

.. math::
M_{SAM} = 32 \cdot T \cdot N_{syn} + 32 \cdot 2 \cdot N_{syn} + 8 \cdot N_{neurons} \cdot N_{steps} + 32 \cdot N_{neurons} \cdot N_{steps}
Performance
-----------
The performance of the Spikeometric models is measured in terms of the time per iteration. Up to ~2 500 000 synapses,
the time per iteration is roughly constant at ~0.3 ms per iteration for most models, with the :class:`ThresholdSAM` being slightly slower due to
due to the cost of computing its more involved background input. If we run the models for 1000 iterations
for networks with ~2 500 000 synapses up to ~50 000 000 synapses, we can see that the run time eventually increases
approximately linearly with the number of synapses. We also run experiments where we increase
the number of time steps in the time filters up to 100 in steps of 5 for a network of ~2 500 000 synapses. The results are shown in the figure below.

.. image:: ../_static/synaptic_scaling.png


CPU vs GPU
----------

Spikeometric is designed to work well with both CPU and GPU architectures, but utilizing a GPU
brings significant performance benefits when simulating larger networks.
For networks with up to 10 000 synapses, the CPU is faster due to overhead on the GPU, but while time per iteration remains
constant at about 0.3 ms up to about 2 500 000 synapses on the GPU, it increases from 0.15 ms at 1000 synapses to 30 ms per iteration at 2 500 000 synapses on the CPU.

.. figure:: ../_static/cpu_gpu.png


Speed-up from batching
----------------------
The performance of the Spikeometric models can be further improved by batching the simulation of multiple networks. For example, using
the :class:`BernoulliGLM` model, simulating networks with 2 500 000 synapses can be improved to ~0.15 ms per iteration on average
by batching at least 10 networks at a time. The speed-up from batching is shown in the figure below.

.. image:: ../_static/batch_size.png
6 changes: 6 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ Then you can install `spikeometric` using pip:
tutorials/implement_model
tutorials/stimuli

.. toctree::
:maxdepth: 1
:caption: Benchmarks

benchmarks/benchmarks

.. toctree::
:maxdepth: 1
:caption: API Reference
Expand Down
7 changes: 5 additions & 2 deletions docs/introduction/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ is invertible.
Introductory example
====================

.. currentmodule:: spikeometric.datasets
.. currentmodule:: spikeometric.stimulus
.. currentmodule:: spikeometric.models

.. note::
Expand Down Expand Up @@ -118,11 +120,12 @@ For each of our ten networks, we will stimulate 4 random excitatory neurons.
stimulus_masks = [torch.isin(torch.arange(n_neurons), torch.randperm(n_neurons//2)[:4]) for _ in range(10)]
stimulus = RegularStimulus(
strength=5.0,
interval=100,
duration=100_000,
period=100,
stop=100_000,
tau=10,
dt=1,
stimulus_masks=stimulus_masks,
batch_size=5
)
model.add_stimulus(stimulus)
Expand Down
9 changes: 9 additions & 0 deletions docs/modules/generated/spikeometric.stimulus.BaseStimulus.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
BaseStimulus
============

.. currentmodule:: spikeometric.stimulus

.. autoclass:: BaseStimulus
:show-inheritance:
:members:
:undoc-members:
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
LoadedStimulus
==============

.. currentmodule:: spikeometric.stimulus

.. autoclass:: LoadedStimulus
:show-inheritance:
:members:
:undoc-members:
8 changes: 5 additions & 3 deletions docs/modules/stimulus.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ Stimulus

The stimlation classes are used to define the external input to the network. They inherit from the :class:`Module` class
from torch, which allows us to tune the parameters of the stimulus with the usual torch methods. A couple of stimulus models are provided in the package, but it is easy to implement
new ones by taking inspiration from the existing ones and extending the :class:`Module` class from torch.
new ones by taking inspiration from the existing ones and extending the :class:`BaseStimulus` class.

.. autosummary::
:toctree: generated/
:nosignatures:
:template: autosummary/class.rst


BaseStimulus
PoissonStimulus
RegularStimulus
SinStimulus
SinStimulus
LoadedStimulus
42 changes: 26 additions & 16 deletions docs/tutorials/implement_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ Stimulus filter
----------------
We want the model to handle a multidimensional stimulus given to some of the neurons,
and must implement a :py:meth:`stimulus_filter` method to determine how this stimulus will be integrated into an input.
The function :py:func:`k` that we pass to the model will help us with this.
Note that we are assuming that all targeted neurons receive the same stimulus and that they have the same stimulus filter.
The function :py:func:`k` that we pass to the model will decide how the different channels in the multidimensional stimulus are weighted.
Note that we are assuming that all targeted neurons they have the same stimulus filter.

.. code-block:: python
def stimulus_filter(self, stimulus: torch.Tensor) -> torch.Tensor:
return torch.sum(stimulus*self.k(torch.arange(stimulus.shape[0])))
return torch.sum(stimulus*self.k(torch.arange(stimulus.shape[1])), dim=1)
Input
------
Expand All @@ -133,7 +133,7 @@ Non-linearity
--------------
After we have collected the input to the neurons, we need to apply a non-linearity to it. In this case,
we use a rectified linear non-linearity. The rate is scaled by the time step to ensure that increasing the resolution
of the simulation doesn't artificially increase the firing rate.
of the simulation doesn't artificially change the firing rate.

.. code-block:: python
Expand Down Expand Up @@ -196,7 +196,7 @@ Summary
return W.flip(1)
def stimulus_filter(self, stimulus: torch.Tensor) -> torch.Tensor:
return torch.sum(stimulus*self.k(torch.arange(stimulus.shape[0])))
return torch.sum(stimulus*self.k(torch.arange(stimulus.shape[1])), dim=1)
def input(self, edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor, t=-1) -> torch.Tensor:
return self.synaptic_input(edge_index, W, state) + self.stimulus_input(t)
Expand All @@ -209,30 +209,39 @@ Summary
Testing the model
------------------
To test the model, we will simulate some networks with 50 neurons. We will use a random connectivity matrix with normal weights and a random periodic stimulus.
To test the model, we will simulate some networks with 50 neurons. We will use a random connectivity matrix with normal weights and a periodic stimulus
with random amplitudes in five frequency bins. Since we are batching the simulations in groups of 2 networks,
we will first generate a stimulus plan and make use of the :py:class:`LoadedStimulus` class to take care of the batching for us.

.. code-block::
n_neurons = 50
test_data = NormalGenerator(n_neurons, mean=0, std=0.1, glorot=True).generate(10)
loader = DataLoader(test_data, batch_size=5, shuffle=False)
test_data = NormalGenerator(n_neurons, mean=0, std=0.5, glorot=True).generate(10)
loader = DataLoader(test_data, batch_size=2, shuffle=False)
def r(t):
return -100.*(t < 2) + -100*torch.exp(-2*(t-2))*(t >= 2)*(t<5)
return -1000.*(t < 2) + -1000*torch.exp(-2*(t-2))*(t >= 2)*(t<5)
def w(t):
return torch.exp(-t/2)
def k(f):
return torch.exp(-f/5) # weight lower frequencies more
return torch.exp(-f)
def stimulus(t):
targets = torch.isin(torch.arange(50), torch.arange(50)[::10]).unsqueeze
return 0.1*torch.rand(5) * (t % 100 < 20)*targets
stimulus_plan = torch.zeros(50, 500, 5, 10) # 50 neurons, 500 time steps, 5 frequency bins, 10 networks
t = torch.arange(500)
for i in range(10):
# For each network, we stimulate the first 10 neurons (out of 50), for 20 time steps every 100 time steps, with random amplitudes in five frequency bins.
# The dimensions returned are [n_neurons, n_time_steps (length of t), n_frequency_bins]
stimulus_plan[:, :, :, i] = ((torch.arange(50) < 10).unsqueeze(1) * (t % 100 < 20)).unsqueeze(2)*torch.rand(50, t.shape[0], 5)
torch.save(stimulus_plan, "stimulus_plan.pt")
stimulus = LoadedStimulus("stimulus_plan.pt", batch_size=2)
model = FilRectLNP(lambda_0=1, theta=-0.01, dt=1, T=20, r=r, w=w, k=k)
Since we don't know what parameters to use, we'll make a guess and then tune them to give a firing rate of 10 Hz.
We don't know what parameters to use, so we'll make a guess and then tune them to give a firing rate of 10 Hz.
We'll use a learning rate of 1e-4 and train for 100 epochs with 500 time steps per epoch. Note that
the model is not trained to fit a set of spike trains, but rather to give a firing rate of 10 Hz.
Note also that we might need to try a few different inital parameters for the model not to immediately blow up.
Expand All @@ -250,9 +259,10 @@ We can now add the stimulus and simulate the model on the data.
.. code-block::
model.add_stimulus(stimulus)
results = torch.zeros((len(test_data), n_neurons))
n_steps = 10_000
results = torch.zeros((50*10, n_steps))
for i, data in enumerate(loader):
results[i*data.num_nodes:(i+1)*data.num_nodes] = model.simulate(data, n_steps=500)
results[i*data.num_nodes:(i+1)*data.num_nodes] = model.simulate(data, n_steps=n_steps)
And plot the results to get the following PSTH, ISI and raster plot

Expand Down
5 changes: 3 additions & 2 deletions docs/tutorials/stimuli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ Synchronizing the network activity with a stimulus
--------------------------------------------------

In this final example we'll try to synchronize the activity of a network with the frequency of a sinusoidal stimulus.
We'll use the :py:class:`SinusoidalStimulus` to model the stimulus, and tune it to achieve an average firing rate of the network
We'll use the :py:class:`SinStimulus` to model the stimulus, and tune it to achieve an average firing rate of the network
that is close to the frequency of the stimulus.

.. code-block:: python
Expand Down Expand Up @@ -162,7 +162,8 @@ that is close to the frequency of the stimulus.
amplitude=1,
period=100, # frequency = 1/period = 10 Hz
duration=n_steps,
stimulus_masks=stim_masks
stimulus_masks=stim_masks,
batch_size=10,
)
model.add_stimulus(stimulus)
Expand Down
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ channels:
dependencies:
- python
- pip
- pytorch=1.13
- pytorch
- cpuonly
- pyg
- pip:
- -r requirements.txt
- -r requirements_dev.txt
Loading

0 comments on commit 1a75b51

Please sign in to comment.