Skip to content

Commit

Permalink
Use projectSU3 to implement the projection.
Browse files Browse the repository at this point in the history
Revert changes in unitarize_links_quda.cu and unitarize_links.cuh.
  • Loading branch information
SaltyChiang committed Oct 5, 2023
1 parent 7b3044d commit 6e90b49
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 52 deletions.
40 changes: 3 additions & 37 deletions include/kernels/unitarize_links.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,19 @@
#include <color_spinor.h>
#include <svd_quda.h>
#include <kernel.h>
#include <kernels/gauge_phase.cuh>

namespace quda {

template <typename Float_, int nColor_, QudaReconstructType recon_, QudaStaggeredPhase phase_>
template <typename Float, int nColor_, QudaReconstructType recon_>
struct UnitarizeArg : kernel_param<> {
using Float = double;
using real = typename mapper<Float_>::type;
using real = typename mapper<Float>::type;
static constexpr int nColor = nColor_;
static constexpr QudaReconstructType recon = recon_;
static constexpr QudaStaggeredPhase phase = phase_;
typedef typename gauge_mapper<Float_,recon>::type Gauge;
typedef typename gauge_mapper<Float,recon>::type Gauge;
Gauge out;
const Gauge in;

int X[4]; // grid dimensions
double tBoundary;
int *fails;
const int max_iter;
const double unitarize_eps;
Expand All @@ -50,9 +46,6 @@ namespace quda {
svd_abs_error(svd_abs_error)
{
for (int dir=0; dir<4; ++dir) X[dir] = in.X()[dir];

bool last_node_in_t = (commCoords(3) == commDim(3)-1);
tBoundary = (Float)(last_node_in_t ? in.TBoundary() : QUDA_PERIODIC_T);
}
};

Expand Down Expand Up @@ -189,16 +182,6 @@ namespace quda {
return true;
} // unitarizeMILC

template <typename real, typename mat, typename Arg>
__host__ __device__ void specialUnitarizeLinkMILC(mat &out, const mat &in, const Arg &arg)
{
complex<real> det = getDeterminant(in);
real r = exp(-log(abs(det)) / Arg::nColor);
real alpha = atan2(det.imag(), det.real()) / Arg::nColor;

out = in * polar(r, -alpha);
} // specialUnitarizeLinkMILC

template <typename mat>
__host__ __device__ bool unitarizeLinkNewton(mat &out, const mat& in, int max_iter)
{
Expand Down Expand Up @@ -235,23 +218,6 @@ namespace quda {
if (arg.check_unitarization) {
if (result.isUnitary(arg.max_error) == false) atomic_fetch_add(arg.fails, 1);
}

if constexpr (Arg::phase == QUDA_STAGGERED_PHASE_CHROMA) { // Special unitraize the result for Chroma convention
int x[4];
getCoords(x, x_cb, arg.X, parity);

double phase;
switch (mu) {
case 0: phase = getPhase<0>(x[0], x[1], x[2], x[3], arg); break;
case 1: phase = getPhase<1>(x[0], x[1], x[2], x[3], arg); break;
case 2: phase = getPhase<2>(x[0], x[1], x[2], x[3], arg); break;
case 3: phase = getPhase<3>(x[0], x[1], x[2], x[3], arg); break;
}
v = result * phase;
specialUnitarizeLinkMILC<double>(result, v, arg);
result *= phase;
}

tmp = result;

arg.out(mu, x_cb, parity) = tmp;
Expand Down
13 changes: 13 additions & 0 deletions lib/interface_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3919,6 +3919,19 @@ void computeKSLinkQuda(void *fatlink, void *longlink, void *ulink, void *inlink,
errorQuda("Error in unitarization component of the hisq fattening: %d failures", *num_failures_h);
profileFatLink.TPSTOP(QUDA_PROFILE_COMPUTE);

// project onto SU(3) if using the Chroma convention
if (param->staggered_phase_type == QUDA_STAGGERED_PHASE_CHROMA) {
profileFatLink.TPSTART(QUDA_PROFILE_COMPUTE);
*num_failures_h = 0;
const double tol = cudaUnitarizedLink->Precision() == QUDA_DOUBLE_PRECISION ? 1e-15 : 2e-6;
if (cudaUnitarizedLink->StaggeredPhaseApplied()) cudaUnitarizedLink->removeStaggeredPhase();
projectSU3(*cudaUnitarizedLink, tol, num_failures_d);
if (!cudaUnitarizedLink->StaggeredPhaseApplied() && param->staggered_phase_applied) cudaUnitarizedLink->applyStaggeredPhase();
if(*num_failures_h>0)
errorQuda("Error in the SU(3) unitarization: %d failures\n", *num_failures_h);
profileFatLink.TPSTOP(QUDA_PROFILE_COMPUTE);
}

cudaUnitarizedLink->saveCPUField(cpuUnitarizedLink, profileFatLink);

profileFatLink.TPSTART(QUDA_PROFILE_FREE);
Expand Down
17 changes: 2 additions & 15 deletions lib/unitarize_links_quda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,8 @@ namespace quda {
void apply(const qudaStream_t &stream)
{
TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
if (in.StaggeredPhase() == QUDA_STAGGERED_PHASE_MILC) {
UnitarizeArg<Float, nColor, recon, QUDA_STAGGERED_PHASE_MILC> arg(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, svd_abs_error);
launch<Unitarize>(tp, stream, arg);
} else if (in.StaggeredPhase() == QUDA_STAGGERED_PHASE_CHROMA) {
UnitarizeArg<Float, nColor, recon, QUDA_STAGGERED_PHASE_CHROMA> arg(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, svd_abs_error);
launch<Unitarize>(tp, stream, arg);
} else if (in.StaggeredPhase() == QUDA_STAGGERED_PHASE_TIFR) {
UnitarizeArg<Float, nColor, recon, QUDA_STAGGERED_PHASE_TIFR> arg(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, svd_abs_error);
launch<Unitarize>(tp, stream, arg);
} else if (in.StaggeredPhase() == QUDA_STAGGERED_PHASE_NO) {
UnitarizeArg<Float, nColor, recon, QUDA_STAGGERED_PHASE_NO> arg(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, svd_abs_error);
launch<Unitarize>(tp, stream, arg);
} else {
errorQuda("Undefined phase type %d", in.StaggeredPhase());
}
launch<Unitarize>(tp, stream,
UnitarizeArg<Float, nColor, recon>(out, in, fails, max_iter, unitarize_eps, max_error, reunit_allow_svd, reunit_svd_only, svd_rel_error, svd_abs_error));
}

void preTune() { if (in.Gauge_p() == out.Gauge_p()) out.backup(); }
Expand Down

0 comments on commit 6e90b49

Please sign in to comment.