diff --git a/unitary/alpha/quantum_world.py b/unitary/alpha/quantum_world.py index 5cd70617..28b291e4 100644 --- a/unitary/alpha/quantum_world.py +++ b/unitary/alpha/quantum_world.py @@ -69,12 +69,22 @@ def clear(self) -> None: """ self.circuit = cirq.Circuit() self.effect_history: List[Tuple[cirq.Circuit, Dict[QuantumObject, int]]] = [] + # This variable is used to save the length of current effect history before each move is made, + # so that if we later undo we know how many effects we need to pop out, since each move could + # consist of several effects. + self.effect_history_length: List[int] = [] self.object_name_dict: Dict[str, QuantumObject] = {} self.ancilla_names: Set[str] = set() # When `compile_to_qubits` is True, this tracks the mapping of the # original qudits to the compiled qubits. self.compiled_qubits: Dict[cirq.Qid, List[cirq.Qid]] = {} self.post_selection: Dict[QuantumObject, int] = {} + # This variable is used to save the qubit remapping dictionary before each move, so that if + # we later undo we know how to reverse the mapping. + self.qubit_remapping_dict: List[Dict[cirq.Qid, cirq.Qid]] = [] + # This variable is used to save the length of qubit_remapping_dict before each move is made, + # so that if we later undo we know how to remap the qubits. + self.qubit_remapping_dict_length: List[int] = [] def copy(self) -> "QuantumWorld": new_objects = [] @@ -95,7 +105,17 @@ def copy(self) -> "QuantumWorld": (circuit.copy(), copy.copy(post_selection)) for circuit, post_selection in self.effect_history ] + new_world.effect_history_length = self.effect_history_length.copy() new_world.post_selection = new_post_selection + # copy qubit_remapping_dict + for remap in self.qubit_remapping_dict: + new_dict = {} + for key_obj, value_obj in remap.items(): + new_dict[ + new_world.get_object_by_name(key_obj.name) + ] = new_world.get_object_by_name(value_obj.name) + new_world.qubit_remapping_dict.append(new_dict) + new_world.qubit_remapping_dict_length = self.qubit_remapping_dict_length.copy() return new_world def add_object(self, obj: QuantumObject): @@ -257,7 +277,8 @@ def add_effect(self, op_list: List[cirq.Operation]): self._append_op(op) def undo_last_effect(self): - """Restores the `QuantumWorld` to the state before the last effect. + """Restores the circuit and post selection dictionary of `QuantumWorld` to the + state before the last effect. Note that pop() is considered to be an effect for the purposes of this call. @@ -269,6 +290,57 @@ def undo_last_effect(self): raise IndexError("No effects to undo") self.circuit, self.post_selection = self.effect_history.pop() + def save_snapshot(self) -> None: + """Saves the current length of the effect history and qubit_remapping_dict. + + Normally this could default to be called after every move made by player of your + game, so that later if the player choose to undo his last move, we could use + `restore_last_snapshot` to restore the quantum properties to the snapshot. + """ + self.effect_history_length.append(len(self.effect_history)) + self.qubit_remapping_dict_length.append(len(self.qubit_remapping_dict)) + + def restore_last_snapshot(self) -> None: + """Restores the `QuantumWorld` to the last snapshot (which was saved after the last move + finished), which includes + - reversing the mapping of qubits, if any, + - restoring the post selection dictionary, + - restoring the circuit. + """ + if ( + len(self.effect_history_length) <= 1 + or len(self.qubit_remapping_dict_length) <= 1 + ): + # length == 1 corresponds to the initial state, and no more restore could be made. + raise ValueError("Unable to restore any more.") + + # Recover the mapping of qubits to the last snapshot, and remove any related post selection memory. + # Note that this need to be done before calling `undo_last_effect()`, otherwise the remapping does not + # work as expected. + self.qubit_remapping_dict_length.pop() + last_length = self.qubit_remapping_dict_length[-1] + while len(self.qubit_remapping_dict) > last_length: + qubit_remapping_dict = self.qubit_remapping_dict.pop() + if len(qubit_remapping_dict) == 0: + continue + # Reverse the mapping. + self.circuit = self.circuit.transform_qubits( + lambda q: qubit_remapping_dict.get(q, q) + ) + # Clear relevant qubits from the post selection dictionary. + # TODO(): rethink if this is necessary, given that undo_last_effect() + # will also restore post selection dictionary. + for obj in qubit_remapping_dict.keys(): + if obj in self.post_selection: + self.post_selection.pop(obj) + + # Recover the effects up to the last snapshot by popping effects out of the + # effect history of the board until its length equals the last snapshot's length. + self.effect_history_length.pop() + last_length = self.effect_history_length[-1] + while len(self.effect_history) > last_length: + self.undo_last_effect() + def _suggest_num_reps(self, sample_size: int) -> int: """Guess the number of raw samples needed to get sample_size results. Assume that each post-selection is about 50/50. @@ -323,6 +395,7 @@ def unhook(self, object: QuantumObject) -> None: object.qubit: new_ancilla.qubit, new_ancilla.qubit: object.qubit, } + self.qubit_remapping_dict.append(qubit_remapping_dict) self.circuit = self.circuit.transform_qubits( lambda q: qubit_remapping_dict.get(q, q) ) @@ -348,7 +421,7 @@ def force_measurement( qubit_remapping_dict.update( {*zip(obj_qubits, new_obj_qubits), *zip(new_obj_qubits, obj_qubits)} ) - + self.qubit_remapping_dict.append(qubit_remapping_dict) self.circuit = self.circuit.transform_qubits( lambda q: qubit_remapping_dict.get(q, q) ) diff --git a/unitary/alpha/quantum_world_test.py b/unitary/alpha/quantum_world_test.py index 30b02941..508c3901 100644 --- a/unitary/alpha/quantum_world_test.py +++ b/unitary/alpha/quantum_world_test.py @@ -326,6 +326,7 @@ def test_copy(simulator, compile_to_qubits): alpha.Flip()(light2) assert board.pop([light1])[0] == Light.RED assert board.pop([light2])[0] == Light.GREEN + board.save_snapshot() board2 = board.copy() @@ -345,9 +346,15 @@ def test_copy(simulator, compile_to_qubits): assert board.circuit is not board2.circuit assert board.effect_history == board2.effect_history assert board.effect_history is not board2.effect_history + assert board.effect_history_length == board2.effect_history_length + assert board.qubit_remapping_dict_length == board2.qubit_remapping_dict_length assert board.ancilla_names == board2.ancilla_names assert board.ancilla_names is not board2.ancilla_names assert len(board2.post_selection) == 2 + assert [key.name for key in board2.qubit_remapping_dict[-1].keys()] == [ + "l2", + "ancilla_l2_0", + ], "Failed to copy qubit_remapping_dict correctly." # Assert that they now evolve independently board2.undo_last_effect() @@ -775,3 +782,99 @@ def test_get_correlated_histogram_with_entangled_qobjects(simulator, compile_to_ histogram = world.get_correlated_histogram() assert histogram.keys() == {(0, 0, 1, 1, 0), (0, 1, 0, 0, 1)} + + +@pytest.mark.parametrize( + ("simulator", "compile_to_qubits"), + [ + (cirq.Simulator, False), + (cirq.Simulator, True), + # Cannot use SparseSimulator without `compile_to_qubits` due to issue #78. + (alpha.SparseSimulator, True), + ], +) +def test_save_and_restore_snapshot(simulator, compile_to_qubits): + light1 = alpha.QuantumObject("l1", Light.GREEN) + light2 = alpha.QuantumObject("l2", Light.RED) + light3 = alpha.QuantumObject("l3", Light.RED) + light4 = alpha.QuantumObject("l4", Light.RED) + light5 = alpha.QuantumObject("l5", Light.RED) + + # Initial state. + world = alpha.QuantumWorld( + [light1, light2, light3, light4, light5], + sampler=simulator(), + compile_to_qubits=compile_to_qubits, + ) + # Snapshot #0 + world.save_snapshot() + circuit_0 = world.circuit.copy() + # one effect from Flip() + assert world.effect_history_length == [1] + assert world.qubit_remapping_dict_length == [0] + assert world.post_selection == {} + + # First move. + alpha.Split()(light1, light2, light3) + # Snapshot #1 + world.save_snapshot() + circuit_1 = world.circuit.copy() + # one more effect from Split() + assert world.effect_history_length == [1, 2] + assert world.qubit_remapping_dict_length == [0, 0] + assert world.post_selection == {} + + # Second move, which includes multiple effects and post selection. + alpha.Flip()(light2) + alpha.Split()(light3, light4, light5) + world.force_measurement(light4, Light.RED) + world.unhook(light5) + # Snapshot #2 + world.save_snapshot() + # 2 more effects from Flip() and Split() + assert world.effect_history_length == [1, 2, 4] + # 2 mapping from force_measurement() and unhook() + assert world.qubit_remapping_dict_length == [0, 0, 2] + # 1 post selection from force_measurement + assert len(world.post_selection) == 1 + results = world.peek( + [light1, light2, light3, light4, light5], count=200, convert_to_enum=False + ) + assert all(result[0] == 0 for result in results) + assert not all(result[1] == 0 for result in results) + assert all(result[2] == 0 for result in results) + assert all(result[3] == 0 for result in results) + assert all(result[4] == 0 for result in results) + + # Restore to snapshot #1 + world.restore_last_snapshot() + assert world.effect_history_length == [1, 2] + assert world.qubit_remapping_dict_length == [0, 0] + assert world.circuit == circuit_1 + assert world.post_selection == {} + results = world.peek( + [light1, light2, light3, light4, light5], count=200, convert_to_enum=False + ) + assert all(result[0] == 0 for result in results) + assert all(result[1] != result[2] for result in results) + assert all(result[3] == 0 for result in results) + assert all(result[4] == 0 for result in results) + + # Restore to snapshot #0 + world.restore_last_snapshot() + assert world.effect_history_length == [1] + assert world.qubit_remapping_dict_length == [0] + assert world.circuit == circuit_0 + assert world.post_selection == {} + results = world.peek( + [light1, light2, light3, light4, light5], count=200, convert_to_enum=False + ) + assert all(result[0] == 1 for result in results) + assert all(result[1] == 0 for result in results) + assert all(result[2] == 0 for result in results) + assert all(result[3] == 0 for result in results) + assert all(result[4] == 0 for result in results) + + # Further restore would return a value error. + with pytest.raises(ValueError, match="Unable to restore any more."): + world.restore_last_snapshot() diff --git a/unitary/examples/quantum_chinese_chess/chess.py b/unitary/examples/quantum_chinese_chess/chess.py index 924da2ce..58cc2656 100644 --- a/unitary/examples/quantum_chinese_chess/chess.py +++ b/unitary/examples/quantum_chinese_chess/chess.py @@ -69,7 +69,7 @@ def print_welcome(self) -> None: self.players_name.append("Player_1" if len(name_1) == 0 else name_1) def __init__(self): - self.players_name = [] + self.players_name: List[str] = [] self.print_welcome() self.board = Board.from_fen() self.board.set_language(self.lang) @@ -77,6 +77,9 @@ def __init__(self): self.game_state = GameState.CONTINUES self.current_player = self.board.current_player self.debug_level = 3 + # This variable is used to save the classical properties of the whole board before each move is + # made, so that if we later undo we could recover the earlier classical state. + self.classical_properties_history: List[List[List[int]]] = [] @staticmethod def parse_input_string(str_to_parse: str) -> Tuple[List[str], List[str]]: @@ -442,15 +445,10 @@ def next_move(self) -> Tuple[bool, str]: # TODO(): make it look like the normal board. Right now it's only for debugging purposes. print(self.board.board.peek(convert_to_enum=False)) elif input_str.lower() == "undo": - output = "Undo last quantum effect." - # Right now it's only for debugging purposes, since it has following problems: - # TODO(): there are several problems here: - # 1) the classical piece information is not reversed back. - # ==> we may need to save the change of classical piece information of each step. - # 2) last move involved multiple effects. - # ==> we may need to save number of effects per move, and undo that number of times. - self.board.board.undo_last_effect() - return True, output + if self.undo(): + return True, "Undoing." + return False, "Failed to undo." + else: try: # The move is success if no ValueError is raised. @@ -496,6 +494,60 @@ def game_over(self) -> None: # TODO(): add the following checks # - If player 0 made N repeatd back-and_forth moves in a row. + def save_snapshot(self) -> None: + """Saves the current length of the effect history, qubit_remapping_dict, and the current classical states of all pieces.""" + # Save the current length of the effect history and qubit_remapping_dict. + self.board.board.save_snapshot() + + # Save the classical states of all pieces. + snapshot = [] + for row in range(10): + for col in "abcdefghi": + piece = self.board.board[f"{col}{row}"] + snapshot.append( + [piece.type_.value, piece.color.value, piece.is_entangled] + ) + self.classical_properties_history.append(snapshot) + + def undo(self) -> bool: + """Undo the last move, which includes reset quantum effects and classical properties, and remapping + qubits. + + Returns True if the undo is success, and False otherwise. + """ + world = self.board.board + if ( + len(world.effect_history_length) <= 1 + or len(world.qubit_remapping_dict_length) <= 1 + or len(self.classical_properties_history) <= 1 + ): + # length == 1 corresponds to the initial state, and no more undo could be made. + return False + + # Recover the mapping of qubits to the last snapshot, remove any related post selection memory, + # and recover the effects up to the last snapshot (which was saved after the last move finished). + try: + world.restore_last_snapshot() + except ValueError as err: + print(err) + return False + except Exception: + print("Unexpected error during undo.") + raise + + # Recover the classical properties of all pieces to the last snapshot. + self.classical_properties_history.pop() + snapshot = self.classical_properties_history[-1] + index = 0 + for row in range(10): + for col in "abcdefghi": + piece = world[f"{col}{row}"] + piece.type_ = Type(snapshot[index][0]) + piece.color = Color(snapshot[index][1]) + piece.is_entangled = snapshot[index][2] + index += 1 + return True + def play(self) -> None: """The loop where each player takes turn to play.""" while True: @@ -507,11 +559,15 @@ def play(self) -> None: print("\nPlease re-enter your move.") continue print(output) - # TODO(): maybe we should not check game_over() when an undo is made. - # Check if the game is over. - self.game_over() - # TODO(): no need to do sampling if the last move was CLASSICAL. - self.update_board_by_sampling() + if output != "Undoing.": + # Check if the game is over. + self.game_over() + # Update any empty or occupied pieces' classical state. + # TODO(): no need to do sampling if the last move was CLASSICAL. + probs = self.update_board_by_sampling() + # Save the current states. + self.save_snapshot() + # TODO(): pass probs into the following method to print probabilities. print(self.board) if self.game_state == GameState.CONTINUES: # If the game continues, switch the player. diff --git a/unitary/examples/quantum_chinese_chess/chess_test.py b/unitary/examples/quantum_chinese_chess/chess_test.py index b278fe35..e8adc19e 100644 --- a/unitary/examples/quantum_chinese_chess/chess_test.py +++ b/unitary/examples/quantum_chinese_chess/chess_test.py @@ -14,6 +14,13 @@ import pytest import io import sys +from unitary.examples.quantum_chinese_chess.test_utils import ( + set_board, + assert_sample_distribution, + locations_to_bitboard, + assert_samples_in, +) +from unitary import alpha from unitary.examples.quantum_chinese_chess.chess import QuantumChineseChess from unitary.examples.quantum_chinese_chess.piece import Piece from unitary.examples.quantum_chinese_chess.enums import ( @@ -380,3 +387,75 @@ def test_update_board_by_sampling(monkeypatch): # Verify that the method would set a1 to classically occupied. game.update_board_by_sampling() assert board["a1"].is_entangled == False + + +def test_undo_single_effect_per_move(monkeypatch): + inputs = iter(["y", "Bob", "Ben"]) + monkeypatch.setattr("builtins.input", lambda _: next(inputs)) + game = QuantumChineseChess() + board = set_board(["a1", "b1", "c1"]) + game.board = board + world = board.board + game.save_snapshot() + alpha.PhasedSplit()(world["a1"], world["a2"], world["a3"]) + game.save_snapshot() + alpha.PhasedSplit()(world["b1"], world["b2"], world["b3"]) + game.save_snapshot() + + assert_sample_distribution( + board, + { + locations_to_bitboard(["a2", "b2", "c1"]): 1.0 / 4, + locations_to_bitboard(["a2", "b3", "c1"]): 1.0 / 4, + locations_to_bitboard(["a3", "b2", "c1"]): 1.0 / 4, + locations_to_bitboard(["a3", "b3", "c1"]): 1.0 / 4, + }, + ) + + game.undo() + assert_sample_distribution( + board, + { + locations_to_bitboard(["a2", "b1", "c1"]): 1.0 / 2, + locations_to_bitboard(["a3", "b1", "c1"]): 1.0 / 2, + }, + ) + + # One more undo to return to the initial state. + game.undo() + assert_samples_in(board, {locations_to_bitboard(["a1", "b1", "c1"]): 1.0}) + + # More undos should have no effect. + game.undo() + assert_samples_in(board, {locations_to_bitboard(["a1", "b1", "c1"]): 1.0}) + + +def test_undo_multiple_effects_per_move(monkeypatch): + inputs = iter(["y", "Bob", "Ben"]) + monkeypatch.setattr("builtins.input", lambda _: next(inputs)) + game = QuantumChineseChess() + board = set_board(["a1", "b1", "c1"]) + game.board = board + world = board.board + game.save_snapshot() + alpha.PhasedSplit()(world["a1"], world["a2"], world["a3"]) + alpha.PhasedSplit()(world["b1"], world["b2"], world["b3"]) + game.save_snapshot() + + assert_sample_distribution( + board, + { + locations_to_bitboard(["a2", "b2", "c1"]): 1.0 / 4, + locations_to_bitboard(["a2", "b3", "c1"]): 1.0 / 4, + locations_to_bitboard(["a3", "b2", "c1"]): 1.0 / 4, + locations_to_bitboard(["a3", "b3", "c1"]): 1.0 / 4, + }, + ) + + # One more undo to return to the initial state. + game.undo() + assert_samples_in(board, {locations_to_bitboard(["a1", "b1", "c1"]): 1.0}) + + # More undos should have no effect. + game.undo() + assert_samples_in(board, {locations_to_bitboard(["a1", "b1", "c1"]): 1.0})