Skip to content

Commit

Permalink
Get rid of sendAcrossNetwork and simplify doPostsAndWaits
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed Jun 8, 2024
1 parent 81487c4 commit 35bbe1d
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 236 deletions.
82 changes: 7 additions & 75 deletions src/details/ArborX_DetailsDistributedTreeUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <ArborX_DetailsKokkosExtStdAlgorithms.hpp>
#include <ArborX_DetailsKokkosExtViewHelpers.hpp>
#include <ArborX_DetailsPriorityQueue.hpp>
#include <ArborX_DetailsUtils.hpp> // create_layout*

#include <Kokkos_Core.hpp>
#include <Kokkos_Profiling_ScopedRegion.hpp>
Expand All @@ -30,73 +29,6 @@
namespace ArborX::Details::DistributedTree
{

template <typename ExecutionSpace, typename Distributor, typename View>
std::enable_if_t<Kokkos::is_view<View>::value>
sendAcrossNetwork(ExecutionSpace const &space, Distributor const &distributor,
View exports, typename View::non_const_type imports)
{
Kokkos::Profiling::ScopedRegion guard(
"ArborX::DistributedTree::sendAcrossNetwork (" + exports.label() + ")");

ARBORX_ASSERT((exports.extent(0) == distributor.getTotalSendLength()) &&
(imports.extent(0) == distributor.getTotalReceiveLength()) &&
(exports.extent(1) == imports.extent(1)) &&
(exports.extent(2) == imports.extent(2)) &&
(exports.extent(3) == imports.extent(3)) &&
(exports.extent(4) == imports.extent(4)) &&
(exports.extent(5) == imports.extent(5)) &&
(exports.extent(6) == imports.extent(6)) &&
(exports.extent(7) == imports.extent(7)));

auto const num_packets = exports.extent(1) * exports.extent(2) *
exports.extent(3) * exports.extent(4) *
exports.extent(5) * exports.extent(6) *
exports.extent(7);

using NonConstValueType = typename View::non_const_value_type;

#ifndef ARBORX_ENABLE_GPU_AWARE_MPI
using MirrorSpace = typename View::host_mirror_space;
typename MirrorSpace::execution_space const execution_space;
#else
using MirrorSpace = typename View::device_type::memory_space;
auto const &execution_space = space;
#endif

auto imports_layout_right = create_layout_right_mirror_view_no_init(
execution_space, MirrorSpace{}, imports);

#ifndef ARBORX_ENABLE_GPU_AWARE_MPI
execution_space.fence();
#endif

Kokkos::View<NonConstValueType *, MirrorSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>
import_buffer(imports_layout_right.data(), imports_layout_right.size());

distributor.doPostsAndWaits(space, exports, num_packets, import_buffer);

constexpr bool can_skip_copy =
(View::rank == 1 &&
(std::is_same_v<typename View::array_layout, Kokkos::LayoutLeft> ||
std::is_same_v<typename View::array_layout, Kokkos::LayoutRight>));
if constexpr (can_skip_copy)
{
// For 1D non-strided views, we can directly copy to the original location,
// as layout is the same
Kokkos::deep_copy(space, imports, imports_layout_right);
}
else
{
// For multi-dimensional views, we need to first copy into a separate
// storage because of a different layout
auto tmp_view = Kokkos::create_mirror_view_and_copy(
Kokkos::view_alloc(space, typename ExecutionSpace::memory_space{}),
imports_layout_right);
Kokkos::deep_copy(space, imports, tmp_view);
}
}

template <typename ExecutionSpace, typename QueryIdsView, typename OffsetView>
void countResults(ExecutionSpace const &space, int n_queries,
QueryIdsView const &query_ids, OffsetView &offset)
Expand Down Expand Up @@ -152,7 +84,7 @@ void forwardQueries(MPI_Comm comm, ExecutionSpace const &space,
"ArborX::DistributedTree::query::forwardQueries::import_ranks"),
n_imports);

sendAcrossNetwork(space, distributor, export_ranks, import_ranks);
distributor.doPostsAndWaits(space, export_ranks, import_ranks);
fwd_ranks = import_ranks;
}

Expand All @@ -177,7 +109,7 @@ void forwardQueries(MPI_Comm comm, ExecutionSpace const &space,
"ArborX::DistributedTree::query::forwardQueries::imports"),
n_imports);

sendAcrossNetwork(space, distributor, exports, imports);
distributor.doPostsAndWaits(space, exports, imports);
fwd_queries = imports;
}

Expand All @@ -202,7 +134,7 @@ void forwardQueries(MPI_Comm comm, ExecutionSpace const &space,
"ArborX::DistributedTree::query::forwardQueries::import_ids"),
n_imports);

sendAcrossNetwork(space, distributor, export_ids, import_ids);
distributor.doPostsAndWaits(space, export_ids, import_ids);
fwd_ids = import_ids;
}
}
Expand Down Expand Up @@ -245,7 +177,7 @@ void communicateResultsBack(MPI_Comm comm, ExecutionSpace const &space,
Kokkos::view_alloc(space, Kokkos::WithoutInitializing, ranks.label()),
n_imports);

sendAcrossNetwork(space, distributor, export_ranks, import_ranks);
distributor.doPostsAndWaits(space, export_ranks, import_ranks);
ranks = import_ranks;
}

Expand All @@ -267,7 +199,7 @@ void communicateResultsBack(MPI_Comm comm, ExecutionSpace const &space,
Kokkos::view_alloc(space, Kokkos::WithoutInitializing, ids.label()),
n_imports);

sendAcrossNetwork(space, distributor, export_ids, import_ids);
distributor.doPostsAndWaits(space, export_ids, import_ids);
ids = import_ids;
}

Expand All @@ -278,7 +210,7 @@ void communicateResultsBack(MPI_Comm comm, ExecutionSpace const &space,
Kokkos::view_alloc(space, Kokkos::WithoutInitializing, out.label()),
n_imports);

sendAcrossNetwork(space, distributor, export_out, import_out);
distributor.doPostsAndWaits(space, export_out, import_out);
out = import_out;
}
}
Expand Down Expand Up @@ -311,7 +243,7 @@ void forwardQueriesAndCommunicateResults(
// Communicate results back
communicateResultsBack(comm, space, values, offset, ranks, ids);

Kokkos::Profiling::pushRegion(prefix + "postprocess_results");
Kokkos::Profiling::pushRegion(prefix + "::postprocess_results");

// Merge results
int const n_predicates = predicates.size();
Expand Down
Loading

0 comments on commit 35bbe1d

Please sign in to comment.