From d543e4c67de5cb243065861a795c3220b50cd43e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Pr=C3=A9vost?= Date: Fri, 2 Feb 2024 11:34:24 +0100 Subject: [PATCH] Manage keys in KME level: keys are now added for specific KMEs as pre-init, and fully initialized for a SAE pair when SAE do the request. Add directory watching for automatically add new QKD keys --- Cargo.toml | 3 + README.md | 2 +- src/lib.rs | 13 +- src/main.rs | 126 +++++-- src/qkd_manager/http_response_obj.rs | 9 +- src/qkd_manager/init_qkd_database.sql | 24 ++ src/qkd_manager/key_handler.rs | 461 ++++++++++++++++++++------ src/qkd_manager/mod.rs | 215 ++++++++---- src/qkd_manager/router.rs | 26 ++ src/routes/request_context.rs | 4 +- src/routes/sae/info.rs | 7 +- tests/common/mod.rs | 17 +- tests/data/key_status.json | 6 +- tests/data/sae_info_me.json | 3 +- tests/dec_keys.rs | 26 ++ tests/enc_keys.rs | 19 +- tests/key_status.rs | 2 +- 17 files changed, 739 insertions(+), 224 deletions(-) create mode 100644 src/qkd_manager/init_qkd_database.sql create mode 100644 src/qkd_manager/router.rs diff --git a/Cargo.toml b/Cargo.toml index 93c7458..a5b9460 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,9 @@ base64 = "0.21.5" async-trait = "0.1.75" rustls-pemfile = "2.0.0" log = "0.4.20" +simple_logger = "4.3.3" +notify = "6.1.1" +clap = { version = "4.4.18", features = ["derive"] } [dev-dependencies] assert_cmd = "2.0.12" diff --git a/README.md b/README.md index aec974f..2f48397 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Response example: ```json { "source_KME_ID": "1", - "target_KME_ID": "?? TODO", + "target_KME_ID": "2", "master_SAE_ID": "1", "slave_SAE_ID": "2", "key_size": 256, diff --git a/src/lib.rs b/src/lib.rs index 6ba1f2f..40498ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,9 @@ fn io_err(e: &str) -> io::Error { /// The size of the QKD key in bytes. This is equal to QKD_MIN_KEY_SIZE_BITS and QKD_MAX_KEY_SIZE_BITS, as we offer no flexibility in key size pub const QKD_KEY_SIZE_BITS: usize = 256; +/// The size of the QKD key in bytes +pub const QKD_KEY_SIZE_BYTES: usize = QKD_KEY_SIZE_BITS / 8; + /// The minimum size of the QKD key in bits, returned in HTTP responses pub const QKD_MIN_KEY_SIZE_BITS: usize = 256; @@ -50,8 +53,14 @@ pub const CLIENT_CERT_SERIAL_SIZE_BYTES: usize = 20; /// Location of the SQLite database file used by the KME to store keys, use ":memory:" for in-memory database pub const MEMORY_SQLITE_DB_PATH: &'static str = ":memory:"; -/// The ID of this KME, used to identify the KME in the database and across the network -pub const THIS_KME_ID: i64 = 1; // TODO: change +/// The type of SAE ID +pub type SaeId = i64; + +/// The type of KME ID +pub type KmeId = i64; + +/// Type for QKD encryption key: basically a byte array +pub type QkdEncKey = [u8; QKD_KEY_SIZE_BYTES]; #[cfg(test)] mod test { diff --git a/src/main.rs b/src/main.rs index b04725e..4ad34aa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,21 @@ +use std::io::{BufReader, Read}; +use std::path::Path; +use std::sync::Arc; use log::error; -use qkd_kme_server::qkd_manager::{QkdKey, QkdManager}; +use clap::Parser; +use notify::{EventKind, RecursiveMode, Watcher}; +use notify::event::{AccessKind, AccessMode}; +use qkd_kme_server::qkd_manager::{PreInitQkdKeyWrapper, QkdManager}; use qkd_kme_server::routes::QKDKMERoutesV1; #[tokio::main] async fn main() { + simple_logger::SimpleLogger::new().init().unwrap(); + let args = Args::parse(); + + println!("{:?}", args); + + let server = qkd_kme_server::server::Server { listen_addr: "127.0.0.1:3000".to_string(), ca_client_cert_path: "certs/CA-zone1.crt".to_string(), @@ -11,56 +23,108 @@ async fn main() { server_key_path: "certs/kme1.key".to_string(), }; - let qkd_manager = QkdManager::new(qkd_kme_server::MEMORY_SQLITE_DB_PATH); + let qkd_manager = Arc::new(QkdManager::new(qkd_kme_server::MEMORY_SQLITE_DB_PATH, args.this_kme_id)); if qkd_manager.add_sae(1, - &[0x70, 0xf4, 0x4f, 0x56, 0x0c, 0x3f, 0x27, 0xd4, 0xb2, 0x11, 0xa4, 0x78, 0x13, 0xaf, 0xd0, 0x3c, 0x03, 0x81, 0x3b, 0x8e] + 1, + &Some([0x70, 0xf4, 0x4f, 0x56, 0x0c, 0x3f, 0x27, 0xd4, 0xb2, 0x11, 0xa4, 0x78, 0x13, 0xaf, 0xd0, 0x3c, 0x03, 0x81, 0x3b, 0x8e]) ).is_err() { error!("Error adding SAE to QKD manager"); return; } if qkd_manager.add_sae(2, - &[0x70, 0xf4, 0x4f, 0x56, 0x0c, 0x3f, 0x27, 0xd4, 0xb2, 0x11, 0xa4, 0x78, 0x13, 0xaf, 0xd0, 0x3c, 0x03, 0x81, 0x3b, 0x92] + 1, + &Some([0x70, 0xf4, 0x4f, 0x56, 0x0c, 0x3f, 0x27, 0xd4, 0xb2, 0x11, 0xa4, 0x78, 0x13, 0xaf, 0xd0, 0x3c, 0x03, 0x81, 0x3b, 0x92]) ).is_err() { error!("Error adding SAE to QKD manager"); return; } - let qkd_key_1 = match QkdKey::new( - 1, - 2, - b"this_is_secret_key_1_of_32_bytes", - ) { - Ok(qkd_key) => qkd_key, - Err(_) => { - error!("Error creating QKD key"); + let mut watchers: Vec = Vec::new(); + + for kme_dir in args.dirs_to_watch_other_kme_ids { + extract_all_keys_from_dir(&kme_dir.dir, kme_dir.kme_id, &qkd_manager); + let kme_id = kme_dir.kme_id; + let qkd_manager = Arc::clone(&qkd_manager); + watchers.push(match notify::recommended_watcher(move |res: Result| { + match res { + Ok(event) => { + if let EventKind::Access(AccessKind::Close(AccessMode::Write)) = event.kind { + extract_all_keys_from_file(&event.paths[0].to_str().unwrap(), kme_id, &qkd_manager); + } + } + Err(e) => { + println!("Watch error: {:?}", e); + return; + } + } + }) { + Ok(watcher) => watcher, + Err(e) => { + error!("Error creating watcher: {:?}", e); + return; + } + }); + if watchers.iter_mut().last().unwrap().watch(Path::new(&kme_dir.dir), RecursiveMode::NonRecursive).is_err() { + error!("Error watching directory: {:?}", kme_dir.dir); return; } - }; + } - if qkd_manager.add_qkd_key(qkd_key_1).is_err() { - error!("Error adding key to QKD manager"); + if server.run::(&qkd_manager).await.is_err() { + error!("Error running HTTP server"); return; } +} - let qkd_key_2 = match QkdKey::new( - 1, - 1, - b"this_is_secret_key_1_of_32_bytes", - ) { - Ok(qkd_key) => qkd_key, - Err(_) => { - error!("Error creating QKD key"); - return; +#[derive(Parser, Debug)] +struct Args { + #[arg(long)] + this_kme_id: i64, + #[arg(long("kme_id,dir_watch"))] + dirs_to_watch_other_kme_ids: Vec, +} + +#[derive(Debug, Clone)] +struct DirWatchOtherKmesArgs { + dir: String, + kme_id: i64, +} + +impl std::str::FromStr for DirWatchOtherKmesArgs { + type Err = std::io::Error; + fn from_str(s: &str) -> Result { + let kme_id_and_dir = s.split(",").collect::>(); + if kme_id_and_dir.len() != 2 { + return Err(std::io::Error::other("Invalid format")); } - }; + let kme_id = i64::from_str(kme_id_and_dir[0]).map_err(|_| std::io::Error::other("Invalid format"))?; + Ok(Self { + dir: kme_id_and_dir[1].to_string(), + kme_id, + }) + } +} - if qkd_manager.add_qkd_key(qkd_key_2).is_err() { - error!("Error adding key to QKD manager"); - return; +// TODO: move to QKD manager struct +fn extract_all_keys_from_file(file_path: &str, other_kme_id: i64, qkd_manager: &QkdManager) { + let file = std::fs::File::open(file_path).unwrap(); + let mut reader = BufReader::with_capacity(32, file); + let mut buffer = [0; 32]; + while let Ok(_) = reader.read_exact(&mut buffer) { + let qkd_key = PreInitQkdKeyWrapper::new( + other_kme_id, + &buffer, + ).unwrap(); + qkd_manager.add_pre_init_qkd_key(qkd_key).unwrap(); } +} - if server.run::(&qkd_manager).await.is_err() { - error!("Error running HTTP server"); - return; +fn extract_all_keys_from_dir(dir_path: &str, other_kme_id: i64, qkd_manager: &QkdManager) { + let paths = std::fs::read_dir(dir_path).unwrap(); + for path in paths { + let path = path.unwrap().path(); + if path.is_file() { + extract_all_keys_from_file(path.to_str().unwrap(), other_kme_id, qkd_manager); + } } } \ No newline at end of file diff --git a/src/qkd_manager/http_response_obj.rs b/src/qkd_manager/http_response_obj.rs index e66c83e..ee837a5 100644 --- a/src/qkd_manager/http_response_obj.rs +++ b/src/qkd_manager/http_response_obj.rs @@ -1,6 +1,7 @@ //! Objects serialized to HTTP response body use std::io; +use crate::{KmeId, SaeId}; /// Trait to be implemented by objects that can be serialized to JSON pub(crate) trait HttpResponseBody where Self: serde::Serialize { @@ -82,8 +83,9 @@ pub(crate) struct ResponseQkdKey { #[allow(non_snake_case)] pub(crate) struct ResponseQkdSAEInfo { /// SAE ID of the SAE - pub(crate) SAE_ID: i64, - // TODO: KME ID ? + pub(crate) SAE_ID: SaeId, + /// KME ID SAE belongs to + pub(crate) KME_ID: KmeId, } impl HttpResponseBody for ResponseQkdSAEInfo {} // can't use Derive macro because of the generic constraint @@ -144,8 +146,9 @@ mod test { fn test_serialize_response_qkd_sae_info() { let response_qkd_sae_info = super::ResponseQkdSAEInfo { SAE_ID: 1, + KME_ID: 1, }; let response_qkd_sae_info_json = response_qkd_sae_info.to_json().unwrap(); - assert_eq!(response_qkd_sae_info_json, "{\n \"SAE_ID\": 1\n}"); + assert_eq!(response_qkd_sae_info_json, "{\n \"SAE_ID\": 1,\n \"KME_ID\": 1\n}"); } } \ No newline at end of file diff --git a/src/qkd_manager/init_qkd_database.sql b/src/qkd_manager/init_qkd_database.sql new file mode 100644 index 0000000..b9c4c8c --- /dev/null +++ b/src/qkd_manager/init_qkd_database.sql @@ -0,0 +1,24 @@ +/* Uninitialized keys, available for SAEs in this KME and other_kme */ +CREATE TABLE IF NOT EXISTS uninit_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + key_uuid TEXT NOT NULL, + key BLOB NOT NULL, + other_kme_id INTEGER NOT NULL +); + +/* Keys assigned to SAEs */ +CREATE TABLE IF NOT EXISTS keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + key_uuid TEXT NOT NULL, + key BLOB NOT NULL, + origin_sae_id INTEGER NOT NULL, + target_sae_id INTEGER NOT NULL, + FOREIGN KEY (origin_sae_id) REFERENCES saes(sae_id), + FOREIGN KEY (target_sae_id) REFERENCES saes(sae_id) +); + +CREATE TABLE IF NOT EXISTS saes ( + sae_id INTEGER PRIMARY KEY NOT NULL, + sae_certificate_serial BLOB, + kme_id INTEGER NOT NULL +); \ No newline at end of file diff --git a/src/qkd_manager/key_handler.rs b/src/qkd_manager/key_handler.rs index 0a8b5d2..8ab2a39 100644 --- a/src/qkd_manager/key_handler.rs +++ b/src/qkd_manager/key_handler.rs @@ -4,8 +4,8 @@ use std::convert::identity; use std::io; use uuid::Bytes; use x509_parser::nom::AsBytes; -use crate::qkd_manager; -use crate::qkd_manager::{QkdKey, QkdManagerCommand, QkdManagerResponse, SAEInfo}; +use crate::{io_err, KmeId, qkd_manager, SaeId}; +use crate::qkd_manager::{KMEInfo, PreInitQkdKeyWrapper, QkdManagerCommand, QkdManagerResponse, SAEInfo}; use base64::{engine::general_purpose, Engine as _}; use log::{error, info, warn}; use crate::qkd_manager::http_response_obj::{ResponseQkdKey, ResponseQkdKeysList}; @@ -19,6 +19,8 @@ pub(super) struct KeyHandler { response_tx: crossbeam_channel::Sender, /// Connection to the sqlite database (in memory or on disk) sqlite_db: sqlite::Connection, + /// The ID of this KME + this_kme_id: KmeId, } impl KeyHandler { @@ -32,7 +34,9 @@ impl KeyHandler { /// A new key handler /// # Errors /// If the sqlite database cannot be opened or if the tables cannot be created - pub(super) fn new(sqlite_db_path: &str, command_rx: crossbeam_channel::Receiver, response_tx: crossbeam_channel::Sender) -> Result { + pub(super) fn new(sqlite_db_path: &str, command_rx: crossbeam_channel::Receiver, response_tx: crossbeam_channel::Sender, this_kme_id: i64) -> Result { + const DATABASE_INIT_REQ: &'static str = include_str!("init_qkd_database.sql"); + let key_handler = Self { command_rx, response_tx, @@ -40,20 +44,10 @@ impl KeyHandler { sqlite_db: sqlite::open(sqlite_db_path).map_err(|e| { io::Error::new(io::ErrorKind::NotConnected, format!("Error opening sqlite database: {:?}", e)) })?, + this_kme_id, }; // Create the tables if they do not exist - key_handler.sqlite_db.execute( - "CREATE TABLE IF NOT EXISTS keys ( - id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, - key_uuid TEXT NOT NULL, - key BLOB NOT NULL, - origin_sae_id INTEGER NOT NULL, - target_sae_id INTEGER NOT NULL, - FOREIGN KEY (origin_sae_id) REFERENCES saes(sae_id), - FOREIGN KEY (target_sae_id) REFERENCES saes(sae_id)); - CREATE TABLE IF NOT EXISTS saes ( - sae_id INTEGER PRIMARY KEY NOT NULL, - sae_certificate_serial BLOB NOT NULL);").map_err(|e| { + key_handler.sqlite_db.execute(DATABASE_INIT_REQ).map_err(|e| { io::Error::new(io::ErrorKind::InvalidInput, format!("Error creating sqlite tables: {:?}", e)) })?; Ok(key_handler) @@ -70,9 +64,9 @@ impl KeyHandler { Ok(cmd) => { match cmd { // Insert a key into the database, each time a QKD exchange occurs - QkdManagerCommand::AddKey(key) => { - info!("Adding key for SAE ID {}", key.target_sae_id); - if self.response_tx.send(self.add_key(key).unwrap_or_else(identity)).is_err() { + QkdManagerCommand::AddPreInitKey(key) => { + info!("Adding key for KME ID {} and {}", self.this_kme_id, key.other_kme_id); + if self.response_tx.send(self.add_preinit_qkd_key(key).unwrap_or_else(identity)).is_err() { error!("Error QKD manager sending response"); } }, @@ -91,9 +85,9 @@ impl KeyHandler { } }, // Add a new SAE ID to the database - QkdManagerCommand::AddSae(sae_id, sae_certificate_serial) => { + QkdManagerCommand::AddSae(sae_id, kme_id, sae_certificate_serial) => { info!("Adding SAE ID {}", sae_id); - if self.response_tx.send(self.add_sae(sae_id, &sae_certificate_serial).unwrap_or_else(identity)).is_err() { + if self.response_tx.send(self.add_sae(sae_id, kme_id, &sae_certificate_serial).unwrap_or_else(identity)).is_err() { error!("Error QKD manager sending response"); } }, @@ -105,15 +99,20 @@ impl KeyHandler { } }, QkdManagerCommand::GetSaeInfoFromCertificate(sae_certificate) => { - info!("Getting SAE ID from certificate"); - let sae_id = self.get_sae_id_from_certificate(&sae_certificate); - let response = match sae_id { - Some(sae_id) => QkdManagerResponse::SaeInfo(SAEInfo { - sae_id, - sae_certificate_serial: sae_certificate, + info!("Getting SAE info from certificate"); + let sae_info_response = self.get_sae_infos_from_certificate(&sae_certificate).unwrap_or_else(identity); + if self.response_tx.send(sae_info_response).is_err() { + error!("Error QKD manager sending response"); + } + }, + QkdManagerCommand::GetKmeIdFromSaeId(sae_id) => { + let kme_id = self.get_kme_id_from_sae_id(sae_id); + let response = match kme_id { + Some(kme_id) => QkdManagerResponse::KmeInfo(KMEInfo { + kme_id, }), None => { - warn!("SAE certificate not found in database"); + warn!("Get KME ID from SAE ID: SAE ID not found in database"); QkdManagerResponse::NotFound }, }; @@ -130,27 +129,50 @@ impl KeyHandler { } } - fn add_sae(&self, sae_id: i64, sae_certificate_serial: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]) -> Result { - const PREPARED_STATEMENT: &'static str = "INSERT INTO saes (sae_id, sae_certificate_serial) VALUES (?, ?);"; + /// Add a new SAE ID to the database + /// # Arguments + /// * `sae_id` - The SAE ID to add + /// * `kme_id` - The KME ID to associate with the SAE ID + /// * `sae_certificate_serial` - The SAE certificate serial number, None if the SAE isn't supposed to authenticate to this KME + fn add_sae(&self, sae_id: SaeId, kme_id: KmeId, sae_certificate_serial: &Option<[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]>) -> Result { + const PREPARED_STATEMENT_KNOWN_CERT: &'static str = "INSERT INTO saes (sae_id, kme_id, sae_certificate_serial) VALUES (?, ?, ?);"; + const PREPARED_STATEMENT_NO_CERT: &'static str = "INSERT INTO saes (sae_id, kme_id) VALUES (?, ?);"; + + let has_provided_certificate = sae_certificate_serial.is_some(); + let is_this_kme = kme_id == self.this_kme_id; + // Has given certificate and doesn't belong to this KME, or doesn't have certificate and belongs to this KME + if has_provided_certificate != is_this_kme { + return Err(QkdManagerResponse::InconsistentSaeData); + } - let mut stmt = ensure_prepared_statement_ok!(self.sqlite_db, PREPARED_STATEMENT); + let statement = match sae_certificate_serial { + Some(_) => PREPARED_STATEMENT_KNOWN_CERT, + None => PREPARED_STATEMENT_NO_CERT, + }; + + let mut stmt = ensure_prepared_statement_ok!(self.sqlite_db, statement); stmt.bind((1, sae_id)).map_err(|_| { QkdManagerResponse::Ko })?; - stmt.bind((2, sae_certificate_serial.as_bytes())).map_err(|_| { + stmt.bind((2, kme_id)).map_err(|_| { QkdManagerResponse::Ko })?; + if sae_certificate_serial.is_some() { + stmt.bind((3, sae_certificate_serial.unwrap().as_bytes())).map_err(|_| { + QkdManagerResponse::Ko + })?; + } stmt.next().map_err(|_| { QkdManagerResponse::Ko })?; Ok(QkdManagerResponse::Ok) } - fn add_key(&self, key: QkdKey) -> Result { - const PREPARED_STATEMENT: &'static str = "INSERT INTO keys (key_uuid, key, origin_sae_id, target_sae_id) VALUES (?, ?, ?, ?);"; + fn add_preinit_qkd_key(&self, pre_init_key: PreInitQkdKeyWrapper) -> Result { + const PREPARED_STATEMENT: &'static str = "INSERT INTO uninit_keys (key_uuid, key, other_kme_id) VALUES (?, ?, ?);"; let mut stmt = ensure_prepared_statement_ok!(self.sqlite_db, PREPARED_STATEMENT); - let uuid_bytes = Bytes::try_from(key.key_uuid).map_err(|_| { + let uuid_bytes = Bytes::try_from(pre_init_key.key_uuid).map_err(|_| { error!("Error converting UUID to bytes"); QkdManagerResponse::Ko })?; @@ -158,13 +180,10 @@ impl KeyHandler { stmt.bind((1, uuid_str.as_str())).map_err(|_| { QkdManagerResponse::Ko })?; - stmt.bind((2, key.key.as_bytes())).map_err(|_| { - QkdManagerResponse::Ko - })?; - stmt.bind((3, key.origin_sae_id)).map_err(|_| { + stmt.bind((2, pre_init_key.key.as_bytes())).map_err(|_| { QkdManagerResponse::Ko })?; - stmt.bind((4, key.target_sae_id)).map_err(|_| { + stmt.bind((3, pre_init_key.other_kme_id)).map_err(|_| { QkdManagerResponse::Ko })?; stmt.next().map_err(|_| { @@ -173,21 +192,19 @@ impl KeyHandler { Ok(QkdManagerResponse::Ok) } - fn get_sae_status(&self, origin_sae_certificate: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], target_sae_id: i64) -> Result { - const PREPARED_STATEMENT: &'static str = "SELECT COUNT(*) FROM keys WHERE target_sae_id = ? and origin_sae_id = ?;"; + fn get_sae_status(&self, origin_sae_certificate: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], target_sae_id: SaeId) -> Result { + const PREPARED_STATEMENT: &'static str = "SELECT COUNT(*) FROM uninit_keys WHERE other_kme_id = ?;"; + + let target_kme_id = self.get_kme_id_from_sae_id(target_sae_id).ok_or(QkdManagerResponse::NotFound)?; // Ensure the origin (master) SAE ID is valid, and get its SAE id let origin_sae_id = self.get_sae_id_from_certificate(origin_sae_certificate).ok_or(QkdManagerResponse::AuthenticationError)?; let mut stmt = ensure_prepared_statement_ok!(self.sqlite_db, PREPARED_STATEMENT); - stmt.bind((1, target_sae_id)).map_err(|_| { + stmt.bind((1, target_kme_id)).map_err(|_| { error!("Error binding target SAE ID"); QkdManagerResponse::Ko })?; - stmt.bind((2, origin_sae_id)).map_err(|_| { - error!("Error binding origin SAE ID"); - QkdManagerResponse::Ko - })?; stmt.next().map_err(|_| { error!("Error executing SQL statement"); QkdManagerResponse::Ko @@ -197,10 +214,12 @@ impl KeyHandler { QkdManagerResponse::Ko })?; + let source_kme_id = self.this_kme_id; // This KME + // Create key exchange status response object let response_qkd_key_status = qkd_manager::http_response_obj::ResponseQkdKeysStatus { - source_KME_ID: crate::THIS_KME_ID.to_string(), - target_KME_ID: "?? TODO".to_string(), + source_KME_ID: source_kme_id.to_string(), + target_KME_ID: target_kme_id.to_string(), master_SAE_ID: origin_sae_id.to_string(), slave_SAE_ID: target_sae_id.to_string(), key_size: crate::QKD_KEY_SIZE_BITS, @@ -215,19 +234,17 @@ impl KeyHandler { Ok(QkdManagerResponse::Status(response_qkd_key_status)) } - fn get_sae_keys(&self, origin_sae_certificate: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], target_sae_id: i64) -> Result { - const PREPARED_STATEMENT: &'static str = "SELECT key_uuid, key FROM keys WHERE target_sae_id = ? and origin_sae_id = ? LIMIT 1;"; + fn get_sae_keys(&self, origin_sae_certificate: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], target_sae_id: SaeId) -> Result { + const FETCH_PREINIT_KEY_PREPARED_STATEMENT: &'static str = "SELECT id, key_uuid, key, other_kme_id FROM uninit_keys WHERE other_kme_id = ? LIMIT 1;"; // Ensure the origin (master) SAE ID is valid, and get its SAE id let origin_sae_id = self.get_sae_id_from_certificate(origin_sae_certificate).ok_or(QkdManagerResponse::AuthenticationError)?; + let origin_kme_id = self.this_kme_id; + let target_kme_id = self.get_kme_id_from_sae_id(target_sae_id).ok_or(QkdManagerResponse::NotFound)?; - let mut stmt = ensure_prepared_statement_ok!(self.sqlite_db, PREPARED_STATEMENT); - stmt.bind((1, target_sae_id)).map_err(|_| { - error!("Error binding target SAE ID"); - QkdManagerResponse::Ko - })?; - stmt.bind((2, origin_sae_id)).map_err(|_| { - error!("Error binding origin SAE ID"); + let mut stmt = ensure_prepared_statement_ok!(self.sqlite_db, FETCH_PREINIT_KEY_PREPARED_STATEMENT); + stmt.bind((1, target_kme_id)).map_err(|_| { + error!("Error binding target KME ID"); QkdManagerResponse::Ko })?; let sql_execution_state = stmt.next().map_err(|_| { @@ -235,19 +252,55 @@ impl KeyHandler { QkdManagerResponse::Ko })?; - // /!\ We only want 1 key here, as multiple keys response isn't supported yet - if sql_execution_state != sqlite::State::Row { return Err(QkdManagerResponse::NotFound); // TODO: we could return an empty array instead } - let key_uuid: String = stmt.read::(0).map_err(|_| { + + let id = stmt.read::(0).map_err(|_| { error!("Error reading SQL statement result"); QkdManagerResponse::Ko })?; - let key: Vec = stmt.read::, usize>(1).map_err(|_| { + let key_uuid: String = stmt.read::(1).map_err(|_| { error!("Error reading SQL statement result"); QkdManagerResponse::Ko })?; + let key: Vec = stmt.read::, usize>(2).map_err(|_| { + error!("Error reading SQL statement result"); + QkdManagerResponse::Ko + })?; + if origin_kme_id != target_kme_id { + // send key to other KME TODO + } + + self.delete_pre_init_key_with_id(id).map_err(|_| { + error!("Error deleting pre-init key {}", id); + QkdManagerResponse::Ko + })?; + + info!("Saving key {} in init keys", key_uuid); + const INSERT_INIT_KEY_PREPARED_STATEMENT: &'static str = "INSERT INTO keys (key_uuid, key, origin_sae_id, target_sae_id) VALUES (?, ?, ?, ?);"; + + let mut stmt = ensure_prepared_statement_ok!(self.sqlite_db, INSERT_INIT_KEY_PREPARED_STATEMENT); + stmt.bind((1, key_uuid.as_str())).map_err(|_| { + error!("Error binding key UUID"); + QkdManagerResponse::Ko + })?; + stmt.bind((2, key.as_slice())).map_err(|_| { + error!("Error binding key"); + QkdManagerResponse::Ko + })?; + stmt.bind((3, origin_sae_id)).map_err(|_| { + error!("Error binding origin SAE ID"); + QkdManagerResponse::Ko + })?; + stmt.bind((4, target_sae_id)).map_err(|_| { + error!("Error binding target SAE ID"); + QkdManagerResponse::Ko + })?; + stmt.next().map_err(|_| { + error!("Error executing SQL statement"); + QkdManagerResponse::Ko + })?; // Encode the key in base64 let response_qkd_key = ResponseQkdKey { @@ -261,7 +314,32 @@ impl KeyHandler { })) } - fn get_sae_keys_with_ids(&self, current_sae_certificate: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], origin_sae_id: i64, keys_uuids: Vec) -> Result { + /// Delete a pre-init key from the pre-init keys database + /// Called when master SAE requested the key: it becomes an init key + /// So that the same key isn't requested again by a master SAE + /// # Arguments + /// * `key_id` - The ID of the pre init key to delete + /// # Returns + /// Ok if the key was deleted, an error otherwise + fn delete_pre_init_key_with_id(&self, key_id: i64) -> Result<(), io::Error> { + const PREPARED_STATEMENT: &'static str = "DELETE FROM uninit_keys WHERE id = ?;"; + + let mut stmt = match self.sqlite_db.prepare(PREPARED_STATEMENT) { + Ok(stmt) => stmt, + Err(_) => { + return Err(io_err("Error preparing SQL statement")); + } + }; + stmt.bind((1, key_id)).map_err(|_| { + io_err("Error binding key ID") + })?; + stmt.next().map_err(|_| { + io_err("Error executing SQL statement, maybe key ID not found in pre init keys database?") + })?; + Ok(()) + } + + fn get_sae_keys_with_ids(&self, current_sae_certificate: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], origin_sae_id: SaeId, keys_uuids: Vec) -> Result { const PREPARED_STATEMENT: &'static str = "SELECT key_uuid, key FROM keys WHERE target_sae_id = ? AND origin_sae_id = ? AND key_uuid = ? LIMIT 1;"; // Ensure the caller (slave) SAE ID is valid and authenticated, and get its SAE id @@ -319,7 +397,7 @@ impl KeyHandler { /// * `sae_certificate` - The client certificate serial number /// # Returns /// The SAE ID if the certificate serial number is found in the database, None otherwise - fn get_sae_id_from_certificate(&self, sae_certificate: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]) -> Option { + fn get_sae_id_from_certificate(&self, sae_certificate: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]) -> Option { const PREPARED_STATEMENT: &'static str = "SELECT sae_id FROM saes WHERE sae_certificate_serial = ? LIMIT 1;"; let mut stmt = match self.sqlite_db.prepare(PREPARED_STATEMENT) { Ok(stmt) => stmt, @@ -346,6 +424,73 @@ impl KeyHandler { }).ok()?; Some(sae_id) } + + /// Get the KME ID from associated SAE ID + /// # Arguments + /// * `sae_id` - The SAE ID + /// # Returns + /// The KME ID if the SAE ID is found in the database, None otherwise + fn get_kme_id_from_sae_id(&self, sae_id: SaeId) -> Option { + const PREPARED_STATEMENT: &'static str = "SELECT kme_id FROM saes WHERE sae_id = ? LIMIT 1;"; + let mut stmt = match self.sqlite_db.prepare(PREPARED_STATEMENT) { + Ok(stmt) => stmt, + Err(_) => { + error!("Error preparing SQL statement"); + return None; + } + }; + stmt.bind((1, sae_id)).map_err(|_| { + error!("Error binding SAE ID"); + () + }).ok()?; + let sql_execution_state = stmt.next().map_err(|_| { + error!("Error executing SQL statement"); + () + }).ok()?; + if sql_execution_state != sqlite::State::Row { + info!("SAE ID not found in database"); + return None; + } + let kme_id: i64 = stmt.read::(0).map_err(|_| { + error!("Error reading SQL statement result"); + () + }).ok()?; + Some(kme_id) + } + + /// Directly fetch SAE info from the certificate serial number, including the SAE ID and KME ID + /// # Arguments + /// * `sae_certificate` - The client SAE certificate serial number + /// # Returns + /// The SAE info, including KME ID, if the certificate serial number is found in the database, an error otherwise + fn get_sae_infos_from_certificate(&self, sae_certificate: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]) -> Result { + const PREPARED_STATEMENT: &'static str = "SELECT sae_id, kme_id FROM saes WHERE sae_certificate_serial = ? LIMIT 1;"; + let mut stmt = ensure_prepared_statement_ok!(self.sqlite_db, PREPARED_STATEMENT); + stmt.bind((1, sae_certificate.as_bytes())).map_err(|_| { + error!("Error binding SAE certificate serial"); + QkdManagerResponse::Ko + })?; + let sql_execution_state = stmt.next().map_err(|_| { + error!("Error executing SQL statement"); + QkdManagerResponse::Ko + })?; + if sql_execution_state != sqlite::State::Row { + return Err(QkdManagerResponse::NotFound); + } + let sae_id: SaeId = stmt.read::(0).map_err(|_| { + error!("Error reading SQL statement result"); + QkdManagerResponse::Ko + })?; + let kme_id: KmeId = stmt.read::(1).map_err(|_| { + error!("Error reading SQL statement result"); + QkdManagerResponse::Ko + })?; + Ok(QkdManagerResponse::SaeInfo(SAEInfo { + sae_id, + kme_id, + sae_certificate_serial: *sae_certificate, + })) + } } /// Check SQL statement preparation and return the statement @@ -373,10 +518,11 @@ mod tests { fn test_get_sae_id_from_certificate() { let (_, command_channel_rx) = crossbeam_channel::unbounded(); let (response_channel_tx, _) = crossbeam_channel::unbounded(); - let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx).unwrap(); + let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx, 1).unwrap(); let sae_id = 1; + let kme_id = 1; let sae_certificate_serial = [0u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]; - key_handler.add_sae(sae_id, &sae_certificate_serial).unwrap(); + key_handler.add_sae(sae_id, kme_id, &Some(sae_certificate_serial)).unwrap(); assert_eq!(key_handler.get_sae_id_from_certificate(&sae_certificate_serial).unwrap(), sae_id); let fake_sae_certificate_serial = [1u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]; @@ -384,27 +530,26 @@ mod tests { } #[test] - fn test_add_key() { + fn test_add_preinit_key() { let (_, command_channel_rx) = crossbeam_channel::unbounded(); let (response_channel_tx, _) = crossbeam_channel::unbounded(); - let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx).unwrap(); - let key = crate::qkd_manager::QkdKey { + let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx, 1).unwrap(); + let key = crate::qkd_manager::PreInitQkdKeyWrapper { + other_kme_id: 1, key_uuid: *uuid::Uuid::from_bytes([0u8; 16]).as_bytes(), key: [0u8; crate::QKD_KEY_SIZE_BITS / 8], - origin_sae_id: 1, - target_sae_id: 2, }; - key_handler.add_key(key).unwrap(); + key_handler.add_preinit_qkd_key(key).unwrap(); } #[test] fn test_get_sae_status() { let (_, command_channel_rx) = crossbeam_channel::unbounded(); let (response_channel_tx, _) = crossbeam_channel::unbounded(); - let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx).unwrap(); + let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx, 1).unwrap(); let sae_id = 1; let sae_certificate_serial = [0u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]; - key_handler.add_sae(sae_id, &sae_certificate_serial).unwrap(); + key_handler.add_sae(sae_id, 1, &Some(sae_certificate_serial)).unwrap(); let qkd_manager_response = key_handler.get_sae_status(&sae_certificate_serial, sae_id).unwrap(); assert!(matches!(qkd_manager_response, QkdManagerResponse::Status(_))); let response_status = match qkd_manager_response { @@ -413,17 +558,20 @@ mod tests { panic!("Unexpected response"); } }; - assert_eq!(response_status.to_json().unwrap(), "{\n \"source_KME_ID\": \"1\",\n \"target_KME_ID\": \"?? TODO\",\n \"master_SAE_ID\": \"1\",\n \"slave_SAE_ID\": \"1\",\n \"key_size\": 256,\n \"stored_key_count\": 0,\n \"max_key_count\": 10,\n \"max_key_per_request\": 1,\n \"max_key_size\": 256,\n \"min_key_size\": 256,\n \"max_SAE_ID_count\": 0\n}"); + assert_eq!(response_status.to_json().unwrap(), "{\n \"source_KME_ID\": \"1\",\n \"target_KME_ID\": \"1\",\n \"master_SAE_ID\": \"1\",\n \"slave_SAE_ID\": \"1\",\n \"key_size\": 256,\n \"stored_key_count\": 0,\n \"max_key_count\": 10,\n \"max_key_per_request\": 1,\n \"max_key_size\": 256,\n \"min_key_size\": 256,\n \"max_SAE_ID_count\": 0\n}"); - // add key for another SAE id - let key = crate::qkd_manager::QkdKey { + // add key for another KME id + let key = crate::qkd_manager::PreInitQkdKeyWrapper { + other_kme_id: 2, key_uuid: *uuid::Uuid::from_bytes([0u8; 16]).as_bytes(), key: [0u8; crate::QKD_KEY_SIZE_BITS / 8], - origin_sae_id: 1, - target_sae_id: 3, }; - key_handler.add_key(key).unwrap(); + key_handler.add_preinit_qkd_key(key).unwrap(); + let qkd_manager_response = key_handler.get_sae_status(&sae_certificate_serial, 2); + assert!(matches!(qkd_manager_response, Err(QkdManagerResponse::NotFound))); + + key_handler.add_sae(2, 1, &Some([1u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); let qkd_manager_response = key_handler.get_sae_status(&sae_certificate_serial, 2).unwrap(); assert!(matches!(qkd_manager_response, QkdManagerResponse::Status(_))); let response_status = match qkd_manager_response { @@ -432,16 +580,15 @@ mod tests { panic!("Unexpected response"); } }; - assert_eq!(response_status.to_json().unwrap(), "{\n \"source_KME_ID\": \"1\",\n \"target_KME_ID\": \"?? TODO\",\n \"master_SAE_ID\": \"1\",\n \"slave_SAE_ID\": \"2\",\n \"key_size\": 256,\n \"stored_key_count\": 0,\n \"max_key_count\": 10,\n \"max_key_per_request\": 1,\n \"max_key_size\": 256,\n \"min_key_size\": 256,\n \"max_SAE_ID_count\": 0\n}"); + assert_eq!(response_status.to_json().unwrap(), "{\n \"source_KME_ID\": \"1\",\n \"target_KME_ID\": \"1\",\n \"master_SAE_ID\": \"1\",\n \"slave_SAE_ID\": \"2\",\n \"key_size\": 256,\n \"stored_key_count\": 0,\n \"max_key_count\": 10,\n \"max_key_per_request\": 1,\n \"max_key_size\": 256,\n \"min_key_size\": 256,\n \"max_SAE_ID_count\": 0\n}"); // add key - let key = crate::qkd_manager::QkdKey { + let key = crate::qkd_manager::PreInitQkdKeyWrapper { + other_kme_id: 1, key_uuid: *uuid::Uuid::from_bytes([0u8; 16]).as_bytes(), key: [0u8; crate::QKD_KEY_SIZE_BITS / 8], - origin_sae_id: 1, - target_sae_id: 2, }; - key_handler.add_key(key).unwrap(); + key_handler.add_preinit_qkd_key(key).unwrap(); let qkd_manager_response = key_handler.get_sae_status(&sae_certificate_serial, 2).unwrap(); assert!(matches!(qkd_manager_response, QkdManagerResponse::Status(_))); let response_status = match qkd_manager_response { @@ -450,28 +597,41 @@ mod tests { panic!("Unexpected response"); } }; - assert_eq!(response_status.to_json().unwrap(), "{\n \"source_KME_ID\": \"1\",\n \"target_KME_ID\": \"?? TODO\",\n \"master_SAE_ID\": \"1\",\n \"slave_SAE_ID\": \"2\",\n \"key_size\": 256,\n \"stored_key_count\": 1,\n \"max_key_count\": 10,\n \"max_key_per_request\": 1,\n \"max_key_size\": 256,\n \"min_key_size\": 256,\n \"max_SAE_ID_count\": 0\n}"); + assert_eq!(response_status.to_json().unwrap(), "{\n \"source_KME_ID\": \"1\",\n \"target_KME_ID\": \"1\",\n \"master_SAE_ID\": \"1\",\n \"slave_SAE_ID\": \"2\",\n \"key_size\": 256,\n \"stored_key_count\": 1,\n \"max_key_count\": 10,\n \"max_key_per_request\": 1,\n \"max_key_size\": 256,\n \"min_key_size\": 256,\n \"max_SAE_ID_count\": 0\n}"); } #[test] fn test_get_sae_keys() { let (_, command_channel_rx) = crossbeam_channel::unbounded(); let (response_channel_tx, _) = crossbeam_channel::unbounded(); - let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx).unwrap(); + let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx, 1).unwrap(); let sae_id = 1; + let kme_id = 1; let sae_certificate_serial = [0u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]; - key_handler.add_sae(sae_id, &sae_certificate_serial).unwrap(); + key_handler.add_sae(sae_id, kme_id, &Some(sae_certificate_serial)).unwrap(); let qkd_manager_response = key_handler.get_sae_keys(&sae_certificate_serial, sae_id); assert!(matches!(qkd_manager_response, Err(QkdManagerResponse::NotFound))); // add key - let key = crate::qkd_manager::QkdKey { + let key = crate::qkd_manager::PreInitQkdKeyWrapper { + other_kme_id: 1, key_uuid: *uuid::Uuid::from_bytes([0u8; 16]).as_bytes(), key: [0u8; crate::QKD_KEY_SIZE_BITS / 8], - origin_sae_id: 1, - target_sae_id: 2, }; - key_handler.add_key(key).unwrap(); + key_handler.add_preinit_qkd_key(key).unwrap(); + + // add key + let key = crate::qkd_manager::PreInitQkdKeyWrapper { + other_kme_id: 1, + key_uuid: *uuid::Uuid::from_bytes([1u8; 16]).as_bytes(), + key: [1u8; crate::QKD_KEY_SIZE_BITS / 8], + }; + key_handler.add_preinit_qkd_key(key).unwrap(); + + let qkd_manager_response = key_handler.get_sae_keys(&sae_certificate_serial, 2); + assert!(matches!(qkd_manager_response, Err(QkdManagerResponse::NotFound))); + + key_handler.add_sae(2, kme_id, &Some([1u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); let qkd_manager_response = key_handler.get_sae_keys(&sae_certificate_serial, 2).unwrap(); assert!(matches!(qkd_manager_response, QkdManagerResponse::Keys(_))); let response_keys = match qkd_manager_response { @@ -484,32 +644,53 @@ mod tests { assert_eq!(response_keys.keys[0].key_ID, "00000000-0000-0000-0000-000000000000"); assert_eq!(response_keys.keys[0].key, "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="); assert_eq!(response_keys.to_json().unwrap(), "{\n \"keys\": [\n {\n \"key_ID\": \"00000000-0000-0000-0000-000000000000\",\n \"key\": \"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"\n }\n ]\n}"); + + + // Same request + let qkd_manager_response = key_handler.get_sae_keys(&sae_certificate_serial, 2).unwrap(); + assert!(matches!(qkd_manager_response, QkdManagerResponse::Keys(_))); + let response_keys = match qkd_manager_response { + QkdManagerResponse::Keys(keys) => keys, + _ => { + panic!("Unexpected response"); + } + }; + assert_eq!(response_keys.keys.len(), 1); + // Not the same key + assert_eq!(response_keys.keys[0].key_ID, "01010101-0101-0101-0101-010101010101"); + assert_eq!(response_keys.keys[0].key, "AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQE="); + assert_eq!(response_keys.to_json().unwrap(), "{\n \"keys\": [\n {\n \"key_ID\": \"01010101-0101-0101-0101-010101010101\",\n \"key\": \"AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQE=\"\n }\n ]\n}"); } #[test] fn test_get_sae_keys_with_ids() { let (_, command_channel_rx) = crossbeam_channel::unbounded(); let (response_channel_tx, _) = crossbeam_channel::unbounded(); - let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx).unwrap(); + let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx, 1).unwrap(); let sae_id = 1; + let kme_id = 1; let sae_1_certificate_serial = [0u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]; let sae_2_certificate_serial = [1u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]; - key_handler.add_sae(sae_id, &sae_1_certificate_serial).unwrap(); - key_handler.add_sae(2, &sae_2_certificate_serial).unwrap(); + key_handler.add_sae(sae_id, kme_id, &Some(sae_1_certificate_serial)).unwrap(); + key_handler.add_sae(2, kme_id, &Some(sae_2_certificate_serial)).unwrap(); let qkd_manager_response = key_handler.get_sae_keys_with_ids(&sae_1_certificate_serial, sae_id, vec!["00000000-0000-0000-0000-000000000000".to_string()]); assert!(matches!(qkd_manager_response, Err(QkdManagerResponse::NotFound))); // add key - let key = crate::qkd_manager::QkdKey { + let key = crate::qkd_manager::PreInitQkdKeyWrapper { + other_kme_id: 1, key_uuid: *uuid::Uuid::from_bytes([0u8; 16]).as_bytes(), key: [0u8; crate::QKD_KEY_SIZE_BITS / 8], - origin_sae_id: 1, - target_sae_id: 2, }; - key_handler.add_key(key).unwrap(); + key_handler.add_preinit_qkd_key(key).unwrap(); + // SAE1 has to pre fetch the key first + let qkd_manager_response = key_handler.get_sae_keys_with_ids(&sae_2_certificate_serial, 1, vec!["00000000-0000-0000-0000-000000000000".to_string()]); + assert!(matches!(qkd_manager_response, Err(QkdManagerResponse::NotFound))); + assert!(matches!(key_handler.get_sae_keys(&sae_1_certificate_serial, 2).unwrap(), QkdManagerResponse::Keys(_))); let qkd_manager_response = key_handler.get_sae_keys_with_ids(&sae_2_certificate_serial, 1, vec!["00000000-0000-0000-0000-000000000000".to_string()]).unwrap(); + assert!(matches!(qkd_manager_response, QkdManagerResponse::Keys(_))); let response_keys = match qkd_manager_response { QkdManagerResponse::Keys(keys) => keys, @@ -526,17 +707,69 @@ mod tests { assert!(matches!(qkd_manager_response, Err(QkdManagerResponse::NotFound))); } + #[test] + fn test_get_kme_id_from_sae() { + let (_, command_channel_rx) = crossbeam_channel::unbounded(); + let (response_channel_tx, _) = crossbeam_channel::unbounded(); + let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx, 1).unwrap(); + let sae_id = 1; + let kme_id = 1; + let sae_1_certificate_serial = [0u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]; + key_handler.add_sae(sae_id, kme_id, &Some(sae_1_certificate_serial)).unwrap(); + let kme_id = key_handler.get_kme_id_from_sae_id(sae_id).unwrap(); + assert_eq!(kme_id, 1); + let kme_id = key_handler.get_kme_id_from_sae_id(2); + assert!(matches!(kme_id, None)); + } + + #[test] + fn test_get_sae_infos_from_certificate() { + let (_, command_channel_rx) = crossbeam_channel::unbounded(); + let (response_channel_tx, _) = crossbeam_channel::unbounded(); + let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx, 1).unwrap(); + let sae_id = 1; + let kme_id = 1; + + let sae_info = key_handler.get_sae_infos_from_certificate(&[0u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]); + assert!(matches!(sae_info, Err(QkdManagerResponse::NotFound))); + + key_handler.add_sae(sae_id, kme_id, &Some([0u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); + let sae_info = key_handler.get_sae_infos_from_certificate(&[0u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]).unwrap(); + assert!(matches!(sae_info, QkdManagerResponse::SaeInfo(_))); + assert_eq!(sae_info, QkdManagerResponse::SaeInfo(super::SAEInfo { + sae_id, + kme_id, + sae_certificate_serial: [0u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], + })); + } + + #[test] + fn test_delete_pre_init_key_with_id() { + let (_, command_channel_rx) = crossbeam_channel::unbounded(); + let (response_channel_tx, _) = crossbeam_channel::unbounded(); + let key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx, 1).unwrap(); + let key = crate::qkd_manager::PreInitQkdKeyWrapper { + other_kme_id: 1, + key_uuid: *uuid::Uuid::from_bytes([0u8; 16]).as_bytes(), + key: [0u8; crate::QKD_KEY_SIZE_BITS / 8], + }; + key_handler.add_preinit_qkd_key(key).unwrap(); + let key_id = 1; // As it's the first key, we can assume it's the ID + key_handler.delete_pre_init_key_with_id(key_id).unwrap(); + } + #[test] fn test_run() { let (command_tx, command_channel_rx) = crossbeam_channel::unbounded(); let (response_channel_tx, response_rx) = crossbeam_channel::unbounded(); - let mut key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx).unwrap(); + let mut key_handler = super::KeyHandler::new(":memory:", command_channel_rx, response_channel_tx, 1).unwrap(); let sae_id = 1; + let kme_id = 1; let sae_certificate_serial = [0u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]; let _ = thread::spawn(move || { key_handler.run(); }); - command_tx.send(super::QkdManagerCommand::AddSae(sae_id, sae_certificate_serial)).unwrap(); + command_tx.send(super::QkdManagerCommand::AddSae(sae_id, kme_id, Some(sae_certificate_serial))).unwrap(); let qkd_manager_response = response_rx.recv().unwrap(); assert!(matches!(qkd_manager_response, QkdManagerResponse::Ok)); @@ -545,13 +778,12 @@ mod tests { assert!(matches!(qkd_manager_response, QkdManagerResponse::NotFound)); // add key - let key = crate::qkd_manager::QkdKey { + let key = crate::qkd_manager::PreInitQkdKeyWrapper { + other_kme_id: 1, key_uuid: *uuid::Uuid::from_bytes([0u8; 16]).as_bytes(), key: [0u8; crate::QKD_KEY_SIZE_BITS / 8], - origin_sae_id: 1, - target_sae_id: 2, }; - command_tx.send(super::QkdManagerCommand::AddKey(key)).unwrap(); + command_tx.send(super::QkdManagerCommand::AddPreInitKey(key)).unwrap(); let qkd_manager_response = response_rx.recv().unwrap(); assert!(matches!(qkd_manager_response, QkdManagerResponse::Ok)); @@ -559,6 +791,14 @@ mod tests { let qkd_manager_response = response_rx.recv().unwrap(); assert!(matches!(qkd_manager_response, QkdManagerResponse::NotFound)); + command_tx.send(super::QkdManagerCommand::GetStatus(sae_certificate_serial, 2)).unwrap(); + let qkd_manager_response = response_rx.recv().unwrap(); + assert!(matches!(qkd_manager_response, QkdManagerResponse::NotFound)); + + command_tx.send(super::QkdManagerCommand::AddSae(2, kme_id, Some([1u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]))).unwrap(); + let qkd_manager_response = response_rx.recv().unwrap(); + assert!(matches!(qkd_manager_response, QkdManagerResponse::Ok)); + command_tx.send(super::QkdManagerCommand::GetStatus(sae_certificate_serial, 2)).unwrap(); let qkd_manager_response = response_rx.recv().unwrap(); assert!(matches!(qkd_manager_response, QkdManagerResponse::Status(_))); @@ -568,10 +808,21 @@ mod tests { assert!(matches!(qkd_manager_response, QkdManagerResponse::SaeInfo(_))); assert_eq!(qkd_manager_response, QkdManagerResponse::SaeInfo(super::SAEInfo { sae_id, + kme_id, sae_certificate_serial, })); - command_tx.send(super::QkdManagerCommand::GetSaeInfoFromCertificate([1u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); + command_tx.send(super::QkdManagerCommand::GetSaeInfoFromCertificate([2u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); + let qkd_manager_response = response_rx.recv().unwrap(); + assert!(matches!(qkd_manager_response, QkdManagerResponse::NotFound)); + + command_tx.send(super::QkdManagerCommand::GetKmeIdFromSaeId(sae_id)).unwrap(); + let qkd_manager_response = response_rx.recv().unwrap(); + assert!(matches!(qkd_manager_response, QkdManagerResponse::KmeInfo(_))); + assert_eq!(qkd_manager_response, QkdManagerResponse::KmeInfo(super::KMEInfo { + kme_id, + })); + command_tx.send(super::QkdManagerCommand::GetKmeIdFromSaeId(3)).unwrap(); let qkd_manager_response = response_rx.recv().unwrap(); assert!(matches!(qkd_manager_response, QkdManagerResponse::NotFound)); } diff --git a/src/qkd_manager/mod.rs b/src/qkd_manager/mod.rs index 21857d0..8b0341a 100644 --- a/src/qkd_manager/mod.rs +++ b/src/qkd_manager/mod.rs @@ -3,12 +3,14 @@ mod key_handler; pub(crate) mod http_response_obj; pub(crate) mod http_request_obj; +mod router; use std::{io, thread}; use log::error; use sha1::Digest; use crate::qkd_manager::http_response_obj::ResponseQkdKeysList; use crate::qkd_manager::QkdManagerResponse::TransmissionError; +use crate::{KmeId, QkdEncKey, SaeId}; /// QKD manager interface, can be cloned for instance in each request handler task #[derive(Clone)] @@ -27,7 +29,7 @@ impl QkdManager { /// A new QKD manager handler /// # Notes /// This function spawns a new thread to handle the QKD manager - pub fn new(sqlite_db_path: &str) -> Self { + pub fn new(sqlite_db_path: &str, this_kme_id: i64) -> Self { // crossbeam_channel allows cloning the sender and receiver let (command_tx, command_rx) = crossbeam_channel::unbounded::(); let (response_tx, response_rx) = crossbeam_channel::unbounded::(); @@ -35,7 +37,7 @@ impl QkdManager { // Spawn a new thread to handle the QKD manager thread::spawn(move || { - let mut key_handler = match key_handler::KeyHandler::new(&sqlite_db_path, command_rx, response_tx) { + let mut key_handler = match key_handler::KeyHandler::new(&sqlite_db_path, command_rx, response_tx, this_kme_id) { Ok(handler) => handler, Err(_) => { error!("Error creating key handler"); @@ -55,8 +57,8 @@ impl QkdManager { /// * `key` - The QKD key to add (key + origin SAE ID + target SAE ID) /// # Returns /// Ok if the key was added successfully, an error otherwise - pub fn add_qkd_key(&self, key: QkdKey) -> Result { - self.command_tx.send(QkdManagerCommand::AddKey(key)).map_err(|_| { + pub fn add_pre_init_qkd_key(&self, key: PreInitQkdKeyWrapper) -> Result { + self.command_tx.send(QkdManagerCommand::AddPreInitKey(key)).map_err(|_| { TransmissionError })?; match self.response_rx.recv().map_err(|_| { @@ -116,8 +118,8 @@ impl QkdManager { /// Ok if the SAE was added successfully, an error otherwise /// # Notes /// It will fail if the SAE ID is already in the database - pub fn add_sae(&self, sae_id: i64, sae_certificate_serial: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]) -> Result { - self.command_tx.send(QkdManagerCommand::AddSae(sae_id, *sae_certificate_serial)).map_err(|_| { + pub fn add_sae(&self, sae_id: SaeId, kme_id: KmeId, sae_certificate_serial: &Option<[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]>) -> Result { + self.command_tx.send(QkdManagerCommand::AddSae(sae_id, kme_id, *sae_certificate_serial)).map_err(|_| { TransmissionError })?; match self.response_rx.recv().map_err(|_| { @@ -134,7 +136,7 @@ impl QkdManager { /// * `target_sae_id` - The ID of the target (slave) SAE, to which master SAE wants to communicate /// # Returns /// The status of the key exchange if the key exchange was found and the caller is authorized to retrieve it, an error otherwise - pub fn get_qkd_key_status(&self, origin_sae_certificate_serial: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], target_sae_id: i64) -> Result { + pub fn get_qkd_key_status(&self, origin_sae_certificate_serial: &[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], target_sae_id: SaeId) -> Result { self.command_tx.send(QkdManagerCommand::GetStatus(*origin_sae_certificate_serial, target_sae_id)).map_err(|_| { TransmissionError })?; @@ -162,27 +164,39 @@ impl QkdManager { qkd_response_error => Err(qkd_response_error), // Likely not found } } + + /// GET the KME ID from belonging SAE ID + /// # Arguments + /// * `sae_id` - The ID of the SAE + /// # Returns + /// The KME ID if the SAE was found, None otherwise + pub fn get_kme_id_from_sae_id(&self, sae_id: SaeId) -> Option { + self.command_tx.send(QkdManagerCommand::GetKmeIdFromSaeId(sae_id)).map_err(|_| { + TransmissionError + }).ok()?; + match self.response_rx.recv().map_err(|_| { + TransmissionError + }).ok()? { + QkdManagerResponse::KmeInfo(kme_info) => Some(kme_info), // KmeInfo is the QkdManagerResponse expected here + _ => None, + } + } } -/// A QKD key, with its origin and target SAE IDs -/// # Note -/// This is not the key serialized in HTTP response, which is [ResponseQkdKey](http_response_obj::ResponseQkdKey) +/// A Pre-init QKD key, with its origin and target KME IDs +/// This is the key supposed to be added to the database +/// A master SAE, belonging to the KME, will then request the key +/// Its status will become "initialized", meaning that instead of being associated to a KME, it will be associated to a pair of SAEs +/// The slave SAE would then request the key to this KME or another KME, depending on the KME the slave SAE belong to #[derive(Debug, Clone)] -pub struct QkdKey { - /// The ID of the origin (master) SAE - pub(crate) origin_sae_id: i64, - /// The ID of the target (slave) SAE, to which master SAE wants to communicate - pub(crate) target_sae_id: i64, - /// The QKD key, of size [QKD_KEY_SIZE_BITS](crate::QKD_KEY_SIZE_BITS) bits - pub(crate) key: [u8; Self::QKD_KEY_SIZE_BYTES], - /// The UUID of the key, generated from the key hash (sha1) +pub struct PreInitQkdKeyWrapper { + pub(crate) other_kme_id: KmeId, + pub(crate) key: QkdEncKey, pub(crate) key_uuid: uuid::Bytes, } -impl QkdKey { - const QKD_KEY_SIZE_BYTES: usize = crate::QKD_KEY_SIZE_BITS / 8; - - /// Create a new QKD key for a communication between SAEs +impl PreInitQkdKeyWrapper { + /// Create a new pre init QKD key for a future communication between SAEs /// # Arguments /// * `origin_sae_id` - The ID of the origin (master) SAE /// * `target_sae_id` - The ID of the target (slave) SAE, to which master SAE wants to communicate @@ -191,20 +205,21 @@ impl QkdKey { /// A new QKD key /// # Errors /// If key UUID generation fails, which should never happen - pub fn new(origin_sae_id: i64, target_sae_id: i64, key: &[u8; Self::QKD_KEY_SIZE_BYTES]) -> Result { + pub fn new(other_kme_id: KmeId, key: &QkdEncKey) -> Result { Ok(Self { - origin_sae_id, - target_sae_id, + other_kme_id, key: *key, key_uuid: Self::generate_key_uuid(key)?, }) } /// Generate a UUID from a key sha1 hash - fn generate_key_uuid(key: &[u8; Self::QKD_KEY_SIZE_BYTES]) -> Result { + fn generate_key_uuid(key: &QkdEncKey) -> Result { + const UUID_SIZE_BYTES: usize = 16; + let mut hasher = sha1::Sha1::new(); hasher.update(key); - let hash_sub_bytes = uuid::Bytes::try_from(&hasher.finalize()[..16]).map_err(|_| { + let hash_sub_bytes = uuid::Bytes::try_from(&hasher.finalize()[..UUID_SIZE_BYTES]).map_err(|_| { io::Error::new(io::ErrorKind::Other, "Error creating key UUID from key hash") })?; Ok(uuid::Builder::from_sha1_bytes(hash_sub_bytes).as_uuid().to_bytes_le()) @@ -222,34 +237,44 @@ impl QkdKey { #[derive(Debug, Clone, PartialEq)] pub struct SAEInfo { /// The ID of the SAE - pub(crate) sae_id: i64, + pub(crate) sae_id: SaeId, + /// The ID of the KME the SAE belongs to + pub(crate) kme_id: KmeId, /// The serial number of the client certificate identifying the SAE pub(crate) sae_certificate_serial: [u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], } +/// Describes information about a KME +#[derive(Debug, Clone, PartialEq)] +pub struct KMEInfo { + pub(crate) kme_id: KmeId, +} + /// All possible commands to the QKD manager /// # Note /// For QKD manager internal usage, interface should be managed from [QkdManager](QkdManager) implementation functions enum QkdManagerCommand { - /// Add a new QKD key to the database - AddKey(QkdKey), + /// Add a new pre-init QKD key to the database: it will be available for SAEs to request it if there are connected to the right KMEs + AddPreInitKey(PreInitQkdKeyWrapper), /// Get a QKD key from the database (shall be called by the master SAE) - GetKeys([u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], i64), // origin certificate + target id + GetKeys([u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], SaeId), // origin certificate + target id /// Get a list of QKD keys from the database (shall be called by the slave SAE) - GetKeysWithIds([u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], i64, Vec), // origin certificate + target id + GetKeysWithIds([u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], SaeId, Vec), // origin certificate + target id /// Get the status of a key exchange between two SAEs (shall be called by the master SAE) - GetStatus([u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], i64), // origin certificate + target id + GetStatus([u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES], SaeId), // origin certificate + target id /// Add a new SAE to the database (shall be called before SAEs start requesting KME) - AddSae(i64, [u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]), // target id + target certificate + AddSae(SaeId, KmeId, Option<[u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]>), // target id + KME id + target certificate /// Get information about a SAE from its client auth certificate GetSaeInfoFromCertificate([u8; crate::CLIENT_CERT_SERIAL_SIZE_BYTES]), // caller's certificate + /// Returns the KME ID from belonging SAE ID + GetKmeIdFromSaeId(SaeId), // SAE id } /// All possible responses from the QKD manager #[allow(private_interfaces)] #[derive(Debug, PartialEq)] pub enum QkdManagerResponse { - /// The operation was successful, ne more information is provided (e.g. after adding a key or a SAE into the database) + /// The operation was successful, no more information is provided (e.g. after adding a key or a SAE into the database) Ok, /// The operation was not successful, the reason is unknown Ko, @@ -257,6 +282,8 @@ pub enum QkdManagerResponse { NotFound, /// Error during transmission between the QKD manager and the key handler, should never happen TransmissionError, + /// The operation was not successful, the provided SAE data is inconsistent (like an authentication key if the SAE doesn't belong to the KME) + InconsistentSaeData, /// Caller authentication error (likely the provided client certificate serial is not in the database or not authorized to retrieve specified resources) AuthenticationError, /// The operation was successful, the requested key(s) are returned @@ -265,6 +292,8 @@ pub enum QkdManagerResponse { Status(http_response_obj::ResponseQkdKeysStatus), /// The operation was successful, the requested SAE information is returned (for example if GetSaeInfoFromCertificate is called) SaeInfo(SAEInfo), + /// The operation was successful, the requested KME information is returned (for example if GetKmeIdFromSaeId is called) + KmeInfo(KMEInfo), } @@ -275,9 +304,9 @@ mod test { #[test] fn test_add_qkd_key() { const SQLITE_DB_PATH: &'static str = ":memory:"; - let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH); - let key = super::QkdKey::new(1, 2, &[0; super::QkdKey::QKD_KEY_SIZE_BYTES]).unwrap(); - let response = qkd_manager.add_qkd_key(key); + let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH, 1); + let key = super::PreInitQkdKeyWrapper::new(1, &[0; crate::QKD_KEY_SIZE_BYTES]).unwrap(); + let response = qkd_manager.add_pre_init_qkd_key(key); assert!(response.is_ok()); assert_eq!(response.unwrap(), super::QkdManagerResponse::Ok); } @@ -285,59 +314,104 @@ mod test { #[test] fn test_add_sae() { const SQLITE_DB_PATH: &'static str = ":memory:"; - let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH); - let response = qkd_manager.add_sae(1, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]); + let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH, 1); + let response = qkd_manager.add_sae(1, 1, &Some([0; CLIENT_CERT_SERIAL_SIZE_BYTES])); assert!(response.is_ok()); assert_eq!(response.unwrap(), super::QkdManagerResponse::Ok); // Duplicate SAE ID - let response = qkd_manager.add_sae(1, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]); + let response = qkd_manager.add_sae(1, 1, &Some([0; CLIENT_CERT_SERIAL_SIZE_BYTES])); assert!(response.is_err()); assert_eq!(response.unwrap_err(), super::QkdManagerResponse::Ko); + + // Add SAE with key if it doesn't belong to KME1 + let response = qkd_manager.add_sae(2, 2, &Some([0; CLIENT_CERT_SERIAL_SIZE_BYTES])); + assert!(response.is_err()); + assert_eq!(response.unwrap_err(), super::QkdManagerResponse::InconsistentSaeData); + + // Add SAE without key if it doesn't belong to KME1 + let response = qkd_manager.add_sae(2, 2, &None); + assert!(response.is_ok()); + assert_eq!(response.unwrap(), super::QkdManagerResponse::Ok); + + // Add SAE without key if it belongs to KME1 + let response = qkd_manager.add_sae(1, 1, &None); + assert!(response.is_err()); + assert_eq!(response.unwrap_err(), super::QkdManagerResponse::InconsistentSaeData); } #[test] fn test_get_qkd_key() { const SQLITE_DB_PATH: &'static str = ":memory:"; - let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH); - qkd_manager.add_sae(1, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]).unwrap(); - qkd_manager.add_sae(2, &[1; CLIENT_CERT_SERIAL_SIZE_BYTES]).unwrap(); - let key = super::QkdKey::new(1, 2, &[0; super::QkdKey::QKD_KEY_SIZE_BYTES]).unwrap(); - qkd_manager.add_qkd_key(key).unwrap(); - let response = qkd_manager.get_qkd_key(2, &[1; CLIENT_CERT_SERIAL_SIZE_BYTES]); + let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH, 1); + qkd_manager.add_sae(1, 1, &Some([0; CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); + qkd_manager.add_sae(2, 2, &None).unwrap(); // No certificate as this SAE doesn't belong to KME1 + let key = super::PreInitQkdKeyWrapper::new(1, &[0; crate::QKD_KEY_SIZE_BYTES]).unwrap(); + qkd_manager.add_pre_init_qkd_key(key).unwrap(); + let response = qkd_manager.get_qkd_key(2, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]); assert!(response.is_err()); assert_eq!(response.unwrap_err(), super::QkdManagerResponse::NotFound); - let response = qkd_manager.get_qkd_key(2, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]); + let response = qkd_manager.get_qkd_key(1, &[1; CLIENT_CERT_SERIAL_SIZE_BYTES]); + assert!(response.is_err()); + assert_eq!(response.unwrap_err(), super::QkdManagerResponse::AuthenticationError); + + + let response = qkd_manager.get_qkd_key(1, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]); assert!(response.is_ok()); } #[test] fn test_get_qkd_keys_with_ids() { const SQLITE_DB_PATH: &'static str = ":memory:"; - let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH); - qkd_manager.add_sae(1, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]).unwrap(); - qkd_manager.add_sae(2, &[1; CLIENT_CERT_SERIAL_SIZE_BYTES]).unwrap(); - let key = super::QkdKey::new(1, 2, &[0; super::QkdKey::QKD_KEY_SIZE_BYTES]).unwrap(); + let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH, 1); + qkd_manager.add_sae(1, 1, &Some([0; CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); + qkd_manager.add_sae(2, 1, &Some([1; CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); + let key = super::PreInitQkdKeyWrapper::new(1,&[0; crate::QKD_KEY_SIZE_BYTES]).unwrap(); let key_uuid = key.get_uuid(); let key_uuid_str = uuid::Uuid::from_bytes(key_uuid).to_string(); - qkd_manager.add_qkd_key(key).unwrap(); + qkd_manager.add_pre_init_qkd_key(key).unwrap(); let response = qkd_manager.get_qkd_keys_with_ids(2, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES], vec![key_uuid_str.clone()]); assert!(response.is_err()); assert_eq!(response.unwrap_err(), super::QkdManagerResponse::NotFound); + let response = qkd_manager.get_qkd_keys_with_ids(1, &[1; CLIENT_CERT_SERIAL_SIZE_BYTES], vec![key_uuid_str.clone()]); + assert!(response.is_err()); + assert_eq!(response.unwrap_err(), super::QkdManagerResponse::NotFound); + + let response = qkd_manager.get_qkd_key(1, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]); + assert!(response.is_ok()); + let response = qkd_manager.get_qkd_keys_with_ids(1, &[1; CLIENT_CERT_SERIAL_SIZE_BYTES], vec![key_uuid_str.clone()]); + assert!(response.is_err()); + assert_eq!(response.unwrap_err(), super::QkdManagerResponse::NotFound); + + let response = qkd_manager.get_qkd_key(2, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]); + assert!(response.is_err()); + let response = qkd_manager.get_qkd_keys_with_ids(1, &[1; CLIENT_CERT_SERIAL_SIZE_BYTES], vec![key_uuid_str.clone()]); + assert!(response.is_err()); + assert_eq!(response.unwrap_err(), super::QkdManagerResponse::NotFound); + + + let key = super::PreInitQkdKeyWrapper::new(1,&[1; crate::QKD_KEY_SIZE_BYTES]).unwrap(); + let key_uuid = key.get_uuid(); + let key_uuid_str = uuid::Uuid::from_bytes(key_uuid).to_string(); + qkd_manager.add_pre_init_qkd_key(key).unwrap(); + + let response = qkd_manager.get_qkd_key(2, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]); + assert!(response.is_ok()); let response = qkd_manager.get_qkd_keys_with_ids(1, &[1; CLIENT_CERT_SERIAL_SIZE_BYTES], vec![key_uuid_str.clone()]); assert!(response.is_ok()); + assert!(matches!(response.unwrap(), super::QkdManagerResponse::Keys(_))); } #[test] fn test_get_qkd_key_status() { const SQLITE_DB_PATH: &'static str = ":memory:"; - let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH); - qkd_manager.add_sae(1, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]).unwrap(); - qkd_manager.add_sae(2, &[1; CLIENT_CERT_SERIAL_SIZE_BYTES]).unwrap(); - let key = super::QkdKey::new(1, 2, &[0; super::QkdKey::QKD_KEY_SIZE_BYTES]).unwrap(); - qkd_manager.add_qkd_key(key).unwrap(); + let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH, 1); + qkd_manager.add_sae(1, 1, &Some([0; CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); + qkd_manager.add_sae(2, 1, &Some([1; CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); + let key = super::PreInitQkdKeyWrapper::new(1, &[0; crate::QKD_KEY_SIZE_BYTES]).unwrap(); + qkd_manager.add_pre_init_qkd_key(key).unwrap(); let response = qkd_manager.get_qkd_key_status(&[1; CLIENT_CERT_SERIAL_SIZE_BYTES], 2); assert!(response.is_ok()); assert!(matches!(response.unwrap(), super::QkdManagerResponse::Status(_))); @@ -345,7 +419,7 @@ mod test { #[test] fn test_key_uuid() { - let key = super::QkdKey::new(1, 2, &[0; super::QkdKey::QKD_KEY_SIZE_BYTES]).unwrap(); + let key = super::PreInitQkdKeyWrapper::new(1, &[0; crate::QKD_KEY_SIZE_BYTES]).unwrap(); let key_uuid = key.get_uuid(); let key_uuid_str = uuid::Uuid::from_bytes(key_uuid).to_string(); assert_eq!(key_uuid_str, "7b848ade-8cff-3d54-a9b8-53a215e6ee77"); @@ -354,15 +428,32 @@ mod test { #[test] fn test_get_sae_info_from_client_auth_certificate() { const SQLITE_DB_PATH: &'static str = ":memory:"; - let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH); - qkd_manager.add_sae(1, &[0; CLIENT_CERT_SERIAL_SIZE_BYTES]).unwrap(); + let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH, 1); + qkd_manager.add_sae(1, 1, &Some([0; CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); let response = qkd_manager.get_sae_info_from_client_auth_certificate(&[0; CLIENT_CERT_SERIAL_SIZE_BYTES]); assert!(response.is_ok()); - assert_eq!(response.unwrap().sae_id, 1); + let response = response.unwrap(); + assert_eq!(response.sae_id, 1); + assert_eq!(response.kme_id, 1); // SAE certificate not present in database let response = qkd_manager.get_sae_info_from_client_auth_certificate(&[1; CLIENT_CERT_SERIAL_SIZE_BYTES]); assert!(response.is_err()); assert_eq!(response.unwrap_err(), super::QkdManagerResponse::NotFound); } + + #[test] + fn test_get_kme_id_from_sae_id() { + const SQLITE_DB_PATH: &'static str = ":memory:"; + let qkd_manager = super::QkdManager::new(SQLITE_DB_PATH, 1); + qkd_manager.add_sae(1, 1, &Some([0; CLIENT_CERT_SERIAL_SIZE_BYTES])).unwrap(); + let response = qkd_manager.get_kme_id_from_sae_id(1); + assert!(response.is_some()); + let response = response.unwrap(); + assert_eq!(response.kme_id, 1); + + // SAE ID not present in database + let response = qkd_manager.get_kme_id_from_sae_id(2); + assert!(response.is_none()); + } } \ No newline at end of file diff --git a/src/qkd_manager/router.rs b/src/qkd_manager/router.rs new file mode 100644 index 0000000..1a6b494 --- /dev/null +++ b/src/qkd_manager/router.rs @@ -0,0 +1,26 @@ +//! QKD network routing manager, get route to SAE + +use std::collections::HashMap; + +#[allow(dead_code)] +#[derive(Clone)] +pub(super) struct QkdRouter { + sae_to_kme_associations: HashMap, +} + +#[allow(dead_code)] +impl QkdRouter { + pub(super) fn new() -> Self { + Self { + sae_to_kme_associations: HashMap::new(), + } + } + + pub(super) fn add_sae_to_kme_association(&mut self, sae_id: i64, kme_id: i64) { + self.sae_to_kme_associations.insert(sae_id, kme_id); + } + + pub(super) fn get_kme_id_from_sae_id(&self, sae_id: i64) -> Option<&i64> { + self.sae_to_kme_associations.get(&sae_id) + } +} \ No newline at end of file diff --git a/src/routes/request_context.rs b/src/routes/request_context.rs index e335981..9ba092f 100644 --- a/src/routes/request_context.rs +++ b/src/routes/request_context.rs @@ -102,7 +102,7 @@ mod test { #[test] fn test_context_no_cert() { - let context = super::RequestContext::new(None, crate::qkd_manager::QkdManager::new(":memory:")).unwrap(); + let context = super::RequestContext::new(None, crate::qkd_manager::QkdManager::new(":memory:", 1)).unwrap(); assert!(!context.has_client_certificate()); assert!(context.get_client_certificate_cn().is_err()); assert!(context.get_client_certificate_serial_as_string().is_err()); @@ -114,7 +114,7 @@ mod test { const CERT_FILENAME: &'static str = "certs/kme1.crt"; let certs = load_cert(CERT_FILENAME).unwrap(); assert_eq!(certs.len(), 1); - let context = super::RequestContext::new(Some(&certs[0]), crate::qkd_manager::QkdManager::new(":memory:")).unwrap(); + let context = super::RequestContext::new(Some(&certs[0]), crate::qkd_manager::QkdManager::new(":memory:", 1)).unwrap(); assert!(context.has_client_certificate()); assert_eq!(context.get_client_certificate_cn().unwrap(), "localhost"); assert_eq!(context.get_client_certificate_serial_as_string().unwrap(), "70:f4:4f:56:0c:3f:27:d4:b2:11:a4:78:13:af:d0:3c:03:81:3b:8d"); diff --git a/src/routes/sae/info.rs b/src/routes/sae/info.rs index 545a7b9..227cf22 100644 --- a/src/routes/sae/info.rs +++ b/src/routes/sae/info.rs @@ -20,8 +20,8 @@ pub(in crate::routes) async fn route_get_info_me(rcx: &RequestContext<'_>, _req: } }; // Retrieve the SAE ID from the QKD manager, given the client certificate serial - let sae_id = match rcx.qkd_manager.get_sae_info_from_client_auth_certificate(client_cert_serial) { - Ok(sae_id) => sae_id.sae_id, + let sae_info = match rcx.qkd_manager.get_sae_info_from_client_auth_certificate(client_cert_serial) { + Ok(sae_info) => sae_info, Err(_) => { // Client certificate serial isn't registered in the QKD manager return crate::routes::QKDKMERoutesV1::not_found() @@ -30,7 +30,8 @@ pub(in crate::routes) async fn route_get_info_me(rcx: &RequestContext<'_>, _req: // Create the response object let sae_info_response_obj = crate::qkd_manager::http_response_obj::ResponseQkdSAEInfo { - SAE_ID: sae_id, + SAE_ID: sae_info.sae_id, + KME_ID: sae_info.kme_id, }; match sae_info_response_obj.to_json() { Ok(json) => { diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 54ef51b..a22c3c8 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -2,7 +2,7 @@ use std::fs::File; use std::io::Read; -use qkd_kme_server::qkd_manager::{QkdKey, QkdManager}; +use qkd_kme_server::qkd_manager::{PreInitQkdKeyWrapper, QkdManager}; use qkd_kme_server::routes::QKDKMERoutesV1; pub const HOST_PORT: &'static str = "localhost:3000"; @@ -15,24 +15,23 @@ pub fn setup() { server_key_path: "certs/kme1.key".to_string(), }; - let qkd_manager = QkdManager::new(":memory:"); + let qkd_manager = QkdManager::new(":memory:", 1); qkd_manager.add_sae(1, - &[0x70, 0xf4, 0x4f, 0x56, 0x0c, 0x3f, 0x27, 0xd4, 0xb2, 0x11, 0xa4, 0x78, 0x13, 0xaf, 0xd0, 0x3c, 0x03, 0x81, 0x3b, 0x8e] + 1, + &Some([0x70, 0xf4, 0x4f, 0x56, 0x0c, 0x3f, 0x27, 0xd4, 0xb2, 0x11, 0xa4, 0x78, 0x13, 0xaf, 0xd0, 0x3c, 0x03, 0x81, 0x3b, 0x8e]) ).unwrap(); - let qkd_key_1 = QkdKey::new( + let qkd_key_1 = PreInitQkdKeyWrapper::new( 1, - 2, b"this_is_secret_key_1_of_32_bytes", ).unwrap(); - qkd_manager.add_qkd_key(qkd_key_1).unwrap(); - let qkd_key_2 = QkdKey::new( - 1, + qkd_manager.add_pre_init_qkd_key(qkd_key_1).unwrap(); + let qkd_key_2 = PreInitQkdKeyWrapper::new( 1, b"this_is_secret_key_1_of_32_bytes", ).unwrap(); - qkd_manager.add_qkd_key(qkd_key_2).unwrap(); + qkd_manager.add_pre_init_qkd_key(qkd_key_2).unwrap(); tokio::spawn(async move {server.run::(&qkd_manager).await.unwrap();}); } diff --git a/tests/data/key_status.json b/tests/data/key_status.json index c1d90a9..a8baa58 100644 --- a/tests/data/key_status.json +++ b/tests/data/key_status.json @@ -1,10 +1,10 @@ { "source_KME_ID": "1", - "target_KME_ID": "?? TODO", + "target_KME_ID": "1", "master_SAE_ID": "1", - "slave_SAE_ID": "2", + "slave_SAE_ID": "1", "key_size": 256, - "stored_key_count": 1, + "stored_key_count": 2, "max_key_count": 10, "max_key_per_request": 1, "max_key_size": 256, diff --git a/tests/data/sae_info_me.json b/tests/data/sae_info_me.json index d89eb96..a0eb620 100644 --- a/tests/data/sae_info_me.json +++ b/tests/data/sae_info_me.json @@ -1,3 +1,4 @@ { - "SAE_ID": 1 + "SAE_ID": 1, + "KME_ID": 1 } \ No newline at end of file diff --git a/tests/dec_keys.rs b/tests/dec_keys.rs index d864d80..6dbaff4 100644 --- a/tests/dec_keys.rs +++ b/tests/dec_keys.rs @@ -10,10 +10,17 @@ async fn post_dec_keys() { const EXPECTED_BODY: &'static str = include_str!("data/dec_keys.json"); const SENT_BODY: &'static str = include_str!("data/dec_keys_post_req_body.json"); const REQUEST_URL: &'static str = concatcp!("https://", common::HOST_PORT ,"/api/v1/keys/1/dec_keys"); + const INIT_POST_KEY_REQUEST_URL: &'static str = concatcp!("https://", common::HOST_PORT ,"/api/v1/keys/1/enc_keys"); common::setup(); let reqwest_client = common::setup_cert_auth_reqwest_client(); + + let post_key_response = reqwest_client.post(INIT_POST_KEY_REQUEST_URL).send().await; + assert!(post_key_response.is_ok()); + let post_key_response = post_key_response.unwrap(); + assert_eq!(post_key_response.status(), 200); + let response = reqwest_client.post(REQUEST_URL).header(CONTENT_TYPE, "application/json").body(SENT_BODY).send().await; assert!(response.is_ok()); let response = response.unwrap(); @@ -22,6 +29,25 @@ async fn post_dec_keys() { assert_eq!(response_body, EXPECTED_BODY); } + +#[tokio::test] +#[serial] +async fn post_dec_keys_not_init() { + const EXPECTED_BODY: &'static str = include_str!("data/not_found_body.json"); + const SENT_BODY: &'static str = include_str!("data/dec_keys_post_req_body.json"); + const REQUEST_URL: &'static str = concatcp!("https://", common::HOST_PORT ,"/api/v1/keys/1/dec_keys"); + + common::setup(); + let reqwest_client = common::setup_cert_auth_reqwest_client(); + + let response = reqwest_client.post(REQUEST_URL).header(CONTENT_TYPE, "application/json").body(SENT_BODY).send().await; + assert!(response.is_ok()); + let response = response.unwrap(); + assert_eq!(response.status(), 404); + let response_body = response.text().await.unwrap(); + assert_eq!(response_body, EXPECTED_BODY); +} + #[tokio::test] #[serial] async fn post_dec_keys_no_body() { diff --git a/tests/enc_keys.rs b/tests/enc_keys.rs index 8294d08..a74ee9f 100644 --- a/tests/enc_keys.rs +++ b/tests/enc_keys.rs @@ -7,7 +7,7 @@ mod common; #[serial] async fn post_enc_keys() { const EXPECTED_BODY: &'static str = include_str!("data/enc_keys.json"); - const REQUEST_URL: &'static str = concatcp!("https://", common::HOST_PORT ,"/api/v1/keys/2/enc_keys"); + const REQUEST_URL: &'static str = concatcp!("https://", common::HOST_PORT ,"/api/v1/keys/1/enc_keys"); common::setup(); let reqwest_client = common::setup_cert_auth_reqwest_client(); @@ -18,4 +18,21 @@ async fn post_enc_keys() { assert_eq!(response.status(), 200); let response_body = response.text().await.unwrap(); assert_eq!(response_body, EXPECTED_BODY); +} + +#[tokio::test] +#[serial] +async fn post_enc_keys_sae_not_found() { + const EXPECTED_BODY: &'static str = include_str!("data/not_found_body.json"); + const REQUEST_URL: &'static str = concatcp!("https://", common::HOST_PORT ,"/api/v1/keys/2/enc_keys"); + + common::setup(); + let reqwest_client = common::setup_cert_auth_reqwest_client(); + + let response = reqwest_client.post(REQUEST_URL).send().await; + assert!(response.is_ok()); + let response = response.unwrap(); + assert_eq!(response.status(), 404); + let response_body = response.text().await.unwrap(); + assert_eq!(response_body, EXPECTED_BODY); } \ No newline at end of file diff --git a/tests/key_status.rs b/tests/key_status.rs index c9de77f..160def2 100644 --- a/tests/key_status.rs +++ b/tests/key_status.rs @@ -7,7 +7,7 @@ mod common; #[serial] async fn get_key_status() { const EXPECTED_BODY: &'static str = include_str!("data/key_status.json"); - const REQUEST_URL: &'static str = concatcp!("https://", common::HOST_PORT ,"/api/v1/keys/2/status"); + const REQUEST_URL: &'static str = concatcp!("https://", common::HOST_PORT ,"/api/v1/keys/1/status"); common::setup(); let reqwest_client = common::setup_cert_auth_reqwest_client();