Skip to content

Commit

Permalink
quantile
Browse files Browse the repository at this point in the history
Signed-off-by: Mohit Chhaya <mohitchhaya24@gmail.com>
  • Loading branch information
mrchhaya committed Jul 11, 2024
1 parent a198332 commit 3979a96
Showing 1 changed file with 51 additions and 46 deletions.
97 changes: 51 additions & 46 deletions cpp/csp/cppnodes/statsimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
#ifdef __linux__
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#else
#include <boost/multi_index_container.hpp>
#include <boost/multi_index/ordered_index.hpp>
#include <boost/multi_index/ranked_index.hpp>
#endif

namespace csp::cppnodes
Expand Down Expand Up @@ -1099,6 +1103,9 @@ void ost_erase( ost<Comparator> &t, double & v )
auto it = t.find_by_order( rank );
t.erase( it );
}
#else
template <typename Comparator>
using ost = boost::multi_index::multi_index_container<double, boost::multi_index::indexed_by<boost::multi_index::ranked_non_unique<boost::multi_index::identity<double>, Comparator>>>;
#endif

class Quantile
Expand Down Expand Up @@ -1214,52 +1221,50 @@ class Quantile
break;
}
#else
auto it = m_tree.begin();
std::advance( it, ft );
switch ( m_interpolation )
switch (m_interpolation)
{
case LINEAR:
if( ft == target )
{
qtl = *it;
}
else
{
double lower = *it;
double higher = *++it;
qtl = ( 1 - target + ft ) * lower + ( 1 - ct + target ) * higher;
}
break;
case LOWER:
qtl = *it;
break;
case HIGHER:
qtl = ( ft == ct ? *it : *++it );
break;
case MIDPOINT:
if( ft == target )
{
qtl = *it;
}
else
{
double lower = *it;
double higher = *++it;
qtl = ( higher+lower ) / 2;
}
break;
case NEAREST:
if( target - ft <= ct - target )
{
qtl = *it;
}
else
{
qtl = *++it;
}
break;
default:
break;
case LINEAR:
if (ft == target)
{
qtl = *m_tree.get<0>().nth(ft);
}
else
{
double lower = *m_tree.get<0>().nth(ft);
double higher = *m_tree.get<0>().nth(ct);
qtl = (1 - target + ft) * lower + (1 - ct + target) * higher;
}
break;
case LOWER:
qtl = *m_tree.get<0>().nth(ft);
break;
case HIGHER:
qtl = *m_tree.get<0>().nth(ct);
break;
case MIDPOINT:
if (ft == target)
{
qtl = *m_tree.get<0>().nth(ft);
}
else
{
double lower = *m_tree.get<0>().nth(ft);
double higher = *m_tree.get<0>().nth(ct);
qtl = (higher + lower) / 2;
}
break;
case NEAREST:
if (target - ft < ct - target)
{
qtl = *m_tree.get<0>().nth(ft);
}
else
{
qtl = *m_tree.get<0>().nth(ct);
}
break;
default:
break;
}
#endif
return qtl;
Expand All @@ -1270,7 +1275,7 @@ class Quantile
#ifdef __linux__
ost<std::less_equal<double>> m_tree;
#else
std::multiset<double> m_tree;
ost<std::less<double>> m_tree;
#endif
std::vector<Dictionary::Data> m_quants;
int64_t m_interpolation;
Expand Down

0 comments on commit 3979a96

Please sign in to comment.