From dfcc424cd76246b30a9ced1a8c3e22f4eae461f0 Mon Sep 17 00:00:00 2001 From: Huanchen Zhai Date: Wed, 11 Dec 2024 22:24:48 -0500 Subject: [PATCH] core: csf coeffs with reference --- pyblock2/driver/core.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/pyblock2/driver/core.py b/pyblock2/driver/core.py index 2207eb89..39ff6776 100644 --- a/pyblock2/driver/core.py +++ b/pyblock2/driver/core.py @@ -6048,7 +6048,8 @@ def get_trans_6pdm(self, bra, ket, *args, **kwargs): return self.get_npdm(ket, pdm_type=6, bra=bra, *args, **kwargs) def get_csf_coefficients( - self, ket, cutoff=0.1, given_dets=None, max_print=200, fci_conv=False, iprint=1 + self, ket, cutoff=0.1, given_dets=None, max_print=200, fci_conv=False, + max_excite=None, ref_det=None, iprint=1 ): """ Find the dominant Configuration State Functions (CSFs, in SU2 mode) @@ -6064,7 +6065,7 @@ def get_csf_coefficients( that should be searched. Default is 0.1. If ``cutoff == 0.0``, will compute coefficients for all CSF/DET, which may take an exponential amount of time. - given_dets : None or list[str] + given_dets : None or list[str] or list[list[int]] If not None, will compute the coefficients for the given CSF/DET set. If ``cutoff != 0.0 and (given_dets == [] or given_dets is None)``, will consider all possible CSF/DET with the absolute value of the coefficients @@ -6084,6 +6085,11 @@ def get_csf_coefficients( Default is False. iprint : int Verbosity. Default is 1. + max_excite : None or int + If not None, will only search CSF/DET that are at most max_excite with respect to + the reference given by ref_det. Default is None. Not working for SAny/SU2. + ref_det : None or str or list[int] + The reference CSF/DET. Default is None. Returns: dets : np.ndarray[np.uint8] @@ -6100,6 +6106,7 @@ def get_csf_coefficients( bw = self.bw iprint = iprint >= 1 and (self.mpi is None or self.mpi.rank == self.mpi.root) import numpy as np, time + assert (max_excite is None) == (ref_det is None) if ket.center != 0: ket = self.copy_mps(ket, tag="CSF-TMP") @@ -6133,7 +6140,11 @@ def get_csf_coefficients( else: uniq.add(tuple(ddet)) dtrie.append(bw.b.VectorUInt8(ddet)) - dtrie.evaluate(bw.bs.UnfusedMPS(ket), cutoff) + if max_excite is not None: + refx = [ddstr.index(x) for x in ref_det] if isinstance(ref_det, str) else ref_det + dtrie.evaluate(bw.bs.UnfusedMPS(ket), cutoff, max_excite, bw.b.VectorUInt8(refx)) + else: + dtrie.evaluate(bw.bs.UnfusedMPS(ket), cutoff) if fci_conv: dtrie.convert_phase(bw.b.VectorInt(list(range(ket.n_sites)))) if iprint: