This repository contains an implementation of a Classical-Quantum (or Hybrid) Neural Network (HNN) that combines a Quantum Neural Network (QNN) and a Convolutional Neural Network (CNN) for digit recognition on handwritten digits (e.g., MNIST, EMNIST Digits, etc.) datasets. The model is protected against compounded adversarial attacks using adversarial training.
West, Maxwell T, et al. “Benchmarking Adversarially Robust Quantum Machine Learning at Scale.” Physical Review Research, vol. 5, no. 2, 23 June 2023, doi: 10.1103/physrevresearch.5.023186.
TorchAttacks https://adversarial-attacks-pytorch.readthedocs.io/en/latest/
PyTorch Adversarial Example Generation https://pytorch.org/tutorials/beginner/fgsm_tutorial.html
Adversarial-Attacks-PyTorch https://github.com/Harry24k/adversarial-attacks-pytorch
Compounded Attack | Pre-Attack Accuracy - No Defense | Post Attack Accuracy - No Defense | Post Attack Accuracy - with Defense |
---|---|---|---|
FGSM + CW | 98.0% | 20.0% | 100.0% |
FGSM + PGD | 98.0% | 20.0% | 98.0% |
CW + PGD | 100.0% | 89.0% | 100.0% |
To run the code, ensure you have the following dependencies installed:
- Python version: 3.8.18
- torch: 2.2.1
- torchvision: 0.17.1
- torchattacks: 3.5.1
- numpy: 1.23.5
- tabulate: 0.9.0
- cirq: 1.3.0
You can install the required packages using pip:
pip install torch torchvision torchattacks numpy tabulate cirq
The code assumes that the datasets are stored on Google Drive and that Google Drive will be mounted with the default directory structure. Ensure that you have the necessary datasets (MNIST, EMNIST, SVHN, USPS, or Semeion) in the appropriate locations on your Google Drive.
The CNN architecture is designed to recognize digits 0 to 9 in the MNIST-type datasets. It consists of the following layers:
- Two convolutional layers with ReLU activation and max pooling
- Two fully connected layers with ReLU activation
- Output layer with log softmax activation
The CNN learns to extract relevant features and patterns from the input images for digit recognition.
The HNN combines the classical CNN with a quantum circuit to enhance the model's performance. The HNN takes the classical model as an input parameter and integrates it with the quantum circuit.
The HNN initializes trainable parameters (theta
and phi
) for the quantum circuit, which represent the angles of rotation gates. These parameters introduce additional non-linearity and expressiveness to the model.
During the forward pass, the input data is passed through the classical CNN, and the extracted features are used as input to the quantum circuit. The quantum circuit applies quantum operations based on the learned parameters to transform the input features. The output of the quantum circuit is then processed and combined with the classical model's predictions to produce the final output.
The quantum circuit is implemented using the Cirq library. It consists of the following components:
- Rotation gates (
cirq.ry
) applied to each qubit, alternating between the anglestheta
andphi
- Entangling gates (
cirq.CNOT
) applied between pairs of adjacent qubits
The rotation gates introduce single-qubit operations that can manipulate the state of individual qubits, while the entangling gates create correlations between the states of different qubits.
The number of qubits in the quantum circuit is determined based on the output dimension of the model. The circuit is created dynamically based on the learned parameters theta
and phi
.
The code implements compounded white-box targeted adversarial attacks using the TorchAttacks library. The available attack options include:
- FGSM + CW attack
- FGSM + PGD attack
- CW + PGD attack
- PGD + BIM attack
- FGSM + BIM attack
- CW + BIM attack
- FGSM + DeepFool attack
- PGD + DeepFool attack
- CW + DeepFool attack
- BIM + DeepFool attack
These attacks are designed to generate adversarial examples that fool the model into making incorrect predictions. The attacks are performed in a white-box setting, where the attacker has full knowledge of the model architecture and parameters. The targeted nature of the attacks means that the adversarial examples are crafted to cause the model to misclassify the input as a specific target class.
The code includes an implementation of adversarial training as a defense mechanism against adversarial attacks. Adversarial training involves the following steps:
- Generate adversarial examples using the specified compounded attack on the clean training data.
- Combine the clean training data with the generated adversarial examples to create an augmented training dataset.
- Retrain the model on the augmented training dataset, allowing it to learn to correctly classify both clean and adversarial examples.
By exposing the model to adversarial examples during training, adversarial training helps improve the model's robustness and resilience against adversarial attacks.
The code evaluates the model's performance on clean data, under adversarial attacks without defense, and under adversarial attacks with defense. It reports metrics such as loss, accuracy, precision, recall, F1-score, and ROC AUC score. It also provides visualizations of misclassified examples for each scenario.
The code summarizes the model's performance in a tabular format and displays example misclassifications for each scenario (clean, no defense attack, and with defense attack). The results provide insights into the effectiveness of the adversarial training defense against the specified compounded white-box targeted attack HNN Results).
This code makes use of the following libraries and frameworks:
- PyTorch
- TorchAttacks
- Cirq
- NumPy
- Tabulate
Disclaimer This repository is intended for educational and research purposes.
Copyright 2024 Eric Yocam
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.