Skip to content

Commit

Permalink
Add method to sum (merge) maps together
Browse files Browse the repository at this point in the history
  • Loading branch information
victorreijgwart committed Dec 13, 2024
1 parent 565864c commit 7392f41
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 34 deletions.
26 changes: 26 additions & 0 deletions examples/python/edit/sum_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import numpy as np
import pywavemap as wave

# Load the map
user_home = os.path.expanduser('~')
input_map_path = os.path.join(user_home, "your_map.wvmp")
your_map = wave.Map.load(input_map_path)

# Crop the map
center = np.array([-2.2, -1.4, 0.0])
radius = 3.0
wave.edit.crop_to_sphere(your_map, center, radius)

# Create a translated copy
translation = np.array([5.0, 5.0, 0.0])
rotation = wave.Rotation(w=1.0, x=0.0, y=0.0, z=0.0)
transformation = wave.Pose(rotation, translation)
your_map_translated = wave.edit.transform(your_map, transformation)

# Merge them together
wave.edit.sum(your_map, your_map_translated)

# Save the map
output_map_path = os.path.join(user_home, "your_map_merged.wvmp")
your_map.store(output_map_path)
36 changes: 20 additions & 16 deletions library/cpp/include/wavemap/core/utils/edit/impl/multiply_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void multiplyNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
// Multiply
node.data() *= multiplier;

// Recursively handle all children
// Recursively handle all child nodes
for (int child_idx = 0; child_idx < OctreeIndex::kNumChildren; ++child_idx) {
if (auto child_node = node.getChild(child_idx); child_node) {
multiplyNodeRecursive<MapT>(*child_node, multiplier);
Expand All @@ -26,22 +26,26 @@ void multiply(MapT& map, FloatingPoint multiplier,
using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;

// Process all blocks
for (auto& [block_index, block] : map.getHashMap()) {
// Indicate that the block has changed
block.setLastUpdatedStamp();
// Multiply the block's average value (wavelet scale coefficient)
FloatingPoint& root_value = block.getRootScale();
root_value *= multiplier;
// Recursively multiply all node values (wavelet detail coefficients)
NodePtrType root_node_ptr = &block.getRootNode();
if (thread_pool) {
thread_pool->add_task([root_node_ptr, multiplier]() {
detail::multiplyNodeRecursive<MapT>(*root_node_ptr, multiplier);
map.forEachBlock(
[&thread_pool, multiplier](const Index3D& /*block_index*/, auto& block) {
// Indicate that the block has changed
block.setLastUpdatedStamp();

// Multiply the block's average value (wavelet scale coefficient)
FloatingPoint& root_value = block.getRootScale();
root_value *= multiplier;

// Recursively multiply all node values (wavelet detail coefficients)
NodePtrType root_node_ptr = &block.getRootNode();
if (thread_pool) {
thread_pool->add_task([root_node_ptr, multiplier]() {
detail::multiplyNodeRecursive<MapT>(*root_node_ptr, multiplier);
});
} else {
detail::multiplyNodeRecursive<MapT>(*root_node_ptr, multiplier);
}
});
} else {
detail::multiplyNodeRecursive<MapT>(*root_node_ptr, multiplier);
}
}

// Wait for all parallel jobs to finish
if (thread_pool) {
thread_pool->wait_all();
Expand Down
63 changes: 63 additions & 0 deletions library/cpp/include/wavemap/core/utils/edit/impl/sum_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,28 @@ void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
node_details = new_details;
node_value = new_value;
}

template <typename MapT>
void sumNodeRecursive(
typename MapT::Block::OctreeType::NodeRefType node_A,
typename MapT::Block::OctreeType::NodeConstRefType node_B) {
using NodeRefType = decltype(node_A);
using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;

// Sum
node_A.data() += node_B.data();

// Recursively handle all child nodes
for (NdtreeIndexRelativeChild child_idx = 0;
child_idx < OctreeIndex::kNumChildren; ++child_idx) {
NodeConstPtrType child_node_B = node_B.getChild(child_idx);
if (!child_node_B) {
continue;
}
NodeRefType child_node_A = node_A.getOrAllocateChild(child_idx);
sumNodeRecursive<MapT>(child_node_A, *child_node_B);
}
}
} // namespace detail

template <typename MapT, typename SamplingFn>
Expand Down Expand Up @@ -113,6 +135,47 @@ void sum(MapT& map, SamplingFn sampling_function,
thread_pool->wait_all();
}
}

template <typename MapT>
void sum(MapT& map_A, const MapT& map_B,
const std::shared_ptr<ThreadPool>& thread_pool) {
CHECK_EQ(map_A.getTreeHeight(), map_B.getTreeHeight());
CHECK_EQ(map_A.getMinCellWidth(), map_B.getMinCellWidth());
using NodePtrType = typename MapT::Block::OctreeType::NodePtrType;
using NodeConstPtrType = typename MapT::Block::OctreeType::NodeConstPtrType;

// Process all blocks
map_B.forEachBlock(
[&map_A, &thread_pool](const Index3D& block_index, const auto& block_B) {
auto& block_A = map_A.getOrAllocateBlock(block_index);

// Indicate that the block has changed
block_A.setLastUpdatedStamp();
block_A.setNeedsPruning();

// Sum the blocks' average values (wavelet scale coefficient)
block_A.getRootScale() += block_B.getRootScale();

// Recursively sum all node values (wavelet detail coefficients)
NodePtrType root_node_ptr_A = &block_A.getRootNode();
NodeConstPtrType root_node_ptr_B = &block_B.getRootNode();
if (thread_pool) {
thread_pool->add_task([root_node_ptr_A, root_node_ptr_B,
block_ptr_A = &block_A]() {
detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
block_ptr_A->prune();
});
} else {
detail::sumNodeRecursive<MapT>(*root_node_ptr_A, *root_node_ptr_B);
block_A.prune();
}
});

// Wait for all parallel jobs to finish
if (thread_pool) {
thread_pool->wait_all();
}
}
} // namespace wavemap::edit

#endif // WAVEMAP_CORE_UTILS_EDIT_IMPL_SUM_INL_H_
9 changes: 9 additions & 0 deletions library/cpp/include/wavemap/core/utils/edit/sum.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,21 @@ void sumNodeRecursive(typename MapT::Block::OctreeType::NodeRefType node,
SamplingFn&& sampling_function,
FloatingPoint min_cell_width,
IndexElement termination_height = 0);

template <typename MapT>
void sumNodeRecursive(
typename MapT::Block::OctreeType::NodeRefType node_A,
typename MapT::Block::OctreeType::NodeConstRefType node_B);
} // namespace detail

template <typename MapT, typename SamplingFn>
void sum(MapT& map, SamplingFn sampling_function,
IndexElement termination_height = 0,
const std::shared_ptr<ThreadPool>& thread_pool = nullptr);

template <typename MapT>
void sum(MapT& map_A, const MapT& map_B,
const std::shared_ptr<ThreadPool>& thread_pool = nullptr);
} // namespace wavemap::edit

#include "wavemap/core/utils/edit/impl/sum_inl.h"
Expand Down
51 changes: 33 additions & 18 deletions library/python/src/edit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,6 @@ using namespace nb::literals; // NOLINT

namespace wavemap {
void add_edit_module(nb::module_& m_edit) {
// Map cropping methods
m_edit.def(
"crop_to_sphere",
[](HashedWaveletOctree& map, const Point3D& t_W_center,
FloatingPoint radius, IndexElement termination_height) {
edit::crop_to_sphere(map, t_W_center, radius, termination_height,
std::make_shared<ThreadPool>());
},
"map"_a, "center_point"_a, "radius"_a, "termination_height"_a = 0);
m_edit.def(
"crop_to_sphere",
[](HashedChunkedWaveletOctree& map, const Point3D& t_W_center,
FloatingPoint radius, IndexElement termination_height) {
edit::crop_to_sphere(map, t_W_center, radius, termination_height,
std::make_shared<ThreadPool>());
},
"map"_a, "center_point"_a, "radius"_a, "termination_height"_a = 0);

// Map multiply methods
// NOTE: Among others, this can be used to implement exponential forgetting,
// by multiplying the map with a scalar between 0 and 1.
Expand All @@ -48,6 +30,21 @@ void add_edit_module(nb::module_& m_edit) {
},
"map"_a, "multiplier"_a);

// Map sum methods
m_edit.def(
"sum",
[](HashedWaveletOctree& map_A, const HashedWaveletOctree& map_B) {
edit::sum(map_A, map_B, std::make_shared<ThreadPool>());
},
"map_A"_a, "map_B"_a);
m_edit.def(
"sum",
[](HashedChunkedWaveletOctree& map_A,
const HashedChunkedWaveletOctree& map_B) {
edit::sum(map_A, map_B, std::make_shared<ThreadPool>());
},
"map_A"_a, "map_B"_a);

// Map transformation methods
m_edit.def(
"transform",
Expand All @@ -61,5 +58,23 @@ void add_edit_module(nb::module_& m_edit) {
return edit::transform(B_map, T_AB, std::make_shared<ThreadPool>());
},
"map"_a, "transformation"_a);

// Map cropping methods
m_edit.def(
"crop_to_sphere",
[](HashedWaveletOctree& map, const Point3D& t_W_center,
FloatingPoint radius, IndexElement termination_height) {
edit::crop_to_sphere(map, t_W_center, radius, termination_height,
std::make_shared<ThreadPool>());
},
"map"_a, "center_point"_a, "radius"_a, "termination_height"_a = 0);
m_edit.def(
"crop_to_sphere",
[](HashedChunkedWaveletOctree& map, const Point3D& t_W_center,
FloatingPoint radius, IndexElement termination_height) {
edit::crop_to_sphere(map, t_W_center, radius, termination_height,
std::make_shared<ThreadPool>());
},
"map"_a, "center_point"_a, "radius"_a, "termination_height"_a = 0);
}
} // namespace wavemap

0 comments on commit 7392f41

Please sign in to comment.