diff --git a/examples/python/edit/sum_map.py b/examples/python/edit/sum_map.py new file mode 100644 index 000000000..0f13a2471 --- /dev/null +++ b/examples/python/edit/sum_map.py @@ -0,0 +1,26 @@ +import os +import numpy as np +import pywavemap as wave + +# Load the map +user_home = os.path.expanduser('~') +input_map_path = os.path.join(user_home, "your_map.wvmp") +your_map = wave.Map.load(input_map_path) + +# Crop the map +center = np.array([-2.2, -1.4, 0.0]) +radius = 3.0 +wave.edit.crop_to_sphere(your_map, center, radius) + +# Create a translated copy +translation = np.array([5.0, 5.0, 0.0]) +rotation = wave.Rotation(w=1.0, x=0.0, y=0.0, z=0.0) +transformation = wave.Pose(rotation, translation) +your_map_translated = wave.edit.transform(your_map, transformation) + +# Merge them together +wave.edit.sum(your_map, your_map_translated) + +# Save the map +output_map_path = os.path.join(user_home, "your_map_merged.wvmp") +your_map.store(output_map_path) diff --git a/library/cpp/include/wavemap/core/utils/edit/impl/multiply_inl.h b/library/cpp/include/wavemap/core/utils/edit/impl/multiply_inl.h index 22fb7ad32..5ddab582b 100644 --- a/library/cpp/include/wavemap/core/utils/edit/impl/multiply_inl.h +++ b/library/cpp/include/wavemap/core/utils/edit/impl/multiply_inl.h @@ -11,7 +11,7 @@ void multiplyNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node, // Multiply node.data() *= multiplier; - // Recursively handle all children + // Recursively handle all child nodes for (int child_idx = 0; child_idx < OctreeIndex::kNumChildren; ++child_idx) { if (auto child_node = node.getChild(child_idx); child_node) { multiplyNodeRecursive(*child_node, multiplier); @@ -26,22 +26,26 @@ void multiply(MapT& map, FloatingPoint multiplier, using NodePtrType = typename MapT::Block::OctreeType::NodePtrType; // Process all blocks - for (auto& [block_index, block] : map.getHashMap()) { - // Indicate that the block has changed - block.setLastUpdatedStamp(); - // Multiply the block's average value (wavelet scale coefficient) - FloatingPoint& root_value = block.getRootScale(); - root_value *= multiplier; - // Recursively multiply all node values (wavelet detail coefficients) - NodePtrType root_node_ptr = &block.getRootNode(); - if (thread_pool) { - thread_pool->add_task([root_node_ptr, multiplier]() { - detail::multiplyNodeRecursive(*root_node_ptr, multiplier); + map.forEachBlock( + [&thread_pool, multiplier](const Index3D& /*block_index*/, auto& block) { + // Indicate that the block has changed + block.setLastUpdatedStamp(); + + // Multiply the block's average value (wavelet scale coefficient) + FloatingPoint& root_value = block.getRootScale(); + root_value *= multiplier; + + // Recursively multiply all node values (wavelet detail coefficients) + NodePtrType root_node_ptr = &block.getRootNode(); + if (thread_pool) { + thread_pool->add_task([root_node_ptr, multiplier]() { + detail::multiplyNodeRecursive(*root_node_ptr, multiplier); + }); + } else { + detail::multiplyNodeRecursive(*root_node_ptr, multiplier); + } }); - } else { - detail::multiplyNodeRecursive(*root_node_ptr, multiplier); - } - } + // Wait for all parallel jobs to finish if (thread_pool) { thread_pool->wait_all(); diff --git a/library/cpp/include/wavemap/core/utils/edit/impl/sum_inl.h b/library/cpp/include/wavemap/core/utils/edit/impl/sum_inl.h index 3752710ec..9857ad6d0 100644 --- a/library/cpp/include/wavemap/core/utils/edit/impl/sum_inl.h +++ b/library/cpp/include/wavemap/core/utils/edit/impl/sum_inl.h @@ -67,6 +67,28 @@ void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node, node_details = new_details; node_value = new_value; } + +template +void sumNodeRecursive( + typename MapT::Block::OctreeType::NodeRefType node_A, + typename MapT::Block::OctreeType::NodeConstRefType node_B) { + using NodeRefType = decltype(node_A); + using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType; + + // Sum + node_A.data() += node_B.data(); + + // Recursively handle all child nodes + for (NdtreeIndexRelativeChild child_idx = 0; + child_idx < OctreeIndex::kNumChildren; ++child_idx) { + NodeConstPtrType child_node_B = node_B.getChild(child_idx); + if (!child_node_B) { + continue; + } + NodeRefType child_node_A = node_A.getOrAllocateChild(child_idx); + sumNodeRecursive(child_node_A, *child_node_B); + } +} } // namespace detail template @@ -113,6 +135,47 @@ void sum(MapT& map, SamplingFn sampling_function, thread_pool->wait_all(); } } + +template +void sum(MapT& map_A, const MapT& map_B, + const std::shared_ptr& thread_pool) { + CHECK_EQ(map_A.getTreeHeight(), map_B.getTreeHeight()); + CHECK_EQ(map_A.getMinCellWidth(), map_B.getMinCellWidth()); + using NodePtrType = typename MapT::Block::OctreeType::NodePtrType; + using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType; + + // Process all blocks + map_B.forEachBlock( + [&map_A, &thread_pool](const Index3D& block_index, const auto& block_B) { + auto& block_A = map_A.getOrAllocateBlock(block_index); + + // Indicate that the block has changed + block_A.setLastUpdatedStamp(); + block_A.setNeedsPruning(); + + // Sum the blocks' average values (wavelet scale coefficient) + block_A.getRootScale() += block_B.getRootScale(); + + // Recursively sum all node values (wavelet detail coefficients) + NodePtrType root_node_ptr_A = &block_A.getRootNode(); + NodeConstPtrType root_node_ptr_B = &block_B.getRootNode(); + if (thread_pool) { + thread_pool->add_task([root_node_ptr_A, root_node_ptr_B, + block_ptr_A = &block_A]() { + detail::sumNodeRecursive(*root_node_ptr_A, *root_node_ptr_B); + block_ptr_A->prune(); + }); + } else { + detail::sumNodeRecursive(*root_node_ptr_A, *root_node_ptr_B); + block_A.prune(); + } + }); + + // Wait for all parallel jobs to finish + if (thread_pool) { + thread_pool->wait_all(); + } +} } // namespace wavemap::edit #endif // WAVEMAP_CORE_UTILS_EDIT_IMPL_SUM_INL_H_ diff --git a/library/cpp/include/wavemap/core/utils/edit/sum.h b/library/cpp/include/wavemap/core/utils/edit/sum.h index b19a5762e..dd9314440 100644 --- a/library/cpp/include/wavemap/core/utils/edit/sum.h +++ b/library/cpp/include/wavemap/core/utils/edit/sum.h @@ -20,12 +20,21 @@ void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node, SamplingFn&& sampling_function, FloatingPoint min_cell_width, IndexElement termination_height = 0); + +template +void sumNodeRecursive( + typename MapT::Block::OctreeType::NodeRefType node_A, + typename MapT::Block::OctreeType::NodeConstRefType node_B); } // namespace detail template void sum(MapT& map, SamplingFn sampling_function, IndexElement termination_height = 0, const std::shared_ptr& thread_pool = nullptr); + +template +void sum(MapT& map_A, const MapT& map_B, + const std::shared_ptr& thread_pool = nullptr); } // namespace wavemap::edit #include "wavemap/core/utils/edit/impl/sum_inl.h" diff --git a/library/python/src/edit.cc b/library/python/src/edit.cc index df69e465f..782696e8c 100644 --- a/library/python/src/edit.cc +++ b/library/python/src/edit.cc @@ -14,24 +14,6 @@ using namespace nb::literals; // NOLINT namespace wavemap { void add_edit_module(nb::module_& m_edit) { - // Map cropping methods - m_edit.def( - "crop_to_sphere", - [](HashedWaveletOctree& map, const Point3D& t_W_center, - FloatingPoint radius, IndexElement termination_height) { - edit::crop_to_sphere(map, t_W_center, radius, termination_height, - std::make_shared()); - }, - "map"_a, "center_point"_a, "radius"_a, "termination_height"_a = 0); - m_edit.def( - "crop_to_sphere", - [](HashedChunkedWaveletOctree& map, const Point3D& t_W_center, - FloatingPoint radius, IndexElement termination_height) { - edit::crop_to_sphere(map, t_W_center, radius, termination_height, - std::make_shared()); - }, - "map"_a, "center_point"_a, "radius"_a, "termination_height"_a = 0); - // Map multiply methods // NOTE: Among others, this can be used to implement exponential forgetting, // by multiplying the map with a scalar between 0 and 1. @@ -48,6 +30,21 @@ void add_edit_module(nb::module_& m_edit) { }, "map"_a, "multiplier"_a); + // Map sum methods + m_edit.def( + "sum", + [](HashedWaveletOctree& map_A, const HashedWaveletOctree& map_B) { + edit::sum(map_A, map_B, std::make_shared()); + }, + "map_A"_a, "map_B"_a); + m_edit.def( + "sum", + [](HashedChunkedWaveletOctree& map_A, + const HashedChunkedWaveletOctree& map_B) { + edit::sum(map_A, map_B, std::make_shared()); + }, + "map_A"_a, "map_B"_a); + // Map transformation methods m_edit.def( "transform", @@ -61,5 +58,23 @@ void add_edit_module(nb::module_& m_edit) { return edit::transform(B_map, T_AB, std::make_shared()); }, "map"_a, "transformation"_a); + + // Map cropping methods + m_edit.def( + "crop_to_sphere", + [](HashedWaveletOctree& map, const Point3D& t_W_center, + FloatingPoint radius, IndexElement termination_height) { + edit::crop_to_sphere(map, t_W_center, radius, termination_height, + std::make_shared()); + }, + "map"_a, "center_point"_a, "radius"_a, "termination_height"_a = 0); + m_edit.def( + "crop_to_sphere", + [](HashedChunkedWaveletOctree& map, const Point3D& t_W_center, + FloatingPoint radius, IndexElement termination_height) { + edit::crop_to_sphere(map, t_W_center, radius, termination_height, + std::make_shared()); + }, + "map"_a, "center_point"_a, "radius"_a, "termination_height"_a = 0); } } // namespace wavemap