Skip to content

Commit

Permalink
Merge pull request #173 from madcpf/snap
Browse files Browse the repository at this point in the history
[Quantum Chinese Chess] Add save_snapshot() and restore_last_snapshot() to QuantumWorld
  • Loading branch information
madcpf authored Dec 15, 2023
2 parents e54c205 + 3903474 commit eaa8963
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 17 deletions.
77 changes: 75 additions & 2 deletions unitary/alpha/quantum_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,22 @@ def clear(self) -> None:
"""
self.circuit = cirq.Circuit()
self.effect_history: List[Tuple[cirq.Circuit, Dict[QuantumObject, int]]] = []
# This variable is used to save the length of current effect history before each move is made,
# so that if we later undo we know how many effects we need to pop out, since each move could
# consist of several effects.
self.effect_history_length: List[int] = []
self.object_name_dict: Dict[str, QuantumObject] = {}
self.ancilla_names: Set[str] = set()
# When `compile_to_qubits` is True, this tracks the mapping of the
# original qudits to the compiled qubits.
self.compiled_qubits: Dict[cirq.Qid, List[cirq.Qid]] = {}
self.post_selection: Dict[QuantumObject, int] = {}
# This variable is used to save the qubit remapping dictionary before each move, so that if
# we later undo we know how to reverse the mapping.
self.qubit_remapping_dict: List[Dict[cirq.Qid, cirq.Qid]] = []
# This variable is used to save the length of qubit_remapping_dict before each move is made,
# so that if we later undo we know how to remap the qubits.
self.qubit_remapping_dict_length: List[int] = []

def copy(self) -> "QuantumWorld":
new_objects = []
Expand All @@ -95,7 +105,17 @@ def copy(self) -> "QuantumWorld":
(circuit.copy(), copy.copy(post_selection))
for circuit, post_selection in self.effect_history
]
new_world.effect_history_length = self.effect_history_length.copy()
new_world.post_selection = new_post_selection
# copy qubit_remapping_dict
for remap in self.qubit_remapping_dict:
new_dict = {}
for key_obj, value_obj in remap.items():
new_dict[
new_world.get_object_by_name(key_obj.name)
] = new_world.get_object_by_name(value_obj.name)
new_world.qubit_remapping_dict.append(new_dict)
new_world.qubit_remapping_dict_length = self.qubit_remapping_dict_length.copy()
return new_world

def add_object(self, obj: QuantumObject):
Expand Down Expand Up @@ -257,7 +277,8 @@ def add_effect(self, op_list: List[cirq.Operation]):
self._append_op(op)

def undo_last_effect(self):
"""Restores the `QuantumWorld` to the state before the last effect.
"""Restores the circuit and post selection dictionary of `QuantumWorld` to the
state before the last effect.
Note that pop() is considered to be an effect for the purposes
of this call.
Expand All @@ -269,6 +290,57 @@ def undo_last_effect(self):
raise IndexError("No effects to undo")
self.circuit, self.post_selection = self.effect_history.pop()

def save_snapshot(self) -> None:
"""Saves the current length of the effect history and qubit_remapping_dict.
Normally this could default to be called after every move made by player of your
game, so that later if the player choose to undo his last move, we could use
`restore_last_snapshot` to restore the quantum properties to the snapshot.
"""
self.effect_history_length.append(len(self.effect_history))
self.qubit_remapping_dict_length.append(len(self.qubit_remapping_dict))

def restore_last_snapshot(self) -> None:
"""Restores the `QuantumWorld` to the last snapshot (which was saved after the last move
finished), which includes
- reversing the mapping of qubits, if any,
- restoring the post selection dictionary,
- restoring the circuit.
"""
if (
len(self.effect_history_length) <= 1
or len(self.qubit_remapping_dict_length) <= 1
):
# length == 1 corresponds to the initial state, and no more restore could be made.
raise ValueError("Unable to restore any more.")

# Recover the mapping of qubits to the last snapshot, and remove any related post selection memory.
# Note that this need to be done before calling `undo_last_effect()`, otherwise the remapping does not
# work as expected.
self.qubit_remapping_dict_length.pop()
last_length = self.qubit_remapping_dict_length[-1]
while len(self.qubit_remapping_dict) > last_length:
qubit_remapping_dict = self.qubit_remapping_dict.pop()
if len(qubit_remapping_dict) == 0:
continue
# Reverse the mapping.
self.circuit = self.circuit.transform_qubits(
lambda q: qubit_remapping_dict.get(q, q)
)
# Clear relevant qubits from the post selection dictionary.
# TODO(): rethink if this is necessary, given that undo_last_effect()
# will also restore post selection dictionary.
for obj in qubit_remapping_dict.keys():
if obj in self.post_selection:
self.post_selection.pop(obj)

# Recover the effects up to the last snapshot by popping effects out of the
# effect history of the board until its length equals the last snapshot's length.
self.effect_history_length.pop()
last_length = self.effect_history_length[-1]
while len(self.effect_history) > last_length:
self.undo_last_effect()

def _suggest_num_reps(self, sample_size: int) -> int:
"""Guess the number of raw samples needed to get sample_size results.
Assume that each post-selection is about 50/50.
Expand Down Expand Up @@ -323,6 +395,7 @@ def unhook(self, object: QuantumObject) -> None:
object.qubit: new_ancilla.qubit,
new_ancilla.qubit: object.qubit,
}
self.qubit_remapping_dict.append(qubit_remapping_dict)
self.circuit = self.circuit.transform_qubits(
lambda q: qubit_remapping_dict.get(q, q)
)
Expand All @@ -348,7 +421,7 @@ def force_measurement(
qubit_remapping_dict.update(
{*zip(obj_qubits, new_obj_qubits), *zip(new_obj_qubits, obj_qubits)}
)

self.qubit_remapping_dict.append(qubit_remapping_dict)
self.circuit = self.circuit.transform_qubits(
lambda q: qubit_remapping_dict.get(q, q)
)
Expand Down
103 changes: 103 additions & 0 deletions unitary/alpha/quantum_world_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def test_copy(simulator, compile_to_qubits):
alpha.Flip()(light2)
assert board.pop([light1])[0] == Light.RED
assert board.pop([light2])[0] == Light.GREEN
board.save_snapshot()

board2 = board.copy()

Expand All @@ -345,9 +346,15 @@ def test_copy(simulator, compile_to_qubits):
assert board.circuit is not board2.circuit
assert board.effect_history == board2.effect_history
assert board.effect_history is not board2.effect_history
assert board.effect_history_length == board2.effect_history_length
assert board.qubit_remapping_dict_length == board2.qubit_remapping_dict_length
assert board.ancilla_names == board2.ancilla_names
assert board.ancilla_names is not board2.ancilla_names
assert len(board2.post_selection) == 2
assert [key.name for key in board2.qubit_remapping_dict[-1].keys()] == [
"l2",
"ancilla_l2_0",
], "Failed to copy qubit_remapping_dict correctly."

# Assert that they now evolve independently
board2.undo_last_effect()
Expand Down Expand Up @@ -775,3 +782,99 @@ def test_get_correlated_histogram_with_entangled_qobjects(simulator, compile_to_

histogram = world.get_correlated_histogram()
assert histogram.keys() == {(0, 0, 1, 1, 0), (0, 1, 0, 0, 1)}


@pytest.mark.parametrize(
("simulator", "compile_to_qubits"),
[
(cirq.Simulator, False),
(cirq.Simulator, True),
# Cannot use SparseSimulator without `compile_to_qubits` due to issue #78.
(alpha.SparseSimulator, True),
],
)
def test_save_and_restore_snapshot(simulator, compile_to_qubits):
light1 = alpha.QuantumObject("l1", Light.GREEN)
light2 = alpha.QuantumObject("l2", Light.RED)
light3 = alpha.QuantumObject("l3", Light.RED)
light4 = alpha.QuantumObject("l4", Light.RED)
light5 = alpha.QuantumObject("l5", Light.RED)

# Initial state.
world = alpha.QuantumWorld(
[light1, light2, light3, light4, light5],
sampler=simulator(),
compile_to_qubits=compile_to_qubits,
)
# Snapshot #0
world.save_snapshot()
circuit_0 = world.circuit.copy()
# one effect from Flip()
assert world.effect_history_length == [1]
assert world.qubit_remapping_dict_length == [0]
assert world.post_selection == {}

# First move.
alpha.Split()(light1, light2, light3)
# Snapshot #1
world.save_snapshot()
circuit_1 = world.circuit.copy()
# one more effect from Split()
assert world.effect_history_length == [1, 2]
assert world.qubit_remapping_dict_length == [0, 0]
assert world.post_selection == {}

# Second move, which includes multiple effects and post selection.
alpha.Flip()(light2)
alpha.Split()(light3, light4, light5)
world.force_measurement(light4, Light.RED)
world.unhook(light5)
# Snapshot #2
world.save_snapshot()
# 2 more effects from Flip() and Split()
assert world.effect_history_length == [1, 2, 4]
# 2 mapping from force_measurement() and unhook()
assert world.qubit_remapping_dict_length == [0, 0, 2]
# 1 post selection from force_measurement
assert len(world.post_selection) == 1
results = world.peek(
[light1, light2, light3, light4, light5], count=200, convert_to_enum=False
)
assert all(result[0] == 0 for result in results)
assert not all(result[1] == 0 for result in results)
assert all(result[2] == 0 for result in results)
assert all(result[3] == 0 for result in results)
assert all(result[4] == 0 for result in results)

# Restore to snapshot #1
world.restore_last_snapshot()
assert world.effect_history_length == [1, 2]
assert world.qubit_remapping_dict_length == [0, 0]
assert world.circuit == circuit_1
assert world.post_selection == {}
results = world.peek(
[light1, light2, light3, light4, light5], count=200, convert_to_enum=False
)
assert all(result[0] == 0 for result in results)
assert all(result[1] != result[2] for result in results)
assert all(result[3] == 0 for result in results)
assert all(result[4] == 0 for result in results)

# Restore to snapshot #0
world.restore_last_snapshot()
assert world.effect_history_length == [1]
assert world.qubit_remapping_dict_length == [0]
assert world.circuit == circuit_0
assert world.post_selection == {}
results = world.peek(
[light1, light2, light3, light4, light5], count=200, convert_to_enum=False
)
assert all(result[0] == 1 for result in results)
assert all(result[1] == 0 for result in results)
assert all(result[2] == 0 for result in results)
assert all(result[3] == 0 for result in results)
assert all(result[4] == 0 for result in results)

# Further restore would return a value error.
with pytest.raises(ValueError, match="Unable to restore any more."):
world.restore_last_snapshot()
86 changes: 71 additions & 15 deletions unitary/examples/quantum_chinese_chess/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,17 @@ def print_welcome(self) -> None:
self.players_name.append("Player_1" if len(name_1) == 0 else name_1)

def __init__(self):
self.players_name = []
self.players_name: List[str] = []
self.print_welcome()
self.board = Board.from_fen()
self.board.set_language(self.lang)
print(self.board)
self.game_state = GameState.CONTINUES
self.current_player = self.board.current_player
self.debug_level = 3
# This variable is used to save the classical properties of the whole board before each move is
# made, so that if we later undo we could recover the earlier classical state.
self.classical_properties_history: List[List[List[int]]] = []

@staticmethod
def parse_input_string(str_to_parse: str) -> Tuple[List[str], List[str]]:
Expand Down Expand Up @@ -442,15 +445,10 @@ def next_move(self) -> Tuple[bool, str]:
# TODO(): make it look like the normal board. Right now it's only for debugging purposes.
print(self.board.board.peek(convert_to_enum=False))
elif input_str.lower() == "undo":
output = "Undo last quantum effect."
# Right now it's only for debugging purposes, since it has following problems:
# TODO(): there are several problems here:
# 1) the classical piece information is not reversed back.
# ==> we may need to save the change of classical piece information of each step.
# 2) last move involved multiple effects.
# ==> we may need to save number of effects per move, and undo that number of times.
self.board.board.undo_last_effect()
return True, output
if self.undo():
return True, "Undoing."
return False, "Failed to undo."

else:
try:
# The move is success if no ValueError is raised.
Expand Down Expand Up @@ -496,6 +494,60 @@ def game_over(self) -> None:
# TODO(): add the following checks
# - If player 0 made N repeatd back-and_forth moves in a row.

def save_snapshot(self) -> None:
"""Saves the current length of the effect history, qubit_remapping_dict, and the current classical states of all pieces."""
# Save the current length of the effect history and qubit_remapping_dict.
self.board.board.save_snapshot()

# Save the classical states of all pieces.
snapshot = []
for row in range(10):
for col in "abcdefghi":
piece = self.board.board[f"{col}{row}"]
snapshot.append(
[piece.type_.value, piece.color.value, piece.is_entangled]
)
self.classical_properties_history.append(snapshot)

def undo(self) -> bool:
"""Undo the last move, which includes reset quantum effects and classical properties, and remapping
qubits.
Returns True if the undo is success, and False otherwise.
"""
world = self.board.board
if (
len(world.effect_history_length) <= 1
or len(world.qubit_remapping_dict_length) <= 1
or len(self.classical_properties_history) <= 1
):
# length == 1 corresponds to the initial state, and no more undo could be made.
return False

# Recover the mapping of qubits to the last snapshot, remove any related post selection memory,
# and recover the effects up to the last snapshot (which was saved after the last move finished).
try:
world.restore_last_snapshot()
except ValueError as err:
print(err)
return False
except Exception:
print("Unexpected error during undo.")
raise

# Recover the classical properties of all pieces to the last snapshot.
self.classical_properties_history.pop()
snapshot = self.classical_properties_history[-1]
index = 0
for row in range(10):
for col in "abcdefghi":
piece = world[f"{col}{row}"]
piece.type_ = Type(snapshot[index][0])
piece.color = Color(snapshot[index][1])
piece.is_entangled = snapshot[index][2]
index += 1
return True

def play(self) -> None:
"""The loop where each player takes turn to play."""
while True:
Expand All @@ -507,11 +559,15 @@ def play(self) -> None:
print("\nPlease re-enter your move.")
continue
print(output)
# TODO(): maybe we should not check game_over() when an undo is made.
# Check if the game is over.
self.game_over()
# TODO(): no need to do sampling if the last move was CLASSICAL.
self.update_board_by_sampling()
if output != "Undoing.":
# Check if the game is over.
self.game_over()
# Update any empty or occupied pieces' classical state.
# TODO(): no need to do sampling if the last move was CLASSICAL.
probs = self.update_board_by_sampling()
# Save the current states.
self.save_snapshot()
# TODO(): pass probs into the following method to print probabilities.
print(self.board)
if self.game_state == GameState.CONTINUES:
# If the game continues, switch the player.
Expand Down
Loading

0 comments on commit eaa8963

Please sign in to comment.