Skip to content

Commit

Permalink
Refactor + do not accummulate error
Browse files Browse the repository at this point in the history
  • Loading branch information
adamant-pwn committed Nov 21, 2024
1 parent 0f73e43 commit 63235e1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 41 deletions.
65 changes: 25 additions & 40 deletions cp-algo/math/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,8 @@ namespace cp_algo::math::fft {
using vpoint = complex<vftype>;
static constexpr size_t flen = vftype::size();


template<typename ft>
constexpr ft to_ft(auto x) {
return ft{} + x;
}
template<typename pt>
constexpr pt to_pt(point r) {
using ft = std::conditional_t<std::is_same_v<point, pt>, ftype, vftype>;
return {to_ft<ft>(r.real()), to_ft<ft>(r.imag())};
}
struct cvector {
static constexpr size_t pre_roots = 1 << 17;
static constexpr size_t pre_roots = 1 << 19;
std::vector<vftype> x, y;
cvector(size_t n) {
n = std::max(flen, std::bit_ceil(n));
Expand Down Expand Up @@ -67,32 +57,28 @@ namespace cp_algo::math::fft {
}
}
static const cvector roots;
template<class pt = point>
static pt root(size_t n, size_t k) {
if(n < pre_roots) {
template<class pt = point, bool precalc = false>
static pt root(size_t n, size_t k, auto &&arg) {
if(n < pre_roots && !precalc) {
return roots.get<pt>(n + k);
} else {
auto arg = std::numbers::pi / (ftype)n;
if constexpr(std::is_same_v<pt, point>) {
return {cos((ftype)k * arg), sin((ftype)k * arg)};
} else {
return pt{vftype{[&](auto i) {return cos(ftype(k + i) * arg);}},
vftype{[&](auto i) {return sin(ftype(k + i) * arg);}}};
}
return polar<typename pt::value_type>(1., arg);
}
}
template<class pt = point>
template<class pt = point, bool precalc = false>
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
ftype arg = std::numbers::pi / (ftype)n;
size_t step = sizeof(pt) / sizeof(point);
pt cur;
pt arg = to_pt<pt>(root<point>(n, step));
for(size_t i = 0; i < m; i += step) {
if(i % 32 == 0 || n < pre_roots) {
cur = root<pt>(n, i);
using ft = pt::value_type;
auto k = [&]() {
if constexpr(std::is_same_v<pt, point>) {
return ft{};
} else {
cur *= arg;
return ft{[](auto i) {return i;}};
}
callback(i, cur);
}();
for(size_t i = 0; i < m; i += step, k += (ftype)step) {
callback(i, root<pt, precalc>(n, i, arg * k));
}
}

Expand All @@ -106,15 +92,15 @@ namespace cp_algo::math::fft {
set(k + i, get<pt>(k) - t);
set(k, get<pt>(k) + t);
};
if(2 * i <= flen) {
if(i < flen) {
exec_on_roots(i, i, butterfly);
} else {
exec_on_roots<vpoint>(i, i, butterfly);
}
}
}
for(size_t k = 0; k < n; k += flen) {
set(k, get<vpoint>(k) /= to_pt<vpoint>((ftype)n));
set(k, get<vpoint>(k) /= (ftype)n);
}
}
void fft() {
Expand All @@ -128,7 +114,7 @@ namespace cp_algo::math::fft {
set(k, A);
set(k + i, B * rt);
};
if(2 * i <= flen) {
if(i < flen) {
exec_on_roots(i, i, butterfly);
} else {
exec_on_roots<vpoint>(i, i, butterfly);
Expand All @@ -140,14 +126,13 @@ namespace cp_algo::math::fft {
const cvector cvector::roots = []() {
cvector res(pre_roots);
for(size_t n = 1; n < res.size(); n *= 2) {
auto base = polar<ftype>(1., std::numbers::pi / (ftype)n);
point cur = 1;
for(size_t k = 0; k < n; k++) {
if((k & 15) == 0) {
cur = polar<ftype>(1., std::numbers::pi * (ftype)k / (ftype)n);
}
res.set(n + k, cur);
cur *= base;
auto propagate = [&](size_t k, auto rt) {
res.set(n + k, rt);
};
if(n < flen) {
res.exec_on_roots<point, true>(n, n, propagate);
} else {
res.exec_on_roots<vpoint, true>(n, n, propagate);
}
}
return res;
Expand Down
3 changes: 2 additions & 1 deletion cp-algo/util/complex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
namespace cp_algo {
template<typename T>
struct complex {
using value_type = T;
T x, y;
constexpr complex() {}
constexpr complex(T x): x(x), y(0) {}
Expand All @@ -26,7 +27,7 @@ namespace cp_algo {
T abs() const {return std::sqrt(norm());}
T real() const {return x;}
T imag() const {return y;}
static complex polar(T r, T theta) {return {r * std::cos(theta), r * std::sin(theta)};}
static complex polar(T r, T theta) {return {r * cos(theta), r * sin(theta)};}
auto operator <=> (complex const& t) const = default;
};
template<typename T>
Expand Down

0 comments on commit 63235e1

Please sign in to comment.