Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for multi-dimensional stimulus in the LoadedStimulus class #18

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions spikeometric/stimulus/loaded_stimulus.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,23 @@ def __init__(self, path: str, batch_size: int = 1):
self.register_buffer("n_neurons", torch.tensor(stimuli.shape[0], dtype=torch.int))
self.register_buffer("n_steps", torch.tensor(stimuli.shape[1], dtype=torch.int))

n_networks = stimuli.shape[2] if len(stimuli.shape) > 2 else 1
self.register_buffer("n_networks", torch.tensor(n_networks, dtype=torch.int))
if len(stimuli.shape) < 4:
n_networks = stimuli.shape[2] if len(stimuli.shape) > 2 else 1
self.register_buffer("n_networks", torch.tensor(n_networks, dtype=torch.int))
else:
n_networks = stimuli.shape[3]
self.register_buffer("n_networks", torch.tensor(n_networks, dtype=torch.int))

self.register_buffer("batch_size", torch.tensor(batch_size, dtype=torch.int))
if self.n_networks < batch_size:
raise ValueError("The number of networks in the stimulus is smaller than the batch size.")

self.register_buffer("n_batches", torch.tensor(math.ceil(n_networks / batch_size), dtype=torch.int))

if len(stimuli.shape) > 2:
if len(stimuli.shape) == 3:
stimuli = torch.concat(torch.split(stimuli, 1, dim=2), dim=0).squeeze(2)
elif len(stimuli.shape) == 4:
stimuli = torch.concat(torch.split(stimuli, 1, dim=3), dim=0).squeeze(3)

neurons_per_batch = [self.n_neurons*self.batch_size] * (self.n_batches - 1)
if n_networks % batch_size != 0:
Expand All @@ -70,16 +76,14 @@ def __call__(self, t: Union[float, torch.Tensor]) -> torch.Tensor:
Parameters
----------
t : torch.Tensor or float
Time :math:`t` at which to compute the stimulus (ms).
Time :math:`t` at which to compute the stimulus.

Returns
-------
torch.Tensor [n_neurons, t.shape[0]] or [n_neurons]
Stimulus at time :math:`t`.
"""
if torch.is_tensor(t) and not t.dim() == 0:
result = torch.zeros((self.neurons_per_batch[self._idx], t.shape[0]))
result[:, :self.n_steps] = self.stimulus[:, :self.n_steps]
return result
return self.stimulus[..., :t.shape[0], ...]
else:
return self.stimulus[:, t] if 0 <= t < self.n_steps else torch.zeros_like(self.stimulus[:, 0])
return self.stimulus[..., t, ...] if 0 <= t < self.n_steps else torch.zeros(self.neurons_per_batch[self._idx], dtype=torch.float, device=self.stimulus.device)
Binary file added tests/test_data/stim_plan_multidim.pt
Binary file not shown.
43 changes: 24 additions & 19 deletions tests/test_stimulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,22 +204,24 @@ def test_sin_stimulus_vectorizes(sin_stimulus):
t = torch.arange(0, 1000)
assert sin_stimulus(t).shape == (20, 1000)

def test_loaded_stimulus_is_batched():
@pytest.mark.parametrize("stim_plan", ["tests/test_data/stim_plan_multidim.pt", "tests/test_data/stim_plan_4_networks.pt"])
def test_loaded_stimulus_is_batched(stim_plan):
from spikeometric.stimulus import LoadedStimulus
stimulus = LoadedStimulus("tests/test_data/stim_plan_4_networks.pt", batch_size=2)
stimulus = LoadedStimulus(stim_plan, batch_size=2)
n_steps = stimulus.n_steps.item()
assert stimulus(0).shape == (40,)
assert stimulus(n_steps).shape == (40,)
assert stimulus(0).shape[0] == 40
assert stimulus(n_steps).shape[0] == 40

def test_loaded_stimulus_can_batch_non_uneven_number_of_batches():
@pytest.mark.parametrize("stim_plan", ["tests/test_data/stim_plan_multidim.pt", "tests/test_data/stim_plan_4_networks.pt"])
def test_loaded_stimulus_can_batch_non_uneven_number_of_batches(stim_plan):
from spikeometric.stimulus import LoadedStimulus
stimulus = LoadedStimulus("tests/test_data/stim_plan_4_networks.pt", batch_size=3)
stimulus = LoadedStimulus(stim_plan, batch_size=3)
n_steps = stimulus.n_steps.item()
assert stimulus(0).shape == (60,)
assert stimulus(n_steps+1).shape == (60,)
assert stimulus(0).shape[0] == 60
assert stimulus(n_steps+1).shape[0] == 60
stimulus.next_batch()
assert stimulus(0).shape == (20,)
assert stimulus(n_steps+1).shape == (20,)
assert stimulus(0).shape[0] == 20
assert stimulus(n_steps+1).shape[0] == 20

def test_loaded_stimulus_affects_model_output(bernoulli_glm, example_data, loaded_stimulus):
initial_spikes = bernoulli_glm.simulate(example_data, n_steps=100, verbose=False)
Expand All @@ -239,9 +241,10 @@ def test_loaded_stimulus_matches_original_stimulus(loaded_stimulus):
def test_loaded_stimulus_is_zero_before_start_of_stimulus(loaded_stimulus):
assert torch.allclose(loaded_stimulus(-1), torch.zeros(20))

def test_loaded_stimulus_cycles_through_batches():
@pytest.mark.parametrize("stim_plan", ["tests/test_data/stim_plan_multidim.pt", "tests/test_data/stim_plan_4_networks.pt"])
def test_loaded_stimulus_cycles_through_batches(stim_plan):
from spikeometric.stimulus import LoadedStimulus
stimulus = LoadedStimulus("tests/test_data/stim_plan_4_networks.pt", batch_size=2)
stimulus = LoadedStimulus(stim_plan, batch_size=2)
n_steps = stimulus.n_steps
first_batch = stimulus(torch.arange(0, n_steps))
stimulus.next_batch()
Expand All @@ -251,23 +254,25 @@ def test_loaded_stimulus_cycles_through_batches():
stimulus.next_batch()
assert torch.allclose(second_batch, stimulus(torch.arange(0, n_steps)))

def test_loaded_stimulus_fails_if_batch_size_greater_than_n_networks():
@pytest.mark.parametrize("stim_plan", ["tests/test_data/stim_plan_multidim.pt", "tests/test_data/stim_plan_4_networks.pt"])
def test_loaded_stimulus_fails_if_batch_size_greater_than_n_networks(stim_plan):
from spikeometric.stimulus import LoadedStimulus
with pytest.raises(ValueError):
stimulus = LoadedStimulus("tests/test_data/stim_plan_4_networks.pt", batch_size=5)
stimulus = LoadedStimulus(stim_plan, batch_size=5)

def test_loaded_stimulus_fails_if_idx_out_of_range():
@pytest.mark.parametrize("stim_plan", ["tests/test_data/stim_plan_multidim.pt", "tests/test_data/stim_plan_4_networks.pt"])
def test_loaded_stimulus_fails_if_idx_out_of_range(stim_plan):
from spikeometric.stimulus import LoadedStimulus
stimulus = LoadedStimulus("tests/test_data/stim_plan_4_networks.pt", batch_size=2)
stimulus = LoadedStimulus(stim_plan, batch_size=2)
with pytest.raises(ValueError):
stimulus.set_batch(2)

def test_loaded_stimulus_resets():
@pytest.mark.parametrize("stim_plan", ["tests/test_data/stim_plan_multidim.pt", "tests/test_data/stim_plan_4_networks.pt"])
def test_loaded_stimulus_resets(stim_plan):
from spikeometric.stimulus import LoadedStimulus
stimulus = LoadedStimulus("tests/test_data/stim_plan_4_networks.pt", batch_size=1)
stimulus = LoadedStimulus(stim_plan, batch_size=1)
n_steps = stimulus.n_steps
first_batch = stimulus(torch.arange(0, n_steps))
stimulus.set_batch(3)
stimulus.reset()
assert torch.allclose(first_batch, stimulus(torch.arange(0, n_steps)))