Skip to content

Commit

Permalink
kokkos#295: fix incorrect container conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
nmm0 committed Nov 10, 2023
1 parent 417204e commit 8e2b196
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 62 deletions.
48 changes: 8 additions & 40 deletions include/experimental/__p1684_bits/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,61 +133,29 @@ class mdarray {
) : map_(m), ctr_(container_is_array<container_type>::construct(map_))
{ }

// Constructors from container
MDSPAN_TEMPLATE_REQUIRES(
class... SizeTypes,
/* requires */ (
_MDSPAN_FOLD_AND(_MDSPAN_TRAIT( std::is_convertible, SizeTypes, index_type) /* && ... */) &&
_MDSPAN_TRAIT( std::is_constructible, extents_type, SizeTypes...) &&
_MDSPAN_TRAIT( std::is_constructible, mapping_type, extents_type)
)
)
MDSPAN_INLINE_FUNCTION
explicit constexpr mdarray(const container_type& ctr, SizeTypes... dynamic_extents)
: map_(extents_type(dynamic_extents...)), ctr_(ctr)
{ assert(ctr.size() >= static_cast<size_t>(map_.required_span_size())); }


MDSPAN_FUNCTION_REQUIRES(
(MDSPAN_INLINE_FUNCTION constexpr),
mdarray, (const container_type& ctr, const extents_type& exts), ,
mdarray, (const extents_type& exts, const container_type& ctr), ,
/* requires */ (_MDSPAN_TRAIT( std::is_constructible, mapping_type, extents_type))
) : map_(exts), ctr_(ctr)
{ assert(ctr.size() >= static_cast<size_t>(map_.required_span_size())); }

constexpr mdarray(const container_type& ctr, const mapping_type& m)
constexpr mdarray(const mapping_type& m, const container_type& ctr)
: map_(m), ctr_(ctr)
{ assert(ctr.size() >= static_cast<size_t>(map_.required_span_size())); }


// Constructors from container
MDSPAN_TEMPLATE_REQUIRES(
class... SizeTypes,
/* requires */ (
_MDSPAN_FOLD_AND(_MDSPAN_TRAIT( std::is_convertible, SizeTypes, index_type) /* && ... */) &&
_MDSPAN_TRAIT( std::is_constructible, extents_type, SizeTypes...) &&
_MDSPAN_TRAIT( std::is_constructible, mapping_type, extents_type)
)
)
MDSPAN_INLINE_FUNCTION
explicit constexpr mdarray(container_type&& ctr, SizeTypes... dynamic_extents)
: map_(extents_type(dynamic_extents...)), ctr_(std::move(ctr))
{ assert(ctr_.size() >= static_cast<size_t>(map_.required_span_size())); }


MDSPAN_FUNCTION_REQUIRES(
(MDSPAN_INLINE_FUNCTION constexpr),
mdarray, (container_type&& ctr, const extents_type& exts), ,
mdarray, (const extents_type& exts, container_type&& ctr), ,
/* requires */ (_MDSPAN_TRAIT( std::is_constructible, mapping_type, extents_type))
) : map_(exts), ctr_(std::move(ctr))
{ assert(ctr_.size() >= static_cast<size_t>(map_.required_span_size())); }

constexpr mdarray(container_type&& ctr, const mapping_type& m)
constexpr mdarray(const mapping_type& m, container_type&& ctr)
: map_(m), ctr_(std::move(ctr))
{ assert(ctr_.size() >= static_cast<size_t>(map_.required_span_size())); }



MDSPAN_TEMPLATE_REQUIRES(
class OtherElementType, class OtherExtents, class OtherLayoutPolicy, class OtherContainer,
/* requires */ (
Expand Down Expand Up @@ -229,7 +197,7 @@ class mdarray {
_MDSPAN_TRAIT( std::is_constructible, mapping_type, extents_type))
)
MDSPAN_INLINE_FUNCTION
constexpr mdarray(const container_type& ctr, const extents_type& exts, const Alloc& a)
constexpr mdarray(const extents_type& exts, const container_type& ctr, const Alloc& a)
: map_(exts), ctr_(ctr, a)
{ assert(ctr_.size() >= static_cast<size_t>(map_.required_span_size())); }

Expand All @@ -238,7 +206,7 @@ class mdarray {
/* requires */ (_MDSPAN_TRAIT( std::is_constructible, container_type, size_t, Alloc))
)
MDSPAN_INLINE_FUNCTION
constexpr mdarray(const container_type& ctr, const mapping_type& map, const Alloc& a)
constexpr mdarray(const mapping_type& map, const container_type& ctr, const Alloc& a)
: map_(map), ctr_(ctr, a)
{ assert(ctr_.size() >= static_cast<size_t>(map_.required_span_size())); }

Expand All @@ -248,7 +216,7 @@ class mdarray {
_MDSPAN_TRAIT( std::is_constructible, mapping_type, extents_type))
)
MDSPAN_INLINE_FUNCTION
constexpr mdarray(container_type&& ctr, const extents_type& exts, const Alloc& a)
constexpr mdarray(const extents_type& exts, container_type&& ctr, const Alloc& a)
: map_(exts), ctr_(std::move(ctr), a)
{ assert(ctr_.size() >= static_cast<size_t>(map_.required_span_size())); }

Expand All @@ -257,7 +225,7 @@ class mdarray {
/* requires */ (_MDSPAN_TRAIT( std::is_constructible, container_type, size_t, Alloc))
)
MDSPAN_INLINE_FUNCTION
constexpr mdarray(container_type&& ctr, const mapping_type& map, const Alloc& a)
constexpr mdarray(const mapping_type& map, container_type&& ctr, const Alloc& a)
: map_(map), ctr_(std::move(ctr), a)
{ assert(ctr_.size() >= map_.required_span_size()); }

Expand Down
44 changes: 22 additions & 22 deletions tests/test_mdarray_ctors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ TEST(TestMdarrayCtorFromContainerSizes, 1d_static) {
using mda_t = KokkosEx::mdarray<int, Kokkos::extents<unsigned,1>, Kokkos::layout_right, std::array<int,1>>;
// ptr to fill, extents, is_layout_right
mdarray_values<1>::fill(d.data(),Kokkos::extents<unsigned,1>(),true);
mda_t m(d,1);
mda_t m({}, d);
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 1, 0, 1, 0, 0, 1, 0, 0, d.data(), false, true);
}
Expand All @@ -257,7 +257,7 @@ TEST(TestMdarrayCtorFromContainerSizes, 2d_static) {
std::array<int, 6> d{42,43,44,3,4,41};
// ptr to fill, extents, is_layout_right
mdarray_values<2>::fill(d.data(),Kokkos::extents<int, 2,3>(),true);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,3>, Kokkos::layout_right, std::array<int,6>> m(d,2,3);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,3>, Kokkos::layout_right, std::array<int,6>> m({},d);
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 2, 0, 2, 3, 0, 3, 1, 0, d.data(), false, true);
}
Expand All @@ -266,7 +266,7 @@ TEST(TestMdarrayCtorFromContainerSizes, 1d_dynamic) {
std::vector<int> d{42};
// ptr to fill, extents, is_layout_right
mdarray_values<1>::fill(d.data(),Kokkos::extents<int, 1>(),true);
KokkosEx::mdarray<int, Kokkos::dextents<int, 1>> m(d,1);
KokkosEx::mdarray<int, Kokkos::dextents<int, 1>> m({},d);
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 1, 1, 1, 0, 0, 1, 0, 0, d.data(), false, true);
}
Expand All @@ -275,7 +275,7 @@ TEST(TestMdarrayCtorFromContainerSizes, 2d_dynamic) {
std::vector<int> d{42,1,2,3,4,41};
// ptr to fill, extents, is_layout_right
mdarray_values<2>::fill(d.data(),Kokkos::extents<int, 2,3>(),true);
KokkosEx::mdarray<int, Kokkos::dextents<int, 2>> m(d,2,3);
KokkosEx::mdarray<int, Kokkos::dextents<int, 2>> m({},d);
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 2, 2, 2, 3, 0, 3, 1, 0, d.data(), false, true);
}
Expand All @@ -284,7 +284,7 @@ TEST(TestMdarrayCtorFromContainerSizes, 2d_mixed) {
std::vector<int> d{42,1,2,3,4,41};
// ptr to fill, extents, is_layout_right
mdarray_values<2>::fill(d.data(),Kokkos::extents<int, 2,3>(),true);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,Kokkos::dynamic_extent>> m(d,3);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,Kokkos::dynamic_extent>> m({},d);
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 2, 1, 2, 3, 0, 3, 1, 0, d.data(), false, true);
}
Expand All @@ -294,7 +294,7 @@ TEST(TestMdarrayCtorFromMoveContainerSizes, 1d_static) {
std::array<int, 1> d{42};
// ptr to fill, extents, is_layout_right
mdarray_values<1>::fill(d.data(),Kokkos::extents<int, 1>(),true);
KokkosEx::mdarray<int, Kokkos::extents<int, 1>, Kokkos::layout_right, std::array<int,1>> m(std::move(d),1);
KokkosEx::mdarray<int, Kokkos::extents<int, 1>, Kokkos::layout_right, std::array<int,1>> m({},std::move(d));
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 1, 0, 1, 0, 0, 1, 0, 0, nullptr, false, true);
}
Expand All @@ -303,7 +303,7 @@ TEST(TestMdarrayCtorFromMoveContainerSizes, 2d_static) {
std::array<int, 6> d{42,1,2,3,4,41};
// ptr to fill, extents, is_layout_right
mdarray_values<2>::fill(d.data(),Kokkos::extents<int, 2,3>(),true);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,3>, Kokkos::layout_right, std::array<int,6>> m(std::move(d),2,3);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,3>, Kokkos::layout_right, std::array<int,6>> m({},std::move(d));
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 2, 0, 2, 3, 0, 3, 1, 0, nullptr, false, true);
}
Expand All @@ -313,7 +313,7 @@ TEST(TestMdarrayCtorFromMoveContainerSizes, 1d_dynamic) {
auto ptr = d.data();
// ptr to fill, extents, is_layout_right
mdarray_values<1>::fill(ptr,Kokkos::extents<int, 1>(),true);
KokkosEx::mdarray<int, Kokkos::dextents<int, 1>> m(std::move(d),1);
KokkosEx::mdarray<int, Kokkos::dextents<int, 1>> m({},std::move(d));
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 1, 1, 1, 0, 0, 1, 0, 0, ptr, true, true);
}
Expand All @@ -323,7 +323,7 @@ TEST(TestMdarrayCtorFromMoveContainerSizes, 2d_dynamic) {
auto ptr = d.data();
// ptr to fill, extents, is_layout_right
mdarray_values<2>::fill(ptr,Kokkos::extents<int, 2,3>(),true);
KokkosEx::mdarray<int, Kokkos::dextents<int, 2>> m(std::move(d),2,3);
KokkosEx::mdarray<int, Kokkos::dextents<int, 2>> m({},std::move(d));
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 2, 2, 2, 3, 0, 3, 1, 0, ptr, true, true);
}
Expand All @@ -333,7 +333,7 @@ TEST(TestMdarrayCtorFromMoveContainerSizes, 2d_mixed) {
auto ptr = d.data();
// ptr to fill, extents, is_layout_right
mdarray_values<2>::fill(ptr,Kokkos::extents<int, 2,3>(),true);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,Kokkos::dynamic_extent>> m(std::move(d),3);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,Kokkos::dynamic_extent>> m({},std::move(d));
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 2, 1, 2, 3, 0, 3, 1, 0, ptr, true, true);
}
Expand Down Expand Up @@ -399,7 +399,7 @@ TEST(TestMdarrayCtorFromContainerSizesAlloc, 1d_dynamic) {
std::vector<int> d{42};
// ptr to fill, extents, is_layout_right
mdarray_values<1>::fill(d.data(),Kokkos::extents<int, 1>(),true);
KokkosEx::mdarray<int, Kokkos::dextents<int, 1>> m(d,Kokkos::dextents<int, 1>{1}, alloc);
KokkosEx::mdarray<int, Kokkos::dextents<int, 1>> m(Kokkos::dextents<int, 1>{1}, d, alloc);
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 1, 1, 1, 0, 0, 1, 0, 0, d.data(), false, true);
}
Expand All @@ -409,7 +409,7 @@ TEST(TestMdarrayCtorFromContainerSizesAlloc, 2d_dynamic) {
std::vector<int> d{42,1,2,3,4,41};
// ptr to fill, extents, is_layout_right
mdarray_values<2>::fill(d.data(),Kokkos::extents<int, 2,3>(),true);
KokkosEx::mdarray<int, Kokkos::dextents<int, 2>> m(d,Kokkos::dextents<int, 2>{2,3}, alloc);
KokkosEx::mdarray<int, Kokkos::dextents<int, 2>> m(Kokkos::dextents<int, 2>{2,3}, d, alloc);
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 2, 2, 2, 3, 0, 3, 1, 0, d.data(), false, true);
}
Expand All @@ -419,7 +419,7 @@ TEST(TestMdarrayCtorFromContainerSizesAlloc, 2d_mixed) {
std::vector<int> d{42,1,2,3,4,41};
// ptr to fill, extents, is_layout_right
mdarray_values<2>::fill(d.data(),Kokkos::extents<int, 2,3>(),true);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,Kokkos::dynamic_extent>> m(d,Kokkos::extents<int, 2,Kokkos::dynamic_extent>{3}, alloc);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,Kokkos::dynamic_extent>> m(Kokkos::extents<int, 2,Kokkos::dynamic_extent>{3}, d, alloc);
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 2, 1, 2, 3, 0, 3, 1, 0, d.data(), false, true);
}
Expand All @@ -430,7 +430,7 @@ TEST(TestMdarrayCtorFromMoveContainerSizesAlloc, 1d_dynamic) {
auto ptr = d.data();
// ptr to fill, extents, is_layout_right
mdarray_values<1>::fill(ptr,Kokkos::extents<int, 1>(),true);
KokkosEx::mdarray<int, Kokkos::dextents<int, 1>> m(std::move(d),Kokkos::extents<int, 1>(), alloc);
KokkosEx::mdarray<int, Kokkos::dextents<int, 1>> m(Kokkos::extents<int, 1>(), std::move(d), alloc);
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 1, 1, 1, 0, 0, 1, 0, 0, ptr, true, true);
}
Expand All @@ -441,7 +441,7 @@ TEST(TestMdarrayCtorFromMoveContainerSizesAlloc, 2d_dynamic) {
auto ptr = d.data();
// ptr to fill, extents, is_layout_right
mdarray_values<2>::fill(ptr,Kokkos::extents<int, 2,3>(),true);
KokkosEx::mdarray<int, Kokkos::dextents<int, 2>> m(std::move(d),Kokkos::extents<int, 2,3>(), alloc);
KokkosEx::mdarray<int, Kokkos::dextents<int, 2>> m(Kokkos::extents<int, 2,3>(), std::move(d), alloc);
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 2, 2, 2, 3, 0, 3, 1, 0, ptr, true, true);
}
Expand All @@ -452,7 +452,7 @@ TEST(TestMdarrayCtorFromMoveContainerSizesAlloc, 2d_mixed) {
auto ptr = d.data();
// ptr to fill, extents, is_layout_right
mdarray_values<2>::fill(ptr,Kokkos::extents<int, 2,3>(),true);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,Kokkos::dynamic_extent>> m(std::move(d),Kokkos::extents<int, 2,Kokkos::dynamic_extent>(3), alloc);
KokkosEx::mdarray<int, Kokkos::extents<int, 2,Kokkos::dynamic_extent>> m(Kokkos::extents<int, 2,Kokkos::dynamic_extent>(3), std::move(d), alloc);
// mdarray, rank, rank_dynamic, ext0, ext1, ext2, stride0, stride1, stride2, ptr, ptr_matches, exhaustive
check_correctness(m, 2, 1, 2, 3, 0, 3, 1, 0, ptr, true, true);
}
Expand All @@ -475,7 +475,7 @@ TEST(TestMdarrayCtorWithPMR, 2d_mixed) {

top_container.emplace_back(3,3);
top_container.emplace_back(a.mapping());
top_container.emplace_back(a.container(), a.mapping());
top_container.emplace_back(a.mapping(), a.container());
top_container.push_back({a});
}
#endif
Expand All @@ -494,7 +494,7 @@ TEST(TestMdarrayCtorDataStdArray, test_mdarray_ctor_data_carray) {

TEST(TestMdarrayCtorDataVector, test_mdarray_ctor_data_carray) {
std::vector<int> d = {42};
KokkosEx::mdarray<int, Kokkos::extents<int, 1>, Kokkos::layout_right, std::vector<int>> m(d);
KokkosEx::mdarray<int, Kokkos::extents<int, 1>, Kokkos::layout_right, std::vector<int>> m({}, d);
ASSERT_EQ(m.rank(), 1);
ASSERT_EQ(m.rank_dynamic(), 0);
ASSERT_EQ(m.extent(0), 1);
Expand All @@ -506,7 +506,7 @@ TEST(TestMdarrayCtorDataVector, test_mdarray_ctor_data_carray) {
TEST(TestMdarrayCtorExtentsStdArrayConvertibleToSizeT, test_mdarray_ctor_extents_std_array_convertible_to_size_t) {
std::vector<int> d{42, 17, 71, 24};
std::array<int, 2> e{2, 2};
KokkosEx::mdarray<int, Kokkos::dextents<int, 2>> m(d, e);
KokkosEx::mdarray<int, Kokkos::dextents<int, 2>> m(e, d);
ASSERT_EQ(m.rank(), 2);
ASSERT_EQ(m.rank_dynamic(), 2);
ASSERT_EQ(m.extent(0), 2);
Expand All @@ -520,7 +520,7 @@ TEST(TestMdarrayCtorExtentsStdArrayConvertibleToSizeT, test_mdarray_ctor_extents
TEST(TestMdarrayListInitializationLayoutLeft, test_mdarray_list_initialization_layout_left) {
std::vector<int> d(16*32);
auto ptr = d.data();
KokkosEx::mdarray<int, Kokkos::extents<int, dyn, dyn>, Kokkos::layout_left> m{std::move(d), 16, 32};
KokkosEx::mdarray<int, Kokkos::extents<int, dyn, dyn>, Kokkos::layout_left> m{Kokkos::dextents<int, 2>{16, 32}, std::move(d)};
ASSERT_EQ(m.data(), ptr);
ASSERT_EQ(m.rank(), 2);
ASSERT_EQ(m.rank_dynamic(), 2);
Expand All @@ -535,7 +535,7 @@ TEST(TestMdarrayListInitializationLayoutLeft, test_mdarray_list_initialization_l
TEST(TestMdarrayListInitializationLayoutRight, test_mdarray_list_initialization_layout_right) {
std::vector<int> d(16*32);
auto ptr = d.data();
KokkosEx::mdarray<int, Kokkos::extents<int, dyn, dyn>, Kokkos::layout_right> m{std::move(d), 16, 32};
KokkosEx::mdarray<int, Kokkos::extents<int, dyn, dyn>, Kokkos::layout_right> m{Kokkos::dextents<int, 2>{16, 32}, std::move(d)};
ASSERT_EQ(m.data(), ptr);
ASSERT_EQ(m.rank(), 2);
ASSERT_EQ(m.rank_dynamic(), 2);
Expand All @@ -549,7 +549,7 @@ TEST(TestMdarrayListInitializationLayoutRight, test_mdarray_list_initialization_
TEST(TestMdarrayListInitializationLayoutStride, test_mdarray_list_initialization_layout_stride) {
std::vector<int> d(32*128);
auto ptr = d.data();
KokkosEx::mdarray<int, Kokkos::extents<int, dyn, dyn>, Kokkos::layout_stride> m{std::move(d), {Kokkos::dextents<int, 2>{16, 32}, std::array<std::size_t, 2>{1, 128}}};
KokkosEx::mdarray<int, Kokkos::extents<int, dyn, dyn>, Kokkos::layout_stride> m{{Kokkos::dextents<int, 2>{16, 32}, std::array<std::size_t, 2>{1, 128}}, std::move(d)};
ASSERT_EQ(m.data(), ptr);
ASSERT_EQ(m.rank(), 2);
ASSERT_EQ(m.rank_dynamic(), 2);
Expand Down

0 comments on commit 8e2b196

Please sign in to comment.