Skip to content

Commit

Permalink
Use pointers to reference substructs
Browse files Browse the repository at this point in the history
  • Loading branch information
le-schmidt committed Aug 10, 2023
1 parent 0583fec commit 5068bd7
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ def _print_buffer_value(self, variable: ASTVariable) -> str:
vector_parameter = ASTUtils.get_numeric_vector_size(variable)
var_name = var_name + "_" + str(vector_parameter)

return "input.inputs[" + var_name + " - MIN_SPIKE_RECEPTOR]"
return "input->inputs[" + var_name + " - MIN_SPIKE_RECEPTOR]"

if variable_symbol.is_continuous_input_port():
var_name = variable_symbol.get_symbol_name().upper()
if variable.get_vector_parameter() is not None:
vector_parameter = ASTUtils.get_numeric_vector_size(variable)
var_name = var_name + "_" + str(vector_parameter)

return "input.inputs[" + var_name + " - MIN_SPIKE_RECEPTOR]"
return "input->inputs[" + var_name + " - MIN_SPIKE_RECEPTOR]"

return variable_symbol.get_symbol_name() + '_grid_sum_'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,10 @@ static void neuron_impl_add_inputs(
// Get the neuron itself

neuron_impl_t *neuron = &neuron_array[neuron_index];
neuron_state_t state = neuron->state;
neuron_input_t input = neuron->input;
neuron_parameter_t parameter = neuron->parameter;
neuron_input_t *input = &neuron->input;

// Do something to store the inputs for the next state update
input.inputs[synapse_type_index] += weights_this_timestep;
input->inputs[synapse_type_index] += weights_this_timestep;
}

__attribute__((unused)) // Marked unused as only used sometimes
Expand All @@ -220,14 +218,14 @@ static void neuron_impl_do_timestep_update(
for (uint32_t neuron_index = 0; neuron_index < n_neurons; neuron_index++) {
// Get the neuron itself
neuron_impl_t *neuron = &neuron_array[neuron_index];
neuron_state_t state = neuron->state;
neuron_input_t input = neuron->input;
neuron_parameter_t parameter = neuron->parameter;
neuron_state_t *state = &neuron->state;
neuron_input_t *input = &neuron->input;
neuron_parameter_t *parameter = &neuron->parameter;

// Store the recorded membrane voltage
{%- for variable in neuron.get_state_symbols() %}
{%- if variable.is_recordable %}
neuron_recording_record_accum({{ variable.get_symbol_name().upper() }}_RECORDING_INDEX, neuron_index, state.{{ variable.get_symbol_name() }});
neuron_recording_record_accum({{ variable.get_symbol_name().upper() }}_RECORDING_INDEX, neuron_index, state->{{ variable.get_symbol_name() }});
{%- endif -%}
{%- endfor %}

Expand All @@ -248,10 +246,10 @@ static void neuron_impl_do_timestep_update(
// neuron_recording_record_bit(SPIKE_RECORDING_BITFIELD, neuron_index); // ?

{%- for variable in neuron.get_spike_input_ports() %}
input.inputs[{{ variable.get_symbol_name().upper() }}] = ZERO;
input->inputs[{{ variable.get_symbol_name().upper() }}] = ZERO;
{%- endfor %}
{%- for variable in neuron.get_continuous_input_ports() %}
input.inputs[{{ variable.get_symbol_name().upper() }}] = ZERO;
input->inputs[{{ variable.get_symbol_name().upper() }}] = ZERO;
{%- endfor %}
}
}
Expand All @@ -261,9 +259,9 @@ void neuron_impl_print_inputs(uint32_t n_neurons) {
log_debug("-------------------------------------\n");
for (index_t i = 0; i < n_neurons; i++) {
neuron_impl_t *neuron = &neuron_array[i];
neuron_input_t input = neuron.input;
neuron_input_t *input = &neuron.input;

log_debug("inputs: %k %k", input.inputs[0], input.inputs[1]);
log_debug("inputs: %k %k", input->inputs[0], input->inputs[1]);
}
log_debug("-------------------------------------\n");
}
Expand Down
10 changes: 5 additions & 5 deletions pynestml/codegeneration/spinnaker_code_generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ def print_symbol_origin(cls, variable_symbol: VariableSymbol, numerical_state_sy
if numerical_state_symbols and variable_symbol.get_symbol_name() in numerical_state_symbols:
return 'S_.ode_state[State_::%s]'

return 'state.%s'
return 'state->%s'

if variable_symbol.block_type == BlockType.PARAMETERS:
return 'parameter.%s'
return 'parameter->%s'

if variable_symbol.block_type == BlockType.COMMON_PARAMETERS:
return 'parameter.%s'
return 'parameter->%s'

if variable_symbol.block_type == BlockType.INTERNALS:
return 'parameter.%s'
return 'parameter->%s'

if variable_symbol.block_type == BlockType.INPUT:
return 'input.%s'
return 'input.>%s'

return ''

0 comments on commit 5068bd7

Please sign in to comment.