Skip to content

Commit

Permalink
Replace Gram-Schmidt method by QR for random Chain initialization (#175)
Browse files Browse the repository at this point in the history
* Replace Gram-Schmidt method by QR for random Chain initialization

* Fix typo

* Format code

* Fix function for odd numbered chain

* Change for-loop back into map

* Apply suggestions from code review

---------

Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>
  • Loading branch information
jofrevalles and mofeing authored Aug 1, 2024
1 parent b516d66 commit 2260133
Showing 1 changed file with 13 additions and 30 deletions.
43 changes: 13 additions & 30 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using LinearAlgebra
using Random
using Muscle: gramschmidt!

struct Chain <: Ansatz
super::Quantum
Expand Down Expand Up @@ -223,6 +222,7 @@ function Base.rand(rng::AbstractRNG, ::Type{A}, ::Type{B}, ::Type{S}; kwargs...)
return rand(rng, ChainSampler{B,S}(; kwargs...), B, S)
end

# TODO let choose the orthogonality center
function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open}, ::Type{State})
n = sampler.parameters.n
χ = sampler.parameters.χ
Expand All @@ -238,27 +238,19 @@ function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open},
(isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr))
end

# fix for first site
i == 1 && ((χl, χr) = (χr, 1))

# orthogonalize by Gram-Schmidt algorithm
A = gramschmidt!(rand(rng, T, χl, χr * p))
# orthogonalize by QR factorization
F = lq!(rand(rng, T, χl, p * χr))

A = reshape(A, χl, χr, p)
permutedims(A, (3, 1, 2))
reshape(Matrix(F.Q), χl, p, χr)
end

# reshape boundary sites
arrays[1] = reshape(arrays[1], p, p)
arrays[n] = reshape(arrays[n], p, p)

# normalize state
arrays[1] ./= sqrt(p)

return Chain(State(), Open(), arrays)
return Chain(State(), Open(), arrays; order=(:l, :o, :r))
end

# TODO let choose the orthogonality center
# TODO different input/output physical dims
function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open}, ::Type{Operator})
n = sampler.parameters.n
Expand All @@ -277,26 +269,17 @@ function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open},
(isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr))
end

shape = if i == 1
(χr, ip, op)
elseif i == n
(χl, ip, op)
else
(χl, χr, ip, op)
end

# orthogonalize by Gram-Schmidt algorithm
A = gramschmidt!(rand(rng, T, shape[1], prod(shape[2:end])))
A = reshape(A, shape)

(i == 1 || i == n) ? permutedims(A, (2, 3, 1)) : permutedims(A, (3, 4, 1, 2))
# orthogonalize by QR factorization
F = lq!(rand(rng, T, χl, ip * op * χr))
reshape(Matrix(F.Q), χl, ip, op, χr)
end

# normalize
ζ = min(χ, ip * op)
arrays[1] ./= sqrt)
# reshape boundary sites
arrays[1] = reshape(arrays[1], p, p, min(χ, ip * op))
arrays[n] = reshape(arrays[n], min(χ, ip * op), p, p)

return Chain(Operator(), Open(), arrays)
# TODO order might not be the best for performance
return Chain(Operator(), Open(), arrays; order=(:l, :i, :o, :r))
end

Tenet.contract(tn::Chain, query::Symbol, args...; kwargs...) = contract!(copy(tn), Val(query), args...; kwargs...)
Expand Down

0 comments on commit 2260133

Please sign in to comment.