Skip to content

Commit

Permalink
fix read/write wfn
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Nov 24, 2024
1 parent 080ba43 commit 062ab25
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 21 deletions.
22 changes: 9 additions & 13 deletions pyblock2/driver/readwfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
if len(sys.argv) > 1:
arg_dic = {}
for i in range(1, len(sys.argv)):
if sys.argv[i] in ['-su2', '-sz', '-expect', '-reduntant']:
if sys.argv[i] in ['-su2', '-sz', '-expect', '-redundant']:
arg_dic[sys.argv[i][1:]] = ''
elif sys.argv[i].startswith('-'):
arg_dic[sys.argv[i][1:]] = sys.argv[i + 1]
Expand All @@ -36,7 +36,7 @@
(A) python readwfn.py -config dmrg.conf -out ./out
(B) python readwfn.py dmrg.conf
(C) python readwfn.py dmrg.conf -expect
(D) python readwfn.py dmrg.conf -reduntant
(D) python readwfn.py dmrg.conf -redundant
(E) python readwfn.py -integral FCIDUMP -prefix ./scratch -dot 2 -su2
(F) python readwfn.py -integral FCIDUMP -prefix ./scratch -dot 2 -sz
(G) python readwfn.py ... -sym c1
Expand All @@ -46,9 +46,9 @@
out: dir for storing block2 MPS
expect: if given, the energy expectation value of MPS
is calculated using block2 and printed at the end
reduntant: if given, the reduntant parameter in the
redundant: if given, the redundant parameter in the
StackBlock MPS will be retained.
Note that removing reduntant parameters do not
Note that removing redundant parameters do not
affect quality of MPS.
when no config file is given/available:
Expand Down Expand Up @@ -79,7 +79,7 @@
hf_occ = None
out_dir = "./out"
expect = "expect" in arg_dic
redunt = "reduntant" in arg_dic
redunt = "redundant" in arg_dic
mps_tags = ["KET"]
if "config" in arg_dic:
config = arg_dic["config"]
Expand Down Expand Up @@ -285,12 +285,10 @@ def swap_order_left(idx):
l, m, r = mps_info.left_dims[idx], hamil.basis[idx], mps_info.left_dims[idx + 1]
clm = StateInfo.get_connection_info(l, m, r)
for ik in range(r.n):
bbed = clm.n if ik == r.n - 1 else clm.n_states[ik + 1]
dx = []
g = 0
for bb in range(clm.n_states[ik], bbed):
ibba = clm.quanta[bb].data >> 16
ibbb = clm.quanta[bb].data & 0xFFFF
for bb in range(clm.acc_n_states[ik], clm.acc_n_states[ik + 1]):
ibba, ibbb = clm.ij_indices[bb]
nx = l.n_states[ibba] * m.n_states[ibbb]
dx.append((l.quanta[ibba], m.quanta[ibbb], g, g + nx))
g += nx
Expand All @@ -310,12 +308,10 @@ def swap_order_right(idx):
l, m, r = mps_info.right_dims[idx], hamil.basis[idx], mps_info.right_dims[idx + 1]
clm = StateInfo.get_connection_info(m, r, l)
for ik in range(l.n):
bbed = clm.n if ik == l.n - 1 else clm.n_states[ik + 1]
dx = []
g = 0
for bb in range(clm.n_states[ik], bbed):
ibba = clm.quanta[bb].data >> 16
ibbb = clm.quanta[bb].data & 0xFFFF
for bb in range(clm.acc_n_states[ik], clm.acc_n_states[ik + 1]):
ibba, ibbb = clm.ij_indices[bb]
nx = m.n_states[ibba] * r.n_states[ibbb]
dx.append((m.quanta[ibba], r.quanta[ibbb], g, g + nx))
g += nx
Expand Down
12 changes: 4 additions & 8 deletions pyblock2/driver/writewfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,10 @@ def swap_order_left(idx):
l, m, r = mps.info.left_dims[idx], hamil.basis[idx], mps.info.left_dims[idx + 1]
clm = B2StateInfo.get_connection_info(l, m, r)
for ik in range(r.n):
bbed = clm.n if ik == r.n - 1 else clm.n_states[ik + 1]
dx = []
g = 0
for bb in range(clm.n_states[ik], bbed):
ibba = clm.quanta[bb].data >> 16
ibbb = clm.quanta[bb].data & 0xFFFF
for bb in range(clm.acc_n_states[ik], clm.acc_n_states[ik + 1]):
ibba, ibbb = clm.ij_indices[bb]
nx = l.n_states[ibba] * m.n_states[ibbb]
dx.append((l.quanta[ibba], m.quanta[ibbb], g, g + nx))
g += nx
Expand All @@ -289,12 +287,10 @@ def swap_order_right(idx):
l, m, r = mps.info.right_dims[idx], hamil.basis[idx], mps.info.right_dims[idx + 1]
clm = B2StateInfo.get_connection_info(m, r, l)
for ik in range(l.n):
bbed = clm.n if ik == l.n - 1 else clm.n_states[ik + 1]
dx = []
g = 0
for bb in range(clm.n_states[ik], bbed):
ibba = clm.quanta[bb].data >> 16
ibbb = clm.quanta[bb].data & 0xFFFF
for bb in range(clm.acc_n_states[ik], clm.acc_n_states[ik + 1]):
ibba, ibbb = clm.ij_indices[bb]
nx = m.n_states[ibba] * r.n_states[ibbb]
dx.append((m.quanta[ibba], r.quanta[ibbb], g, g + nx))
g += nx
Expand Down

0 comments on commit 062ab25

Please sign in to comment.