Skip to content

Commit

Permalink
Factor out complex type
Browse files Browse the repository at this point in the history
  • Loading branch information
adamant-pwn committed Nov 20, 2024
1 parent ffd4bdd commit 094f2a7
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 14 deletions.
4 changes: 2 additions & 2 deletions cp-algo/geometry/closest_pair.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ namespace cp_algo::geometry {
}
}
}
std::map<point, std::vector<int>> neigs;
md = ceil(sqrtl(md));
std::map<point, std::vector<size_t>> neigs;
md = (int64_t)ceil(sqrt((double)md));
for(size_t i = 0; i < n; i++) {
neigs[r[i] / md].push_back(i);
}
Expand Down
8 changes: 4 additions & 4 deletions cp-algo/geometry/point.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#ifndef CP_ALGO_GEOMETRY_POINT_HPP
#define CP_ALGO_GEOMETRY_POINT_HPP
#include "../random/rng.hpp"
#include "cp-algo/util/complex.hpp"
#include "cp-algo/random/rng.hpp"
#include <iostream>
#include <complex>
namespace cp_algo::geometry {
template<typename ftype>
struct point_t: public std::complex<ftype> {
using Base = std::complex<ftype>;
struct point_t: complex<ftype> {
using Base = complex<ftype>;
using Base::Base;

point_t(Base const& t): Base(t) {}
Expand Down
14 changes: 7 additions & 7 deletions cp-algo/math/fft.hpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
#ifndef CP_ALGO_MATH_FFT_HPP
#define CP_ALGO_MATH_FFT_HPP
#include "common.hpp"
#include "../number_theory/modint.hpp"
#include "cp-algo/number_theory/modint.hpp"
#include "cp-algo/util/complex.hpp"
#include <algorithm>
#include <complex>
#include <cassert>
#include <ranges>
#include <vector>
#include <bit>
#include <experimental/simd>

namespace cp_algo::math::fft {
using ftype = double;
using point = std::complex<ftype>;
using point = complex<ftype>;
using vftype = std::experimental::native_simd<ftype>;
using vpoint = std::complex<vftype>;
using vpoint = complex<vftype>;
static constexpr size_t flen = vftype::size();


Expand Down Expand Up @@ -74,7 +74,7 @@ namespace cp_algo::math::fft {
} else {
auto arg = std::numbers::pi / (ftype)n;
if constexpr(std::is_same_v<pt, point>) {
return {(ftype)cos(k * arg), (ftype)sin(k * arg)};
return {cos((ftype)k * arg), sin((ftype)k * arg)};
} else {
return pt{vftype{[&](auto i) {return cos(ftype(k + i) * arg);}},
vftype{[&](auto i) {return sin(ftype(k + i) * arg);}}};
Expand Down Expand Up @@ -140,11 +140,11 @@ namespace cp_algo::math::fft {
const cvector cvector::roots = []() {
cvector res(pre_roots);
for(size_t n = 1; n < res.size(); n *= 2) {
auto base = std::polar(1., std::numbers::pi / (ftype)n);
auto base = polar<ftype>(1., std::numbers::pi / (ftype)n);
point cur = 1;
for(size_t k = 0; k < n; k++) {
if((k & 15) == 0) {
cur = std::polar(1., std::numbers::pi * (ftype)k / (ftype)n);
cur = polar<ftype>(1., std::numbers::pi * (ftype)k / (ftype)n);
}
res.set(n + k, cur);
cur *= base;
Expand Down
2 changes: 1 addition & 1 deletion cp-algo/number_theory/modint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ namespace cp_algo::math {
static void switch_mod(Int nm) {
m = nm;
im = m % 2 ? inv2(-m) : 0;
r2 = (typename Base::UInt2)(-1) % m + 1;
r2 = static_cast<Base::UInt>(static_cast<Base::UInt2>(-1) % m + 1);
}

// Wrapper for temp switching
Expand Down
41 changes: 41 additions & 0 deletions cp-algo/util/complex.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#ifndef CP_ALGO_UTIL_COMPLEX_HPP
#define CP_ALGO_UTIL_COMPLEX_HPP
#include <cmath>
namespace cp_algo {
template<typename T>
struct complex {
T x, y;
constexpr complex() {}
constexpr complex(T x): x(x), y(0) {}
constexpr complex(T x, T y): x(x), y(y) {}
complex& operator *= (T t) {x *= t; y *= t; return *this;}
complex& operator /= (T t) {x /= t; y /= t; return *this;}
complex operator * (T t) const {return complex(*this) *= t;}
complex operator / (T t) const {return complex(*this) /= t;}
complex& operator += (complex t) {x += t.x; y += t.y; return *this;}
complex& operator -= (complex t) {x -= t.x; y -= t.y; return *this;}
complex operator * (complex t) const {return {x * t.x - y * t.y, x * t.y + y * t.x};}
complex operator / (complex t) const {return *this * t.conj() / t.norm();}
complex operator + (complex t) const {return complex(*this) += t;}
complex operator - (complex t) const {return complex(*this) -= t;}
complex& operator *= (complex t) {return *this = *this * t;}
complex& operator /= (complex t) {return *this = *this / t;}
complex operator - () const {return {-x, -y};}
complex conj() const {return {x, -y};}
T norm() const {return x * x + y * y;}
T abs() const {return std::sqrt(norm());}
T real() const {return x;}
T imag() const {return y;}
static complex polar(T r, T theta) {return {r * std::cos(theta), r * std::sin(theta)};}
auto operator <=> (complex const& t) const = default;
};
template<typename T>
complex<T> operator * (auto x, complex<T> y) {return y * x;}
template<typename T> complex<T> conj(complex<T> x) {return x.conj();}
template<typename T> T norm(complex<T> x) {return x.norm();}
template<typename T> T abs(complex<T> x) {return x.abs();}
template<typename T> T real(complex<T> x) {return x.real();}
template<typename T> T imag(complex<T> x) {return x.imag();}
template<typename T> complex<T> polar(T r, T theta) {return complex<T>::polar(r, theta);}
}
#endif // CP_ALGO_UTIL_COMPLEX_HPP

0 comments on commit 094f2a7

Please sign in to comment.