Skip to content

Commit

Permalink
core: csf coeffs with reference
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Dec 12, 2024
1 parent 5cc88cd commit dfcc424
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6048,7 +6048,8 @@ def get_trans_6pdm(self, bra, ket, *args, **kwargs):
return self.get_npdm(ket, pdm_type=6, bra=bra, *args, **kwargs)

def get_csf_coefficients(
self, ket, cutoff=0.1, given_dets=None, max_print=200, fci_conv=False, iprint=1
self, ket, cutoff=0.1, given_dets=None, max_print=200, fci_conv=False,
max_excite=None, ref_det=None, iprint=1
):
"""
Find the dominant Configuration State Functions (CSFs, in SU2 mode)
Expand All @@ -6064,7 +6065,7 @@ def get_csf_coefficients(
that should be searched. Default is 0.1.
If ``cutoff == 0.0``, will compute coefficients for all CSF/DET,
which may take an exponential amount of time.
given_dets : None or list[str]
given_dets : None or list[str] or list[list[int]]
If not None, will compute the coefficients for the given CSF/DET set.
If ``cutoff != 0.0 and (given_dets == [] or given_dets is None)``,
will consider all possible CSF/DET with the absolute value of the coefficients
Expand All @@ -6084,6 +6085,11 @@ def get_csf_coefficients(
Default is False.
iprint : int
Verbosity. Default is 1.
max_excite : None or int
If not None, will only search CSF/DET that are at most max_excite with respect to
the reference given by ref_det. Default is None. Not working for SAny/SU2.
ref_det : None or str or list[int]
The reference CSF/DET. Default is None.
Returns:
dets : np.ndarray[np.uint8]
Expand All @@ -6100,6 +6106,7 @@ def get_csf_coefficients(
bw = self.bw
iprint = iprint >= 1 and (self.mpi is None or self.mpi.rank == self.mpi.root)
import numpy as np, time
assert (max_excite is None) == (ref_det is None)

if ket.center != 0:
ket = self.copy_mps(ket, tag="CSF-TMP")
Expand Down Expand Up @@ -6133,7 +6140,11 @@ def get_csf_coefficients(
else:
uniq.add(tuple(ddet))
dtrie.append(bw.b.VectorUInt8(ddet))
dtrie.evaluate(bw.bs.UnfusedMPS(ket), cutoff)
if max_excite is not None:
refx = [ddstr.index(x) for x in ref_det] if isinstance(ref_det, str) else ref_det
dtrie.evaluate(bw.bs.UnfusedMPS(ket), cutoff, max_excite, bw.b.VectorUInt8(refx))
else:
dtrie.evaluate(bw.bs.UnfusedMPS(ket), cutoff)
if fci_conv:
dtrie.convert_phase(bw.b.VectorInt(list(range(ket.n_sites))))
if iprint:
Expand Down

0 comments on commit dfcc424

Please sign in to comment.