Skip to content

Commit

Permalink
Merge pull request #3120 from jprotze/fixPerThreadBoolIndicator
Browse files Browse the repository at this point in the history
Fix data race in nestkernel/per_thread_bool_indicator.cpp
  • Loading branch information
suku248 authored Mar 12, 2024
2 parents 0aaa566 + 4b185a4 commit 6ab67c5
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 46 deletions.
4 changes: 2 additions & 2 deletions nestkernel/connection_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,13 +876,13 @@ nest::ConnectionManager::connect_( Node& source,
{
#pragma omp atomic write
has_primary_connections_ = true;
check_primary_connections_[ tid ].set_true();
check_primary_connections_.set_true( tid );
}
else if ( check_secondary_connections_[ tid ].is_false() and not is_primary )
{
#pragma omp atomic write
secondary_connections_exist_ = true;
check_secondary_connections_[ tid ].set_true();
check_secondary_connections_.set_true( tid );
}
}

Expand Down
16 changes: 8 additions & 8 deletions nestkernel/event_delivery_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ EventDeliveryManager::gather_target_data( const size_t tid )
assert( not kernel().connection_manager.is_source_table_cleared() );

// assume all threads have some work to do
gather_completed_checker_[ tid ].set_false();
gather_completed_checker_.set_false( tid );
assert( gather_completed_checker_.all_false() );

const AssignedRanks assigned_ranks = kernel().vp_manager.get_assigned_ranks( tid );
Expand All @@ -802,7 +802,7 @@ EventDeliveryManager::gather_target_data( const size_t tid )
{
// assume this is the last gather round and change to false
// otherwise
gather_completed_checker_[ tid ].set_true();
gather_completed_checker_.set_true( tid );

#pragma omp master
{
Expand All @@ -819,7 +819,7 @@ EventDeliveryManager::gather_target_data( const size_t tid )
assigned_ranks, kernel().mpi_manager.get_send_recv_count_target_data_per_rank() );

const bool gather_completed = collocate_target_data_buffers_( tid, assigned_ranks, send_buffer_position );
gather_completed_checker_[ tid ].logical_and( gather_completed );
gather_completed_checker_.logical_and( tid, gather_completed );

if ( gather_completed_checker_.all_true() )
{
Expand All @@ -842,7 +842,7 @@ EventDeliveryManager::gather_target_data( const size_t tid )
#pragma omp barrier

const bool distribute_completed = distribute_target_data_buffers_( tid );
gather_completed_checker_[ tid ].logical_and( distribute_completed );
gather_completed_checker_.logical_and( tid, distribute_completed );

// resize mpi buffers, if necessary and allowed
if ( gather_completed_checker_.any_false() and kernel().mpi_manager.adaptive_target_buffers() )
Expand All @@ -864,7 +864,7 @@ EventDeliveryManager::gather_target_data_compressed( const size_t tid )
assert( not kernel().connection_manager.is_source_table_cleared() );

// assume all threads have some work to do
gather_completed_checker_[ tid ].set_false();
gather_completed_checker_.set_false( tid );
assert( gather_completed_checker_.all_false() );

const AssignedRanks assigned_ranks = kernel().vp_manager.get_assigned_ranks( tid );
Expand All @@ -874,7 +874,7 @@ EventDeliveryManager::gather_target_data_compressed( const size_t tid )
while ( gather_completed_checker_.any_false() )
{
// assume this is the last gather round and change to false otherwise
gather_completed_checker_[ tid ].set_true();
gather_completed_checker_.set_true( tid );

#pragma omp master
{
Expand All @@ -891,7 +891,7 @@ EventDeliveryManager::gather_target_data_compressed( const size_t tid )
const bool gather_completed =
collocate_target_data_buffers_compressed_( tid, assigned_ranks, send_buffer_position );

gather_completed_checker_[ tid ].logical_and( gather_completed );
gather_completed_checker_.logical_and( tid, gather_completed );

if ( gather_completed_checker_.all_true() )
{
Expand All @@ -916,7 +916,7 @@ EventDeliveryManager::gather_target_data_compressed( const size_t tid )
// all data it is responsible for to buffers. Now combine with information on whether other ranks
// have sent all their data. Note: All threads will return the same value for distribute_completed.
const bool distribute_completed = distribute_target_data_buffers_( tid );
gather_completed_checker_[ tid ].logical_and( distribute_completed );
gather_completed_checker_.logical_and( tid, distribute_completed );

// resize mpi buffers, if necessary and allowed
if ( gather_completed_checker_.any_false() and kernel().mpi_manager.adaptive_target_buffers() )
Expand Down
57 changes: 25 additions & 32 deletions nestkernel/per_thread_bool_indicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,62 +50,55 @@ PerThreadBoolIndicator::initialize( const size_t num_threads, const bool status
kernel().vp_manager.assert_single_threaded();
per_thread_status_.clear();
per_thread_status_.resize( num_threads, BoolIndicatorUInt64( status ) );
size_ = num_threads;
if ( status )
{
are_true_ = num_threads;
}
else
{
are_true_ = 0;
}
}

bool
PerThreadBoolIndicator::all_false() const
{
// We need two barriers here to ensure that no thread can continue and change the result
// before all threads have determined the result.
#pragma omp barrier
for ( auto it = per_thread_status_.begin(); it < per_thread_status_.end(); ++it )
{
if ( it->is_true() )
{
return false;
}
}
return true;
// We need two barriers here to ensure that no thread can continue and change the result
// before all threads have determined the result.
bool ret = ( are_true_ == 0 );
#pragma omp barrier
return ret;
}

bool
PerThreadBoolIndicator::all_true() const
{
#pragma omp barrier
for ( auto it = per_thread_status_.begin(); it < per_thread_status_.end(); ++it )
{
if ( it->is_false() )
{
return false;
}
}
return true;
bool ret = ( are_true_ == size_ );
#pragma omp barrier
return ret;
}

bool
PerThreadBoolIndicator::any_false() const
{
#pragma omp barrier
for ( auto it = per_thread_status_.begin(); it < per_thread_status_.end(); ++it )
{
if ( it->is_false() )
{
return true;
}
}
return false;
bool ret = ( are_true_ < size_ );
#pragma omp barrier
return ret;
}

bool
PerThreadBoolIndicator::any_true() const
{
#pragma omp barrier
for ( auto it = per_thread_status_.begin(); it < per_thread_status_.end(); ++it )
{
if ( it->is_true() )
{
return true;
}
}
return false;
bool ret = ( are_true_ > 0 );
#pragma omp barrier
return ret;
}

} // namespace nest
44 changes: 43 additions & 1 deletion nestkernel/per_thread_bool_indicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#define PER_THREAD_BOOL_INDICATOR_H

// C++ includes:
#include <atomic>
#include <cassert>
#include <cstddef>
#include <cstdint>
Expand Down Expand Up @@ -52,15 +53,17 @@ class BoolIndicatorUInt64
bool is_true() const;
bool is_false() const;


protected:
void set_true();
void set_false();

void logical_and( const bool status );

private:
static constexpr std::uint_fast64_t true_uint64 = true;
static constexpr std::uint_fast64_t false_uint64 = false;
std::uint_fast64_t status_;
friend class PerThreadBoolIndicator;
};

inline bool
Expand Down Expand Up @@ -106,6 +109,36 @@ class PerThreadBoolIndicator

BoolIndicatorUInt64& operator[]( const size_t tid );

void
set_true( const size_t tid )
{
if ( per_thread_status_[ tid ].is_false() )
{
are_true_++;
per_thread_status_[ tid ].set_true();
}
}

void
set_false( const size_t tid )
{
if ( per_thread_status_[ tid ].is_true() )
{
are_true_--;
per_thread_status_[ tid ].set_false();
}
}

void
logical_and( const size_t tid, const bool status )
{
if ( per_thread_status_[ tid ].is_true() and not status )
{
are_true_--;
per_thread_status_[ tid ].set_false();
}
}

/**
* Resize to the given number of threads and set all elements to false.
*/
Expand Down Expand Up @@ -133,6 +166,15 @@ class PerThreadBoolIndicator

private:
std::vector< BoolIndicatorUInt64 > per_thread_status_;
int size_ { 0 };

/** Number of per-thread indicators currently true
*
* are_true_ == 0 -> all are false
* are_true_ == size_ -> all are true
* 0 < are_true_ < size_ -> some true, some false
*/
std::atomic< int > are_true_ { 0 };
};

} // namespace nest
Expand Down
6 changes: 3 additions & 3 deletions nestkernel/source_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ SourceTable::clear( const size_t tid )
it->clear();
}
sources_[ tid ].clear();
is_cleared_[ tid ].set_true();
is_cleared_.set_true( tid );
}

inline void
Expand Down Expand Up @@ -412,15 +412,15 @@ SourceTable::save_entry_point( const size_t tid )
assert( current_positions_[ tid ].lcid == -1 );
saved_positions_[ tid ].lcid = -1;
}
saved_entry_point_[ tid ].set_true();
saved_entry_point_.set_true( tid );
}
}

inline void
SourceTable::restore_entry_point( const size_t tid )
{
current_positions_[ tid ] = saved_positions_[ tid ];
saved_entry_point_[ tid ].set_false();
saved_entry_point_.set_false( tid );
}

inline void
Expand Down

0 comments on commit 6ab67c5

Please sign in to comment.