diff --git a/include/kernels/unitarize_links.cuh b/include/kernels/unitarize_links.cuh index 62d526a553..f0faef8570 100644 --- a/include/kernels/unitarize_links.cuh +++ b/include/kernels/unitarize_links.cuh @@ -7,23 +7,19 @@ #include #include #include -#include namespace quda { - template + template struct UnitarizeArg : kernel_param<> { - using Float = double; - using real = typename mapper::type; + using real = typename mapper::type; static constexpr int nColor = nColor_; static constexpr QudaReconstructType recon = recon_; - static constexpr QudaStaggeredPhase phase = phase_; - typedef typename gauge_mapper::type Gauge; + typedef typename gauge_mapper::type Gauge; Gauge out; const Gauge in; int X[4]; // grid dimensions - double tBoundary; int *fails; const int max_iter; const double unitarize_eps; @@ -50,9 +46,6 @@ namespace quda { svd_abs_error(svd_abs_error) { for (int dir=0; dir<4; ++dir) X[dir] = in.X()[dir]; - - bool last_node_in_t = (commCoords(3) == commDim(3)-1); - tBoundary = (Float)(last_node_in_t ? in.TBoundary() : QUDA_PERIODIC_T); } }; @@ -189,16 +182,6 @@ namespace quda { return true; } // unitarizeMILC - template - __host__ __device__ void specialUnitarizeLinkMILC(mat &out, const mat &in, const Arg &arg) - { - complex det = getDeterminant(in); - real r = exp(-log(abs(det)) / Arg::nColor); - real alpha = atan2(det.imag(), det.real()) / Arg::nColor; - - out = in * polar(r, -alpha); - } // specialUnitarizeLinkMILC - template __host__ __device__ bool unitarizeLinkNewton(mat &out, const mat& in, int max_iter) { @@ -235,23 +218,6 @@ namespace quda { if (arg.check_unitarization) { if (result.isUnitary(arg.max_error) == false) atomic_fetch_add(arg.fails, 1); } - - if constexpr (Arg::phase == QUDA_STAGGERED_PHASE_CHROMA) { // Special unitraize the result for Chroma convention - int x[4]; - getCoords(x, x_cb, arg.X, parity); - - double phase; - switch (mu) { - case 0: phase = getPhase<0>(x[0], x[1], x[2], x[3], arg); break; - case 1: phase = getPhase<1>(x[0], x[1], x[2], x[3], arg); break; - case 2: phase = getPhase<2>(x[0], x[1], x[2], x[3], arg); break; - case 3: phase = getPhase<3>(x[0], x[1], x[2], x[3], arg); break; - } - v = result * phase; - specialUnitarizeLinkMILC(result, v, arg); - result *= phase; - } - tmp = result; arg.out(mu, x_cb, parity) = tmp; diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index e019312324..1da16e4895 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -3919,6 +3919,19 @@ void computeKSLinkQuda(void *fatlink, void *longlink, void *ulink, void *inlink, errorQuda("Error in unitarization component of the hisq fattening: %d failures", *num_failures_h); profileFatLink.TPSTOP(QUDA_PROFILE_COMPUTE); + // project onto SU(3) if using the Chroma convention + if (param->staggered_phase_type == QUDA_STAGGERED_PHASE_CHROMA) { + profileFatLink.TPSTART(QUDA_PROFILE_COMPUTE); + *num_failures_h = 0; + const double tol = cudaUnitarizedLink->Precision() == QUDA_DOUBLE_PRECISION ? 1e-15 : 2e-6; + if (cudaUnitarizedLink->StaggeredPhaseApplied()) cudaUnitarizedLink->removeStaggeredPhase(); + projectSU3(*cudaUnitarizedLink, tol, num_failures_d); + if (!cudaUnitarizedLink->StaggeredPhaseApplied() && param->staggered_phase_applied) cudaUnitarizedLink->applyStaggeredPhase(); + if(*num_failures_h>0) + errorQuda("Error in the SU(3) unitarization: %d failures\n", *num_failures_h); + profileFatLink.TPSTOP(QUDA_PROFILE_COMPUTE); + } + cudaUnitarizedLink->saveCPUField(cpuUnitarizedLink, profileFatLink); profileFatLink.TPSTART(QUDA_PROFILE_FREE); diff --git a/lib/unitarize_links_quda.cu b/lib/unitarize_links_quda.cu index 6c4d325c4d..fa006f0b4a 100644 --- a/lib/unitarize_links_quda.cu +++ b/lib/unitarize_links_quda.cu @@ -121,21 +121,8 @@ namespace quda { void apply(const qudaStream_t &stream) { TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); - if (in.StaggeredPhase() == QUDA_STAGGERED_PHASE_MILC) { - UnitarizeArg arg(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, svd_abs_error); - launch(tp, stream, arg); - } else if (in.StaggeredPhase() == QUDA_STAGGERED_PHASE_CHROMA) { - UnitarizeArg arg(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, svd_abs_error); - launch(tp, stream, arg); - } else if (in.StaggeredPhase() == QUDA_STAGGERED_PHASE_TIFR) { - UnitarizeArg arg(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, svd_abs_error); - launch(tp, stream, arg); - } else if (in.StaggeredPhase() == QUDA_STAGGERED_PHASE_NO) { - UnitarizeArg arg(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, svd_abs_error); - launch(tp, stream, arg); - } else { - errorQuda("Undefined phase type %d", in.StaggeredPhase()); - } + launch(tp, stream, + UnitarizeArg(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, svd_abs_error)); } void preTune() { if (in.Gauge_p() == out.Gauge_p()) out.backup(); }