Skip to content

Commit

Permalink
Merge pull request #333 from Point72/mrc/csp-77
Browse files Browse the repository at this point in the history
Improve implementation of statistics functions Quantile/Rank using `boost::multi_index`
  • Loading branch information
timkpaine authored Jul 17, 2024
2 parents 4114b90 + c2be3cf commit 5c81458
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 179 deletions.
243 changes: 64 additions & 179 deletions cpp/csp/cppnodes/statsimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
#include <numeric>
#include <set>
#include <type_traits>

#ifdef __linux__
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#endif
#include <boost/multi_index_container.hpp>
#include <boost/multi_index/ordered_index.hpp>
#include <boost/multi_index/ranked_index.hpp>

namespace csp::cppnodes
{
Expand Down Expand Up @@ -1087,19 +1085,8 @@ class WeightedKurtosis
bool m_excess;
};

#ifdef __linux__
template<typename Comparator>
using ost = __gnu_pbds::tree<double, __gnu_pbds::null_type, Comparator, __gnu_pbds::rb_tree_tag,
__gnu_pbds::tree_order_statistics_node_update>;

template<typename Comparator>
void ost_erase( ost<Comparator> &t, double & v )
{
int rank = t.order_of_key( v );
auto it = t.find_by_order( rank );
t.erase( it );
}
#endif
template <typename Comparator>
using ost = boost::multi_index::multi_index_container<double, boost::multi_index::indexed_by<boost::multi_index::ranked_non_unique<boost::multi_index::identity<double>, Comparator>>>;

class Quantile
{
Expand Down Expand Up @@ -1142,11 +1129,7 @@ class Quantile

void remove( double x )
{
#ifdef __linux__
ost_erase( m_tree, x );
#else
m_tree.erase( m_tree.find( x ) );
#endif
}

void reset()
Expand All @@ -1165,113 +1148,60 @@ class Quantile
double target = std::get<double>( m_quants[index]._data ) * ( m_tree.size() - 1 );
int ft = floor( target );
int ct = ceil( target );
auto fIt = m_tree.get<0>().nth( ft );
auto cIt = ( ft == ct ) ? fIt : std::next( fIt );

double qtl = 0.0;
#ifdef __linux__
switch ( m_interpolation )
{
case LINEAR:
if( ft == target )
{
qtl = *m_tree.find_by_order( ft );
}
else
{
double lower = *m_tree.find_by_order( ft );
double higher = *m_tree.find_by_order( ct );
qtl = ( 1 - target + ft ) * lower + ( 1 - ct + target ) * higher;
}
break;
case LOWER:
qtl = *m_tree.find_by_order( ft );
break;
case HIGHER:
qtl = *m_tree.find_by_order( ct );
break;
case MIDPOINT:
if( ft == target )
{
qtl = *m_tree.find_by_order( ft );
}
else
{
double lower = *m_tree.find_by_order( ft );
double higher = *m_tree.find_by_order( ct );
qtl = ( higher+lower ) / 2;
}
break;
case NEAREST:
if( target - ft < ct - target )
{
qtl = *m_tree.find_by_order( ft );
}
else
{
qtl = *m_tree.find_by_order( ct );
}
break;
default:
break;
}
#else
auto it = m_tree.begin();
std::advance( it, ft );
switch ( m_interpolation )
{
case LINEAR:
if( ft == target )
{
qtl = *it;
}
else
{
double lower = *it;
double higher = *++it;
qtl = ( 1 - target + ft ) * lower + ( 1 - ct + target ) * higher;
}
break;
case LOWER:
qtl = *it;
break;
case HIGHER:
qtl = ( ft == ct ? *it : *++it );
break;
case MIDPOINT:
if( ft == target )
{
qtl = *it;
}
else
{
double lower = *it;
double higher = *++it;
qtl = ( higher+lower ) / 2;
}
break;
case NEAREST:
if( target - ft <= ct - target )
{
qtl = *it;
}
else
{
qtl = *++it;
}
break;
default:
break;
case LINEAR:
if ( ft == target )
{
qtl = *fIt;
}
else
{
double lower = *fIt;
double higher = *cIt;
qtl = ( 1 - target + ft ) * lower + ( 1 - ct + target ) * higher;
}
break;
case LOWER:
qtl = *fIt;
break;
case HIGHER:
qtl = *cIt;
break;
case MIDPOINT:
if ( ft == target )
{
qtl = *fIt;
}
else
{
double lower = *fIt;
double higher = *cIt;
qtl = ( higher + lower ) / 2;
}
break;
case NEAREST:
if ( target - ft < ct - target )
{
qtl = *fIt;
}
else
{
qtl = *cIt;
}
break;
default:
break;
}
#endif
return qtl;
}

private:

#ifdef __linux__
ost<std::less_equal<double>> m_tree;
#else
std::multiset<double> m_tree;
#endif
ost<std::less<double>> m_tree;
std::vector<Dictionary::Data> m_quants;
int64_t m_interpolation;
};
Expand Down Expand Up @@ -1359,119 +1289,74 @@ class Rank
else
{
m_lastval = x;
#ifdef __linux__
if( m_method == MAX )
m_maxtree.insert( x );
else
m_mintree.insert( x );
#else
m_tree.insert( x );
#endif
}
}

void remove( double x )
{
if( likely( !isnan( x ) ) )
{
#ifdef __linux__
if( m_method == MAX )
ost_erase( m_maxtree, x );
if ( m_method == MAX )
m_maxtree.erase ( m_maxtree.find( x ) );
else
ost_erase( m_mintree, x );
#else
m_tree.erase( m_tree.find( x ) );
#endif
m_mintree.erase ( m_mintree.find( x ) );
}
}

void reset()
{
#ifdef __linux__
if( m_method == MAX )
m_maxtree.clear();
else
m_mintree.clear();
#else
m_tree.clear();
#endif
}

double compute() const
{
// Verify tree is not empty and lastValue is valid
// Last value can only ever be NaN if the "keep" nan option is used
#ifdef __linux__
if( likely( !isnan( m_lastval ) && ( ( m_method == MAX && m_maxtree.size() > 0 ) || m_mintree.size() > 0 ) ) )
{
switch( m_method )
{
case MIN:
{
if( m_mintree.size() == 1 )
if ( m_mintree.size() == 1 )
return 0;
return m_mintree.order_of_key( m_lastval );
return m_mintree.get<0>().find_rank( m_lastval );
}
case MAX:
{
if( m_maxtree.size() == 1 )
if ( m_maxtree.size() == 1 )
return 0;
return m_maxtree.size() - 1 - m_maxtree.order_of_key( m_lastval );
return m_maxtree.size() - 1 - m_maxtree.get<0>().find_rank( m_lastval );
}
case AVG:
{
// Need to iterate to find average rank
if( m_mintree.size() == 1 )
if ( m_mintree.size() == 1 )
return 0;

int min_rank = m_mintree.order_of_key( m_lastval );
int min_rank = m_mintree.get<0>().find_rank( m_lastval );
int max_rank = min_rank;
auto it = m_mintree.find_by_order( min_rank );
auto it = m_mintree.get<0>().nth( min_rank );
it++;
for( ; it != m_mintree.end() && *it == m_lastval ; it++ ) max_rank++;
for( ; it != m_mintree.end() && *it == m_lastval ; it++ ) max_rank++; // While this is in theory O(n), in reality this loop is only interated once, since there are likely no duplicate values or very few.
return ( double )( min_rank + max_rank ) / 2;
}

default:
break;
}
}
#else
if( likely( !isnan( m_lastval ) && m_tree.size() > 0 ) )
{
switch( m_method )
{
case MIN:
{
return std::distance( m_tree.begin(), m_tree.find( m_lastval ) );
}
case MAX:
{
auto end_range = m_tree.equal_range( m_lastval ).second;
return std::distance( m_tree.begin(), std::prev( end_range ) );
}
case AVG:
{
auto range = m_tree.equal_range( m_lastval );
return std::distance( m_tree.begin(), range.first ) + ( double )std::distance( range.first, std::prev( range.second ) ) / 2;
}
default:
break;
}
}
#endif

return std::numeric_limits<double>::quiet_NaN();
}

private:

#ifdef __linux__
ost<std::less_equal<double>> m_mintree;
ost<std::greater_equal<double>> m_maxtree;
#else
std::multiset<double> m_tree;
#endif
ost<std::less<double>> m_mintree;
ost<std::greater<double>> m_maxtree;
double m_lastval;

int64_t m_method;
Expand Down
1 change: 1 addition & 0 deletions vcpkg.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"abseil",
"arrow",
"boost-beast",
"boost-multi-index",
"brotli",
"exprtk",
"gtest",
Expand Down

0 comments on commit 5c81458

Please sign in to comment.