Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
madcpf committed Aug 1, 2024
1 parent 8893ced commit bd4278b
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 18 deletions.
95 changes: 78 additions & 17 deletions unitary/alpha/quantum_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
import numpy as np
from itertools import combinations
import pandas as pd
# from tabulate import tabulate
# from texttable import Texttable
# from prettytable import PrettyTable
# from terminaltables import AsciiTable


class QuantumWorld:
Expand Down Expand Up @@ -121,9 +117,9 @@ def copy(self) -> "QuantumWorld":
for remap in self.qubit_remapping_dict:
new_dict = {}
for key_obj, value_obj in remap.items():
new_dict[new_world.get_object_by_name(key_obj.name)] = (
new_world.get_object_by_name(value_obj.name)
)
new_dict[
new_world.get_object_by_name(key_obj.name)
] = new_world.get_object_by_name(value_obj.name)
new_world.qubit_remapping_dict.append(new_dict)
new_world.qubit_remapping_dict_length = self.qubit_remapping_dict_length.copy()
return new_world
Expand Down Expand Up @@ -694,38 +690,103 @@ def density_matrix(
2**num_shown_qubits, 2**num_shown_qubits
)

def measure_entanglement(self, objects: Optional[Sequence[QuantumObject]] = None) -> float:
def measure_entanglement(
self, objects: Optional[Sequence[QuantumObject]] = None
) -> float:
"""Measures the entanglement (i.e. quantum mutual information) of the given objects.
See https://en.wikipedia.org/wiki/Quantum_mutual_information for the formula.
Parameters:
objects: quantum objects among which the entanglement will be calculated
objects: quantum objects among which the entanglement will be calculated
(currently only qubits are supported). If not specified, all current
quantum objects will be used. If specified, at least two quantum
objects are expected.
Returns:
The quantum mutual information. For 2 qubits it's defined as S_1 + S_2 - S_12,
The quantum mutual information. For 2 qubits it's defined as S_1 + S_2 - S_12,
where S denotes (reduced) von Neumann entropy.
"""
num_involved_objects = len(objects) if objects is not None else len(self.object_name_dict.values())
num_involved_objects = (
len(objects) if objects is not None else len(self.object_name_dict.values())
)

if num_involved_objects < 2:
raise ValueError(f"Could not calculate entanglement for {num_involved_objects} qubit. "
"At least 2 qubits are required.")
raise ValueError(
f"Could not calculate entanglement for {num_involved_objects} qubit. "
"At least 2 qubits are required."
)

involved_objects = objects if objects is not None else list(self.object_name_dict.values())
involved_objects = (
objects if objects is not None else list(self.object_name_dict.values())
)

density_matrix = self.density_matrix(involved_objects)
reshaped_density_matrix = density_matrix.reshape(tuple([2, 2] * num_involved_objects))
reshaped_density_matrix = density_matrix.reshape((2, 2) * num_involved_objects)
result = 0.0
for comb in combinations(range(num_involved_objects), num_involved_objects - 1):
reshaped_partial_density_matrix = cirq.partial_trace(reshaped_density_matrix, list(comb))
partial_density_matrix = reshaped_partial_density_matrix.reshape(2 ** (num_involved_objects - 1), 2 ** (num_involved_objects - 1))
reshaped_partial_density_matrix = cirq.partial_trace(
reshaped_density_matrix, list(comb)
)
partial_density_matrix = reshaped_partial_density_matrix.reshape(
2 ** (num_involved_objects - 1), 2 ** (num_involved_objects - 1)
)
result += cirq.von_neumann_entropy(partial_density_matrix, validate=False)
result -= cirq.von_neumann_entropy(density_matrix, validate=False)
return result

def print_entanglement_table(self, count: int = 1000) -> None:
"""Peek the current quantum world `count` times, and calculate pair-wise entanglement
(i.e. quantum mutual information) for each pair of quantum objects.
See https://en.wikipedia.org/wiki/Quantum_mutual_information for the formula. And print
the results out in a table.
Parameters:
count: Number of measurements.
"""
objects = list(self.object_name_dict.values())
num_qubits = len(objects)
if num_qubits < 2:
raise ValueError(
f"There is only {num_qubits} qubit in the quantum world. "
"At least 2 qubits are required to calculate entanglements."
)
# Peek the current world `count` times and get the results.
histogram = self.get_correlated_histogram(objects, count)

# Get an estimate of the state vector.
state_vector = np.array([0.0] * (2**num_qubits))
for key, val in histogram.items():
state_vector += self.__to_state_vector__(key) * np.sqrt(val * 1.0 / count)
density_matrix = np.outer(state_vector, state_vector)
reshaped_density_matrix = density_matrix.reshape((2, 2) * num_qubits)

entropy = [0.0] * num_qubits
entropy_pair = np.zeros((num_qubits, num_qubits))
entanglement = np.zeros((num_qubits, num_qubits))
for i in range(num_qubits - 1):
for j in range(i + 1, num_qubits):
density_matrix_ij = cirq.partial_trace(reshaped_density_matrix, [i, j])
entropy_pair[i][j] = cirq.von_neumann_entropy(
density_matrix_ij.reshape(4, 4), validate=False
)
if i == 0:
# Fill in entropy [0]
if j == i + 1:
density_matrix_i = cirq.partial_trace(density_matrix_ij, [0])
entropy[i] = cirq.von_neumann_entropy(
density_matrix_i, validate=False
)
# Fill in entropy [1 to num_qubit - 1]
density_matrix_j = cirq.partial_trace(density_matrix_ij, [1])
entropy[j] = cirq.von_neumann_entropy(
density_matrix_j, validate=False
)
entanglement[i][j] = entropy[i] + entropy[j] - entropy_pair[i][j]
entanglement[j][i] = entanglement[i][j]
names = list(self.object_name_dict.keys())
data_frame = pd.DataFrame(entanglement, index=names, columns=names)
print(data_frame.round(1))

def __getitem__(self, name: str) -> QuantumObject:
quantum_object = self.object_name_dict.get(name, None)
if not quantum_object:
Expand Down
55 changes: 54 additions & 1 deletion unitary/alpha/quantum_world_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import unitary.alpha as alpha
import unitary.alpha.qudit_gates as qudit_gates
import io
import contextlib


class Light(enum.Enum):
Expand Down Expand Up @@ -979,5 +981,56 @@ def test_measure_entanglement(simulator, compile_to_qubits):
# Test with objects=None.
assert round(board.measure_entanglement(), 1) == 2.0
# Supplying one object would return a value error.
with pytest.raises(ValueError, match="Could not calculate entanglement for 1 qubit."):
with pytest.raises(
ValueError, match="Could not calculate entanglement for 1 qubit."
):
board.measure_entanglement([light1])


@pytest.mark.parametrize(
("simulator", "compile_to_qubits"),
[
(cirq.Simulator, False),
(cirq.Simulator, True),
# Cannot use SparseSimulator without `compile_to_qubits` due to issue #78.
(alpha.SparseSimulator, True),
],
)
def test_print_entanglement_table(simulator, compile_to_qubits):
rho_green = np.reshape([0, 0, 0, 1], (2, 2))
rho_red = np.reshape([1, 0, 0, 0], (2, 2))
light1 = alpha.QuantumObject("red1", Light.RED)
light2 = alpha.QuantumObject("green", Light.GREEN)
light3 = alpha.QuantumObject("red2", Light.RED)
board = alpha.QuantumWorld(
[light1, light2, light3],
sampler=simulator(),
compile_to_qubits=compile_to_qubits,
)
# f = io.StringIO()
# with contextlib.redirect_stdout(f):
# board.print_entanglement_table()
# assert (
# f.getvalue()
# in """
# red1 green red2
# red1 0.0 0.0 0.0
# green 0.0 0.0 0.0
# red2 0.0 0.0 0.0
# """
# )

alpha.Superposition()(light2)
alpha.quantum_if(light2).apply(alpha.Flip())(light3)
f = io.StringIO()
with contextlib.redirect_stdout(f):
board.print_entanglement_table()
assert (
f.getvalue()
in """
red1 green red2
red1 0.0 0.0 0.0
green 0.0 0.0 2.0
red2 0.0 2.0 0.0
"""
)

0 comments on commit bd4278b

Please sign in to comment.