diff --git a/src/SSLClient.cpp b/src/SSLClient.cpp index 6d50e70..149c2d2 100644 --- a/src/SSLClient.cpp +++ b/src/SSLClient.cpp @@ -23,77 +23,95 @@ #undef write #undef read - -SSLClient::SSLClient() -{ - _connected = false; - - sslclient = new sslclient_context; - ssl_init(sslclient, nullptr); - sslclient->handshake_timeout = 120000; - _CA_cert = NULL; - _cert = NULL; - _private_key = NULL; - _pskIdent = NULL; - _psKey = NULL; +/** + * @brief Construct a new SSLClient::SSLClient object using the default constructor. + */ +SSLClient::SSLClient() { + _connected = false; + sslclient = new sslclient_context; + ssl_init(sslclient, nullptr); + sslclient->handshake_timeout = 120000; + _CA_cert = NULL; + _cert = NULL; + _private_key = NULL; + _pskIdent = NULL; + _psKey = NULL; } -SSLClient::SSLClient(Client* client) -{ - _connected = false; - - sslclient = new sslclient_context; - ssl_init(sslclient, client); - sslclient->handshake_timeout = 120000; - _CA_cert = NULL; - _cert = NULL; - _private_key = NULL; - _pskIdent = NULL; - _psKey = NULL; - +/** + * @brief Construct a new SSLClient::SSLClient object using the pointer to the specified client. + * + * @param client + */ +SSLClient::SSLClient(Client* client) { + _connected = false; + sslclient = new sslclient_context; + ssl_init(sslclient, client); + sslclient->handshake_timeout = 120000; + _CA_cert = NULL; + _cert = NULL; + _private_key = NULL; + _pskIdent = NULL; + _psKey = NULL; } -SSLClient::~SSLClient() -{ - stop(); - delete sslclient; +/** + * @brief Destroy the SSLClient::SSLClient object. + */ +SSLClient::~SSLClient() { + stop(); + delete sslclient; } +/** + * @brief Stops the SSL client. + */ void SSLClient::stop() { - if (sslclient->client != nullptr) { - if (sslclient->client >= 0) { - log_v("Stopping ssl client"); - stop_ssl_socket(sslclient, _CA_cert, _cert, _private_key); - } else { - log_v("stop() not called because client is < 0"); - } + if (sslclient->client != nullptr) { + if (sslclient->client >= 0) { + log_v("Stopping ssl client"); + stop_ssl_socket(sslclient, _CA_cert, _cert, _private_key); } else { - log_v("stop() not called because client is nullptr"); + log_v("stop() not called because client is < 0"); } - _connected = false; - _peek = -1; + } else { + log_v("stop() not called because client is nullptr"); + } + _connected = false; + _peek = -1; } -int SSLClient::connect(IPAddress ip, uint16_t port) -{ - if (_pskIdent && _psKey) - return connect(ip, port, _pskIdent, _psKey); - return connect(ip, port, _CA_cert, _cert, _private_key); +/** + * @brief + * + * @param ip + * @param port + * @return int + */ +int SSLClient::connect(IPAddress ip, uint16_t port) { + if (_pskIdent && _psKey) { + log_v("connect with PSK"); + return connect(ip, port, _pskIdent, _psKey); + } + log_v("connect with CA"); + return connect(ip, port, _CA_cert, _cert, _private_key); } -int SSLClient::connect(IPAddress ip, uint16_t port, int32_t timeout){ - _timeout = timeout; - return connect(ip, port); +int SSLClient::connect(IPAddress ip, uint16_t port, int32_t timeout) { + _timeout = timeout; + return connect(ip, port); } -int SSLClient::connect(const char *host, uint16_t port) -{ - if (_pskIdent && _psKey) - return connect(host, port, _pskIdent, _psKey); - return connect(host, port, _CA_cert, _cert, _private_key); +int SSLClient::connect(const char *host, uint16_t port) { + if (_pskIdent && _psKey) { + log_v("connect with PSK"); + return connect(host, port, _pskIdent, _psKey); + } + log_v("connect with CA"); + return connect(host, port, _CA_cert, _cert, _private_key); } -int SSLClient::connect(const char *host, uint16_t port, int32_t timeout){ +int SSLClient::connect(const char *host, uint16_t port, int32_t timeout) { _timeout = timeout; return connect(host, port); } @@ -178,48 +196,74 @@ size_t SSLClient::write(const uint8_t *buf, size_t size) return res; } -int SSLClient::read(uint8_t *buf, size_t size) -{ - int peeked = 0; - int avail = available(); - if ((!buf && size) || avail <= 0) { - return -1; - } - if(!size){ - return 0; - } - if(_peek >= 0){ - buf[0] = _peek; - _peek = -1; - size--; - avail--; - if(!size || !avail){ - return 1; - } - buf++; - peeked = 1; - } +/** + * \brief Reads data from the sslclient. If there is a byte peeked, it returns that byte. + * + * \param buf Buffer to read into. + * \param size Size of the buffer. + * \return int 1 if a byte has been peeked and the client is not connected. + * \return int < 1 if client is connected and there is an error from get_ssl_receive(). + * \return int > 1 if res + peeked. + */ +int SSLClient::read(uint8_t *buf, size_t size) { + log_v("This is the iClient->read() implementation"); + int peeked = 0; + int avail = available(); + + if ((!buf && size) || avail <= 0) { + return -1; // return error if no buffer or nothing to read. + } - int res = get_ssl_receive(sslclient, buf, size); - if (res < 0) { - stop(); - return peeked?peeked:res; + if (!size) { + return 0; // return 0 if no bytes requested. + } + + if (_peek >= 0) { + buf[0] = _peek; // Places this peeked byte at the start of the buffer. + _peek = -1; // Resets _peek to -1 to indicate no bytes are currently peeked. + size--; // Decreases the available size (size) by 1. + avail--; // Decreases the available bytes (avail) by 1. + if (!size || !avail) { // If there's no space left in the buffer (size) or no data left to read (avail) + return 1; // Return 1 to indicate one byte has been read. } - return res + peeked; + buf++; // Increment the buffer pointer. + peeked = 1; // set peeked to 1 to indicate one byte has been read from the peeked value. + } + + int res = get_ssl_receive(sslclient, buf, size); + + if (res < 0) { + stop(); + return peeked?peeked:res; // If peeked is true return peeked, otherwise return res, i.e. data_to_read error. + } + + return res + peeked; // Return the number of bytes read + the number of bytes peeked. } -int SSLClient::available() -{ - int peeked = (_peek >= 0); - if (!_connected) { - return peeked; - } - int res = data_to_read(sslclient); - if (res < 0) { - stop(); - return peeked?peeked:res; - } - return res+peeked; +/** + * \brief Returns how many bytes of data are available to be read from the sslclient. + * It takes into account both directly readable bytes and a potentially "peeked" byte. + * If there's an error or the client is not connected, it handles these scenarios appropriately. + * + * \return int 1 if a byte has been peeked and the client is not connected. + * \return int < 1 if client is connected and there is an error from data_to_read(). + * \return int > 1 if res + peeked. + */ +int SSLClient::available() { + int peeked = (_peek >= 0); // 1 if a byte has been peeked (available to read without advancing the read pointer) + + if (!_connected) { + return peeked; + } + + int res = data_to_read(sslclient); // how many bytes available to read. + + if (res < 0) { + stop(); + return peeked?peeked:res; // If peeked is true return peeked, otherwise return res, i.e. data_to_read error. + } + + return res+peeked; } uint8_t SSLClient::connected() diff --git a/src/SSLClient.h b/src/SSLClient.h index 7c54de6..164702f 100644 --- a/src/SSLClient.h +++ b/src/SSLClient.h @@ -25,78 +25,75 @@ class SSLClient : public Client { protected: - sslclient_context *sslclient; + sslclient_context *sslclient; - int _lastError = 0; + int _lastError = 0; int _peek = -1; - int _timeout = 0; - const char *_CA_cert; - const char *_cert; - const char *_private_key; - const char *_pskIdent; // identity for PSK cipher suites - const char *_psKey; // key in hex for PSK cipher suites + int _timeout = 0; + const char *_CA_cert; + const char *_cert; + const char *_private_key; + const char *_pskIdent; // identity for PSK cipher suites + const char *_psKey; // key in hex for PSK cipher suites - bool _connected = false; + bool _connected = false; - Client* _client = nullptr; + Client* _client = nullptr; public: - SSLClient(); - SSLClient(Client* client); - ~SSLClient(); + SSLClient(); + SSLClient(Client* client); + ~SSLClient(); - int connect(IPAddress ip, uint16_t port); - int connect(IPAddress ip, uint16_t port, int32_t timeout); - int connect(const char *host, uint16_t port); - int connect(const char *host, uint16_t port, int32_t timeout); - int connect(IPAddress ip, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key); - int connect(const char *host, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key); - int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey); - int connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey); + int connect(IPAddress ip, uint16_t port); + int connect(IPAddress ip, uint16_t port, int32_t timeout); + int connect(const char *host, uint16_t port); + int connect(const char *host, uint16_t port, int32_t timeout); + int connect(IPAddress ip, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key); + int connect(const char *host, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key); + int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey); + int connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey); int peek(); - size_t write(uint8_t data); - size_t write(const uint8_t *buf, size_t size); - int available(); - int read(); - int read(uint8_t *buf, size_t size); - void flush() {} - void stop(); - uint8_t connected(); - int lastError(char *buf, const size_t size); + size_t write(uint8_t data); + size_t write(const uint8_t *buf, size_t size); + int available(); + int read(); + int read(uint8_t *buf, size_t size); + void flush() {} + void stop(); + uint8_t connected(); + int lastError(char *buf, const size_t size); - void setPreSharedKey(const char *pskIdent, const char *psKey); // psKey in Hex - void setCACert(const char *rootCA); - void setCertificate(const char *client_ca); - void setPrivateKey (const char *private_key); - bool loadCACert(Stream& stream, size_t size); - bool loadCertificate(Stream& stream, size_t size); - bool loadPrivateKey(Stream& stream, size_t size); - bool verify(const char* fingerprint, const char* domain_name); - void setHandshakeTimeout(unsigned long handshake_timeout); - void setClient(Client* client); + void setPreSharedKey(const char *pskIdent, const char *psKey); // psKey in Hex + void setCACert(const char *rootCA); + void setCertificate(const char *client_ca); + void setPrivateKey (const char *private_key); + bool loadCACert(Stream& stream, size_t size); + bool loadCertificate(Stream& stream, size_t size); + bool loadPrivateKey(Stream& stream, size_t size); + bool verify(const char* fingerprint, const char* domain_name); + void setHandshakeTimeout(unsigned long handshake_timeout); + void setClient(Client* client); + int setTimeout(uint32_t seconds){ return 0; } + operator bool() { + return connected(); + } - int setTimeout(uint32_t seconds){ return 0; } + bool operator==(const bool value) { + return bool() == value; + } - operator bool() - { - return connected(); - } - bool operator==(const bool value) - { - return bool() == value; - } - bool operator!=(const bool value) - { - return bool() != value; - } + bool operator!=(const bool value) { + return bool() != value; + } private: - char *_streamLoad(Stream& stream, size_t size); + char *_streamLoad(Stream& stream, size_t size); - //friend class GprsServer; - using Print::write; + //friend class GprsServer; + using Print::write; }; #endif /* SSLClient_H */ diff --git a/src/ssl_client.cpp b/src/ssl_client.cpp index 7afdf23..73c4a7a 100644 --- a/src/ssl_client.cpp +++ b/src/ssl_client.cpp @@ -201,7 +201,7 @@ void ssl_init(sslclient_context *ssl_client, Client *client) * \param ssl_client sslclient_context* - The ssl client context. * \param host const char* - The host to connect to. * \param port uint32_t - The port to connect to. - * \return int 1 if successful, -1 if failed. + * \return int 0 if successful. */ int initialize_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port) { log_v("Connecting to %s:%d", host, port); @@ -214,8 +214,6 @@ int initialize_ssl_client(sslclient_context *ssl_client, const char *host, uint3 if (ssl_client->client == nullptr) { log_w("ssl_client->client is not initialized!"); return -1; - } else { - log_i("ssl_client->client is initialized"); } Client *pClient = ssl_client->client; @@ -230,17 +228,17 @@ int initialize_ssl_client(sslclient_context *ssl_client, const char *host, uint3 return -1; } - return 1; + return 0; } /** * \brief Seed the random number generator. * * \param ssl_client sslclient_context* - The ssl client context. - * \return int 1 if successful, -1 if failed. + * \return int 0 if successful. */ int seed_rng(sslclient_context *ssl_client) { - int ret; + int ret = -1; log_v("Seeding the random number generator"); @@ -250,26 +248,26 @@ int seed_rng(sslclient_context *ssl_client) { // Seed the random number generator ret = mbedtls_ctr_drbg_seed(&ssl_client->drbg_ctx, mbedtls_entropy_func, &ssl_client->entropy_ctx, (const unsigned char *) pers, strlen(pers)); - if (ret < 0) { - return handle_error(ret); // You might need to adjust handle_error() to make it more specific to seeding RNG errors + if (ret != 0) { + return handle_error(ret); } - return 1; + return ret; } /** * \brief Configure the SSL/TLS structure. This function is used when no CA certificate is defined. * * \param ssl_client sslclient_context* - The ssl client context. - * \return int 1 if successful, -1 if failed. + * \return int 0 if successful. */ static int configure_default_ssl(sslclient_context *ssl_client) { log_v("No cert provided. Using default cert verification"); - int ret = mbedtls_ssl_config_defaults(&ssl_client->ssl_conf, + + return mbedtls_ssl_config_defaults(&ssl_client->ssl_conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); - return ret; } /** @@ -277,7 +275,7 @@ static int configure_default_ssl(sslclient_context *ssl_client) { * * \param ssl_client sslclient_context* - The ssl client context. * \param rootCABuff const char* - The root CA certificate. - * \return int 1 if successful, -1 if failed. + * \return int 0 if successful, +int num certs failed, -int X509 or PEM error code. */ static int configure_ca_cert(sslclient_context *ssl_client, const char *rootCABuff) { log_v("Loading CA cert"); @@ -285,7 +283,10 @@ static int configure_ca_cert(sslclient_context *ssl_client, const char *rootCABu mbedtls_x509_crt_init(&ssl_client->ca_cert); mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_REQUIRED); int ret = mbedtls_x509_crt_parse(&ssl_client->ca_cert, (const unsigned char *)rootCABuff, strlen(rootCABuff) + 1); - mbedtls_ssl_conf_ca_chain(&ssl_client->ssl_conf, &ssl_client->ca_cert, NULL); + + if (ret == 0) { + mbedtls_ssl_conf_ca_chain(&ssl_client->ssl_conf, &ssl_client->ca_cert, NULL); + } return ret; } @@ -296,18 +297,34 @@ static int configure_ca_cert(sslclient_context *ssl_client, const char *rootCABu * \param ssl_client sslclient_context* - The ssl client context. * \param pskIdent const char* - The PSK identity. * \param psKey const char* - The PSK key. - * \return int 1 if successful, -1 if failed. + * \return int 0 if successful. */ static int configure_psk(sslclient_context *ssl_client, const char *pskIdent, const char *psKey) { + int ret = -1; log_v("Setting up PSK"); unsigned char psk[MBEDTLS_PSK_MAX_LEN]; size_t psk_len = strlen(psKey) / 2; - // [ ... Convert PSK from hex to binary logic ... ] - - int ret = mbedtls_ssl_conf_psk(&ssl_client->ssl_conf, psk, psk_len, - (const unsigned char *)pskIdent, strlen(pskIdent)); + for (int j=0; j= '0' && c <= '9') c -= '0'; + else if (c >= 'A' && c <= 'F') c -= 'A' - 10; + else if (c >= 'a' && c <= 'f') c -= 'a' - 10; + else return -1; + psk[j/2] = c<<4; + c = psKey[j+1]; + if (c >= '0' && c <= '9') c -= '0'; + else if (c >= 'A' && c <= 'F') c -= 'A' - 10; + else if (c >= 'a' && c <= 'f') c -= 'a' - 10; + else return -1; + psk[j/2] |= c; + } + ret = mbedtls_ssl_conf_psk(&ssl_client->ssl_conf, psk, psk_len, (const unsigned char *)pskIdent, strlen(pskIdent)); + if (ret != 0) { + log_e("mbedtls_ssl_conf_psk returned %d", ret); + return handle_error(ret); + } return ret; } @@ -317,54 +334,60 @@ static int configure_psk(sslclient_context *ssl_client, const char *pskIdent, co * \param ssl_client sslclient_context* - The ssl client context. * \param cli_cert const char* - The client certificate. * \param cli_key const char* - The client key. - * \return int 1 if successful, -1 if failed. + * \return int 0 if successful, +int num certs failed, -int X509 or PEM error code. */ static int configure_client_cert_key(sslclient_context *ssl_client, const char *cli_cert, const char *cli_key) { + int ret = -1; mbedtls_x509_crt_init(&ssl_client->client_cert); mbedtls_pk_init(&ssl_client->client_key); log_v("Loading CRT cert"); - int ret = mbedtls_x509_crt_parse(&ssl_client->client_cert, (const unsigned char *)cli_cert, strlen(cli_cert) + 1); - if (ret < 0) { - return ret; + ret = mbedtls_x509_crt_parse(&ssl_client->client_cert, (const unsigned char *)cli_cert, strlen(cli_cert) + 1); + if (ret != 0) { + return handle_error(ret); } log_v("Loading private key"); ret = mbedtls_pk_parse_key(&ssl_client->client_key, (const unsigned char *)cli_key, strlen(cli_key) + 1, NULL, 0); if (ret != 0) { mbedtls_x509_crt_free(&ssl_client->client_cert); - return ret; + return handle_error(ret); } - mbedtls_ssl_conf_own_cert(&ssl_client->ssl_conf, &ssl_client->client_cert, &ssl_client->client_key); + log_v("Setting own certification chain"); + ret = mbedtls_ssl_conf_own_cert(&ssl_client->ssl_conf, &ssl_client->client_cert, &ssl_client->client_key); + if (ret != 0) { + return ret; + } return ret; } /** - * \brief Set the up ssl configuration object. + * \brief Set the up ssl configuration object. * - * \param ssl_client sslclient_context* - The ssl client context. - * \param rootCABuff const char* - The root CA certificate. - * \param cli_cert const char* - The client certificate. - * \param cli_key const char* - The client key. - * \param pskIdent const char* - The PSK identity. - * \param psKey const char* - The PSK key. - * \return int 1 if successful, -1 if failed. + * \param ssl_client sslclient_context* - The ssl client context. + * \param rootCABuff const char* - The root CA certificate. + * \param cli_cert const char* - The client certificate. + * \param cli_key const char* - The client key. + * \param pskIdent const char* - The PSK identity. + * \param psKey const char* - The PSK key. + * \return int 0 if successful, +int num certs failed, -int X509 or PEM error code. */ int setup_ssl_configuration(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey) { - int ret; + int ret = -1; log_v("Setting up the SSL/TLS structure..."); ret = configure_default_ssl(ssl_client); + if (ret != 0) { return handle_error(ret); } if (rootCABuff != NULL) { ret = configure_ca_cert(ssl_client, rootCABuff); - if (ret < 0) { + if (ret != 0) { return handle_error(ret); } } else if (pskIdent != NULL && psKey != NULL) { @@ -380,12 +403,12 @@ int setup_ssl_configuration(sslclient_context *ssl_client, const char *rootCABuf if (cli_cert != NULL && cli_key != NULL) { ret = configure_client_cert_key(ssl_client, cli_cert, cli_key); - if (ret < 0) { + if (ret != 0) { return handle_error(ret); } } - return 1; + return ret; } /** @@ -395,10 +418,10 @@ int setup_ssl_configuration(sslclient_context *ssl_client, const char *rootCABuf * \param rootCABuff const char* - The root CA certificate. * \param cli_cert const char* - The client certificate. * \param cli_key const char* - The client key. - * \return int 1 if successful, -1 if failed. + * \return int 0 if successful, +int num certs failed, -int X509 or PEM error code. */ int load_certificates_and_keys(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key) { - int ret; + int ret = -1; // Load the CA root certificate if (rootCABuff) { @@ -420,23 +443,24 @@ int load_certificates_and_keys(sslclient_context *ssl_client, const char *rootCA ret = mbedtls_x509_crt_parse(&ssl_client->client_cert, (const unsigned char *)cli_cert, strlen(cli_cert) + 1); if (ret != 0) { log_e("Failed to load client certificate. mbedtls_x509_crt_parse returned -0x%x", -ret); - return ret; + return handle_error(ret); } } // Load the client private key if (cli_key) { - log_v("Loading client key"); + log_v("Loading private key"); mbedtls_pk_init(&ssl_client->client_key); ret = mbedtls_pk_parse_key(&ssl_client->client_key, (const unsigned char *)cli_key, strlen(cli_key) + 1, NULL, 0); if (ret != 0) { + mbedtls_x509_crt_free(&ssl_client->client_cert); log_e("Failed to load client private key. mbedtls_pk_parse_key returned -0x%x", -ret); - return ret; + return handle_error(ret); } } - return 0; // 0 means success + return ret; } /** @@ -444,10 +468,27 @@ int load_certificates_and_keys(sslclient_context *ssl_client, const char *rootCA * * \param ssl_client sslclient_context* - The ssl client context. * \param timeout int - The timeout in milliseconds. - * \return int 1 if successful, -1 if failed. + * \return int 0 if successful. */ -int perform_handshake(sslclient_context *ssl_client, int timeout) { - int ret; +int perform_handshake(sslclient_context *ssl_client, const char *host, int timeout) { + int ret = -1; + log_v("Setting hostname for TLS session..."); + + // Hostname set here should match CN in server certificate + if((ret = mbedtls_ssl_set_hostname(&ssl_client->ssl_ctx, host)) != 0){ + return handle_error(ret); + } + + mbedtls_ssl_conf_rng(&ssl_client->ssl_conf, mbedtls_ctr_drbg_random, &ssl_client->drbg_ctx); + + if ((ret = mbedtls_ssl_setup(&ssl_client->ssl_ctx, &ssl_client->ssl_conf)) != 0) { + return handle_error(ret); + } + + log_v("Setting up IO callbacks..."); + mbedtls_ssl_set_bio(&ssl_client->ssl_ctx, ssl_client->client, + client_net_send, NULL, client_net_recv_timeout); + unsigned long handshake_start_time = millis(); log_v("Performing the SSL/TLS handshake..."); @@ -467,7 +508,27 @@ int perform_handshake(sslclient_context *ssl_client, int timeout) { } log_v("SSL/TLS handshake completed successfully."); - return 0; + + return ret; +} + +/** + * \brief Confirm the protocols and ciphersuite used for the connection. + * + * \param ssl_client sslclient_context* - The ssl client context. + * \param cli_cert const char* - The client certificate. + * \param cli_key const char* - The client key. + */ +void confirm_protocols(sslclient_context* ssl_client, const char* cli_cert, const char* cli_key) { + int ret = 0; + if (cli_cert != NULL && cli_key != NULL) { + log_v("Protocol is %s Ciphersuite is %s", mbedtls_ssl_get_version(&ssl_client->ssl_ctx), mbedtls_ssl_get_ciphersuite(&ssl_client->ssl_ctx)); + if ((ret = mbedtls_ssl_get_record_expansion(&ssl_client->ssl_ctx)) >= 0) { + log_v("Record expansion is %d", ret); + } else { + log_w("Record expansion is unknown (compression)"); + } + } } /** @@ -477,7 +538,7 @@ int perform_handshake(sslclient_context *ssl_client, int timeout) { * \param rootCABuff const char* - The root CA certificate. * \param cli_cert const char* - The client certificate. * \param cli_key const char* - The client key. - * \return int 1 if successful, -1 if failed. + * \return int 0 if successful. */ int verify_peer_certificate(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key) { char buf[512]; @@ -497,7 +558,7 @@ int verify_peer_certificate(sslclient_context *ssl_client, const char *rootCABuf return handle_error(-1); // Using -1 as a general error code here } else { log_v("Certificate verified."); - return 1; + return flags; } } @@ -540,35 +601,219 @@ void clean_up_resources(sslclient_context *ssl_client, const char *rootCABuff, c * \param cli_key const char*- The client key. * \param pskIdent const char* - The PSK identity. * \param psKey const char* - The PSK key. - * \return int 1 if successful, -1 if failed. + * \return int 1 if successful. */ -int start_ssl_client( sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey) { +int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey) +{ + char buf[512]; + int ret, flags; + //int enable = 1; log_v("Free internal heap before TLS %u", ESP.getFreeHeap()); - if (initialize_ssl_client(ssl_client, host, port) != 1) { + log_d("Connecting to %s:%d", host, port); + + Client *pClient = ssl_client->client; + + if (!pClient) { + log_e("Client provider not initialised"); return -1; } - if (seed_rng(ssl_client) != 1) { + + if (!pClient->connect(host, port)) { + log_e("Connect to Server failed!"); return -2; } - if (setup_ssl_configuration(ssl_client, rootCABuff, cli_cert, cli_key, pskIdent, psKey) != 1) { - return -3; + + log_v("Seeding the random number generator"); + mbedtls_entropy_init(&ssl_client->entropy_ctx); + + ret = mbedtls_ctr_drbg_seed(&ssl_client->drbg_ctx, mbedtls_entropy_func, + &ssl_client->entropy_ctx, (const unsigned char *) pers, strlen(pers)); + if (ret < 0) { + return handle_error(ret); } - if (load_certificates_and_keys(ssl_client, rootCABuff, cli_cert, cli_key) != 1) { - return -4; + + log_v("Setting up the SSL/TLS structure..."); + + if ((ret = mbedtls_ssl_config_defaults(&ssl_client->ssl_conf, + MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT)) != 0) { + return handle_error(ret); + } + + // MBEDTLS_SSL_VERIFY_REQUIRED if a CA certificate is defined on Arduino IDE and + // MBEDTLS_SSL_VERIFY_NONE if not. + + if (rootCABuff != NULL) { + log_v("Loading CA cert"); + mbedtls_x509_crt_init(&ssl_client->ca_cert); + mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_REQUIRED); + ret = mbedtls_x509_crt_parse(&ssl_client->ca_cert, (const unsigned char *)rootCABuff, strlen(rootCABuff) + 1); + mbedtls_ssl_conf_ca_chain(&ssl_client->ssl_conf, &ssl_client->ca_cert, NULL); + //mbedtls_ssl_conf_verify(&ssl_client->ssl_ctx, my_verify, NULL ); + if (ret < 0) { + return handle_error(ret); + } + } else if (pskIdent != NULL && psKey != NULL) { + log_v("Setting up PSK"); + // convert PSK from hex to binary + if ((strlen(psKey) & 1) != 0 || strlen(psKey) > 2*MBEDTLS_PSK_MAX_LEN) { + log_e("pre-shared key not valid hex or too long"); + return -1; + } + unsigned char psk[MBEDTLS_PSK_MAX_LEN]; + size_t psk_len = strlen(psKey)/2; + for (int j=0; j= '0' && c <= '9') c -= '0'; + else if (c >= 'A' && c <= 'F') c -= 'A' - 10; + else if (c >= 'a' && c <= 'f') c -= 'a' - 10; + else return -1; + psk[j/2] = c<<4; + c = psKey[j+1]; + if (c >= '0' && c <= '9') c -= '0'; + else if (c >= 'A' && c <= 'F') c -= 'A' - 10; + else if (c >= 'a' && c <= 'f') c -= 'a' - 10; + else return -1; + psk[j/2] |= c; + } + // set mbedtls config + ret = mbedtls_ssl_conf_psk(&ssl_client->ssl_conf, psk, psk_len, + (const unsigned char *)pskIdent, strlen(pskIdent)); + if (ret != 0) { + log_e("mbedtls_ssl_conf_psk returned %d", ret); + return handle_error(ret); + } + } else { + mbedtls_ssl_conf_authmode(&ssl_client->ssl_conf, MBEDTLS_SSL_VERIFY_NONE); + log_i("WARNING: Use certificates for a more secure communication!"); + } + + if (cli_cert != NULL && cli_key != NULL) { + mbedtls_x509_crt_init(&ssl_client->client_cert); + mbedtls_pk_init(&ssl_client->client_key); + + log_v("Loading CRT cert"); + + ret = mbedtls_x509_crt_parse(&ssl_client->client_cert, (const unsigned char *)cli_cert, strlen(cli_cert) + 1); + if (ret < 0) { + return handle_error(ret); + } + + log_v("Loading private key"); + ret = mbedtls_pk_parse_key(&ssl_client->client_key, (const unsigned char *)cli_key, strlen(cli_key) + 1, NULL, 0); + + if (ret != 0) { + mbedtls_x509_crt_free(&ssl_client->client_cert); // cert+key are free'd in pair + return handle_error(ret); + } + + mbedtls_ssl_conf_own_cert(&ssl_client->ssl_conf, &ssl_client->client_cert, &ssl_client->client_key); + } + + log_v("Setting hostname for TLS session..."); + + // Hostname set here should match CN in server certificate + if ((ret = mbedtls_ssl_set_hostname(&ssl_client->ssl_ctx, host)) != 0) { + return handle_error(ret); + } + + mbedtls_ssl_conf_rng(&ssl_client->ssl_conf, mbedtls_ctr_drbg_random, &ssl_client->drbg_ctx); + + if ((ret = mbedtls_ssl_setup(&ssl_client->ssl_ctx, &ssl_client->ssl_conf)) != 0) { + return handle_error(ret); + } + + log_v("Setting up IO callbacks..."); + mbedtls_ssl_set_bio(&ssl_client->ssl_ctx, ssl_client->client, + client_net_send, NULL, client_net_recv_timeout ); + + log_v("Performing the SSL/TLS handshake..."); + unsigned long handshake_start_time=millis(); + while ((ret = mbedtls_ssl_handshake(&ssl_client->ssl_ctx)) != 0) { + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { + return handle_error(ret); + } + if((millis()-handshake_start_time)>ssl_client->handshake_timeout) { + return -1; + } + vTaskDelay(10 / portTICK_PERIOD_MS); + } + + + if (cli_cert != NULL && cli_key != NULL) { + log_d("Protocol is %s Ciphersuite is %s", mbedtls_ssl_get_version(&ssl_client->ssl_ctx), mbedtls_ssl_get_ciphersuite(&ssl_client->ssl_ctx)); + if ((ret = mbedtls_ssl_get_record_expansion(&ssl_client->ssl_ctx)) >= 0) { + log_d("Record expansion is %d", ret); + } else { + log_w("Record expansion is unknown (compression)"); + } + } + + log_v("Verifying peer X.509 certificate..."); + + if ((flags = mbedtls_ssl_get_verify_result(&ssl_client->ssl_ctx)) != 0) { + bzero(buf, sizeof(buf)); + mbedtls_x509_crt_verify_info(buf, sizeof(buf), " ! ", flags); + log_e("Failed to verify peer certificate! verification info: %s", buf); + stop_ssl_socket(ssl_client, rootCABuff, cli_cert, cli_key); //It's not safe continue. + return handle_error(ret); + } else { + log_v("Certificate verified."); } - if (perform_handshake(ssl_client) != 1) { - return -5; + + if (rootCABuff != NULL) { + mbedtls_x509_crt_free(&ssl_client->ca_cert); } - if (verify_peer_certificate(ssl_client, rootCABuff, cli_cert, cli_key) != 1) { - return -6; + + if (cli_cert != NULL) { + mbedtls_x509_crt_free(&ssl_client->client_cert); } - clean_up_resources(ssl_client, rootCABuff, cli_cert, cli_key); + if (cli_key != NULL) { + mbedtls_pk_free(&ssl_client->client_key); + } + log_v("Free internal heap after TLS %u", ESP.getFreeHeap()); + //return ssl_client->socket; return 1; } +// int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey) { +// log_v("Free internal heap before TLS %u", ESP.getFreeHeap()); + +// if (initialize_ssl_client(ssl_client, host, port) != 0) { +// return -1; +// } + +// if (seed_rng(ssl_client) != 0) { +// return -2; +// } + +// if (setup_ssl_configuration(ssl_client, rootCABuff, cli_cert, cli_key, pskIdent, psKey) != 0) { +// return -3; +// } + +// if (load_certificates_and_keys(ssl_client, rootCABuff, cli_cert, cli_key) != 0) { +// return -4; +// } + +// if (perform_handshake(ssl_client, host) != 0) { +// return -5; +// } + +// confirm_protocols(ssl_client, cli_cert, cli_key); + +// if (verify_peer_certificate(ssl_client, rootCABuff, cli_cert, cli_key) != 0) { +// return -6; +// } + +// clean_up_resources(ssl_client, rootCABuff, cli_cert, cli_key); +// log_v("Free internal heap after TLS %u", ESP.getFreeHeap()); + +// return 1; +// } /** * \brief Stop the ssl socket. @@ -629,7 +874,7 @@ int data_to_read(sslclient_context *ssl_client) { * \return int The number of bytes sent. */ int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len) { - log_d("Writing SSL (%zu bytes)...", len); //for low level debug + log_v("Writing SSL (%zu bytes)...", len); //for low level debug int ret = -1; while ((ret = mbedtls_ssl_write(&ssl_client->ssl_ctx, data, len)) <= 0) { @@ -652,7 +897,7 @@ int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len * \return int The number of bytes received. */ int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length) { - log_d( "Reading SSL (%d bytes)", length); //for low level debug + log_v( "Reading SSL (%d bytes)", length); //for low level debug int ret = -1; ret = mbedtls_ssl_read(&ssl_client->ssl_ctx, data, length); @@ -732,12 +977,12 @@ bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const ++pos; } if (pos > len - 2) { - log_d("pos:%d len:%d fingerprint too short", pos, len); + log_v("pos:%d len:%d fingerprint too short", pos, len); return false; } uint8_t high, low; if (!parseHexNibble(fp[pos], &high) || !parseHexNibble(fp[pos+1], &low)) { - log_d("pos:%d len:%d invalid hex sequence: %c%c", pos, len, fp[pos], fp[pos+1]); + log_v("pos:%d len:%d invalid hex sequence: %c%c", pos, len, fp[pos], fp[pos+1]); return false; } pos += 2; @@ -748,7 +993,7 @@ bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const const mbedtls_x509_crt* crt = mbedtls_ssl_get_peer_cert(&ssl_client->ssl_ctx); if (!crt) { - log_d("could not fetch peer certificate"); + log_v("could not fetch peer certificate"); return false; } diff --git a/src/ssl_client.h b/src/ssl_client.h index 4dd4ffc..6163690 100644 --- a/src/ssl_client.h +++ b/src/ssl_client.h @@ -48,7 +48,8 @@ int initialize_ssl_client(sslclient_context *ssl_client, const char *host, uint3 int seed_rng(sslclient_context *ssl_client); int setup_ssl_configuration(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey); int load_certificates_and_keys(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key); -int perform_handshake(sslclient_context *ssl_client, int timeout=SSL_CLIENT_DEFAULT_HANDSHAKE_TIMEOUT); +int perform_handshake(sslclient_context *ssl_client, const char *host, int timeout=SSL_CLIENT_SLOW_NETWORK_HANDSHAKE_TIMEOUT); +void confirm_protocols(sslclient_context* ssl_client, const char* cli_cert, const char* cli_key); int verify_peer_certificate(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key); void clean_up_resources(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key); void ssl_init(sslclient_context *ssl_client, Client *client);