Skip to content

Commit

Permalink
fix: update test branch
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertByrnes committed Aug 28, 2023
1 parent 3157867 commit 482fcd2
Show file tree
Hide file tree
Showing 4 changed files with 503 additions and 216 deletions.
228 changes: 136 additions & 92 deletions src/SSLClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,77 +23,95 @@
#undef write
#undef read


SSLClient::SSLClient()
{
_connected = false;

sslclient = new sslclient_context;
ssl_init(sslclient, nullptr);
sslclient->handshake_timeout = 120000;
_CA_cert = NULL;
_cert = NULL;
_private_key = NULL;
_pskIdent = NULL;
_psKey = NULL;
/**
* @brief Construct a new SSLClient::SSLClient object using the default constructor.
*/
SSLClient::SSLClient() {
_connected = false;
sslclient = new sslclient_context;
ssl_init(sslclient, nullptr);
sslclient->handshake_timeout = 120000;
_CA_cert = NULL;
_cert = NULL;
_private_key = NULL;
_pskIdent = NULL;
_psKey = NULL;
}

SSLClient::SSLClient(Client* client)
{
_connected = false;

sslclient = new sslclient_context;
ssl_init(sslclient, client);
sslclient->handshake_timeout = 120000;
_CA_cert = NULL;
_cert = NULL;
_private_key = NULL;
_pskIdent = NULL;
_psKey = NULL;

/**
* @brief Construct a new SSLClient::SSLClient object using the pointer to the specified client.
*
* @param client
*/
SSLClient::SSLClient(Client* client) {
_connected = false;
sslclient = new sslclient_context;
ssl_init(sslclient, client);
sslclient->handshake_timeout = 120000;
_CA_cert = NULL;
_cert = NULL;
_private_key = NULL;
_pskIdent = NULL;
_psKey = NULL;
}

SSLClient::~SSLClient()
{
stop();
delete sslclient;
/**
* @brief Destroy the SSLClient::SSLClient object.
*/
SSLClient::~SSLClient() {
stop();
delete sslclient;
}

/**
* @brief Stops the SSL client.
*/
void SSLClient::stop() {
if (sslclient->client != nullptr) {
if (sslclient->client >= 0) {
log_v("Stopping ssl client");
stop_ssl_socket(sslclient, _CA_cert, _cert, _private_key);
} else {
log_v("stop() not called because client is < 0");
}
if (sslclient->client != nullptr) {
if (sslclient->client >= 0) {
log_v("Stopping ssl client");
stop_ssl_socket(sslclient, _CA_cert, _cert, _private_key);
} else {
log_v("stop() not called because client is nullptr");
log_v("stop() not called because client is < 0");
}
_connected = false;
_peek = -1;
} else {
log_v("stop() not called because client is nullptr");
}
_connected = false;
_peek = -1;
}

int SSLClient::connect(IPAddress ip, uint16_t port)
{
if (_pskIdent && _psKey)
return connect(ip, port, _pskIdent, _psKey);
return connect(ip, port, _CA_cert, _cert, _private_key);
/**
* @brief
*
* @param ip
* @param port
* @return int
*/
int SSLClient::connect(IPAddress ip, uint16_t port) {
if (_pskIdent && _psKey) {
log_v("connect with PSK");
return connect(ip, port, _pskIdent, _psKey);
}
log_v("connect with CA");
return connect(ip, port, _CA_cert, _cert, _private_key);
}

int SSLClient::connect(IPAddress ip, uint16_t port, int32_t timeout){
_timeout = timeout;
return connect(ip, port);
int SSLClient::connect(IPAddress ip, uint16_t port, int32_t timeout) {
_timeout = timeout;
return connect(ip, port);
}

int SSLClient::connect(const char *host, uint16_t port)
{
if (_pskIdent && _psKey)
return connect(host, port, _pskIdent, _psKey);
return connect(host, port, _CA_cert, _cert, _private_key);
int SSLClient::connect(const char *host, uint16_t port) {
if (_pskIdent && _psKey) {
log_v("connect with PSK");
return connect(host, port, _pskIdent, _psKey);
}
log_v("connect with CA");
return connect(host, port, _CA_cert, _cert, _private_key);
}

int SSLClient::connect(const char *host, uint16_t port, int32_t timeout){
int SSLClient::connect(const char *host, uint16_t port, int32_t timeout) {
_timeout = timeout;
return connect(host, port);
}
Expand Down Expand Up @@ -178,48 +196,74 @@ size_t SSLClient::write(const uint8_t *buf, size_t size)
return res;
}

int SSLClient::read(uint8_t *buf, size_t size)
{
int peeked = 0;
int avail = available();
if ((!buf && size) || avail <= 0) {
return -1;
}
if(!size){
return 0;
}
if(_peek >= 0){
buf[0] = _peek;
_peek = -1;
size--;
avail--;
if(!size || !avail){
return 1;
}
buf++;
peeked = 1;
}
/**
* \brief Reads data from the sslclient. If there is a byte peeked, it returns that byte.
*
* \param buf Buffer to read into.
* \param size Size of the buffer.
* \return int 1 if a byte has been peeked and the client is not connected.
* \return int < 1 if client is connected and there is an error from get_ssl_receive().
* \return int > 1 if res + peeked.
*/
int SSLClient::read(uint8_t *buf, size_t size) {
log_v("This is the iClient->read() implementation");
int peeked = 0;
int avail = available();

if ((!buf && size) || avail <= 0) {
return -1; // return error if no buffer or nothing to read.
}

int res = get_ssl_receive(sslclient, buf, size);
if (res < 0) {
stop();
return peeked?peeked:res;
if (!size) {
return 0; // return 0 if no bytes requested.
}

if (_peek >= 0) {
buf[0] = _peek; // Places this peeked byte at the start of the buffer.
_peek = -1; // Resets _peek to -1 to indicate no bytes are currently peeked.
size--; // Decreases the available size (size) by 1.
avail--; // Decreases the available bytes (avail) by 1.
if (!size || !avail) { // If there's no space left in the buffer (size) or no data left to read (avail)
return 1; // Return 1 to indicate one byte has been read.
}
return res + peeked;
buf++; // Increment the buffer pointer.
peeked = 1; // set peeked to 1 to indicate one byte has been read from the peeked value.
}

int res = get_ssl_receive(sslclient, buf, size);

if (res < 0) {
stop();
return peeked?peeked:res; // If peeked is true return peeked, otherwise return res, i.e. data_to_read error.
}

return res + peeked; // Return the number of bytes read + the number of bytes peeked.
}

int SSLClient::available()
{
int peeked = (_peek >= 0);
if (!_connected) {
return peeked;
}
int res = data_to_read(sslclient);
if (res < 0) {
stop();
return peeked?peeked:res;
}
return res+peeked;
/**
* \brief Returns how many bytes of data are available to be read from the sslclient.
* It takes into account both directly readable bytes and a potentially "peeked" byte.
* If there's an error or the client is not connected, it handles these scenarios appropriately.
*
* \return int 1 if a byte has been peeked and the client is not connected.
* \return int < 1 if client is connected and there is an error from data_to_read().
* \return int > 1 if res + peeked.
*/
int SSLClient::available() {
int peeked = (_peek >= 0); // 1 if a byte has been peeked (available to read without advancing the read pointer)

if (!_connected) {
return peeked;
}

int res = data_to_read(sslclient); // how many bytes available to read.

if (res < 0) {
stop();
return peeked?peeked:res; // If peeked is true return peeked, otherwise return res, i.e. data_to_read error.
}

return res+peeked;
}

uint8_t SSLClient::connected()
Expand Down
109 changes: 53 additions & 56 deletions src/SSLClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,78 +25,75 @@
class SSLClient : public Client
{
protected:
sslclient_context *sslclient;
sslclient_context *sslclient;

int _lastError = 0;
int _lastError = 0;
int _peek = -1;
int _timeout = 0;
const char *_CA_cert;
const char *_cert;
const char *_private_key;
const char *_pskIdent; // identity for PSK cipher suites
const char *_psKey; // key in hex for PSK cipher suites
int _timeout = 0;
const char *_CA_cert;
const char *_cert;
const char *_private_key;
const char *_pskIdent; // identity for PSK cipher suites
const char *_psKey; // key in hex for PSK cipher suites

bool _connected = false;
bool _connected = false;

Client* _client = nullptr;
Client* _client = nullptr;

public:
SSLClient();
SSLClient(Client* client);
~SSLClient();
SSLClient();
SSLClient(Client* client);
~SSLClient();

int connect(IPAddress ip, uint16_t port);
int connect(IPAddress ip, uint16_t port, int32_t timeout);
int connect(const char *host, uint16_t port);
int connect(const char *host, uint16_t port, int32_t timeout);
int connect(IPAddress ip, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key);
int connect(const char *host, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key);
int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey);
int connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey);
int connect(IPAddress ip, uint16_t port);
int connect(IPAddress ip, uint16_t port, int32_t timeout);
int connect(const char *host, uint16_t port);
int connect(const char *host, uint16_t port, int32_t timeout);
int connect(IPAddress ip, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key);
int connect(const char *host, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key);
int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey);
int connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey);

int peek();
size_t write(uint8_t data);
size_t write(const uint8_t *buf, size_t size);
int available();
int read();
int read(uint8_t *buf, size_t size);
void flush() {}
void stop();
uint8_t connected();
int lastError(char *buf, const size_t size);
size_t write(uint8_t data);
size_t write(const uint8_t *buf, size_t size);
int available();
int read();
int read(uint8_t *buf, size_t size);
void flush() {}
void stop();
uint8_t connected();
int lastError(char *buf, const size_t size);

void setPreSharedKey(const char *pskIdent, const char *psKey); // psKey in Hex
void setCACert(const char *rootCA);
void setCertificate(const char *client_ca);
void setPrivateKey (const char *private_key);
bool loadCACert(Stream& stream, size_t size);
bool loadCertificate(Stream& stream, size_t size);
bool loadPrivateKey(Stream& stream, size_t size);
bool verify(const char* fingerprint, const char* domain_name);
void setHandshakeTimeout(unsigned long handshake_timeout);
void setClient(Client* client);
void setPreSharedKey(const char *pskIdent, const char *psKey); // psKey in Hex
void setCACert(const char *rootCA);
void setCertificate(const char *client_ca);
void setPrivateKey (const char *private_key);
bool loadCACert(Stream& stream, size_t size);
bool loadCertificate(Stream& stream, size_t size);
bool loadPrivateKey(Stream& stream, size_t size);
bool verify(const char* fingerprint, const char* domain_name);
void setHandshakeTimeout(unsigned long handshake_timeout);
void setClient(Client* client);
int setTimeout(uint32_t seconds){ return 0; }

operator bool() {
return connected();
}

int setTimeout(uint32_t seconds){ return 0; }
bool operator==(const bool value) {
return bool() == value;
}

operator bool()
{
return connected();
}
bool operator==(const bool value)
{
return bool() == value;
}
bool operator!=(const bool value)
{
return bool() != value;
}
bool operator!=(const bool value) {
return bool() != value;
}

private:
char *_streamLoad(Stream& stream, size_t size);
char *_streamLoad(Stream& stream, size_t size);

//friend class GprsServer;
using Print::write;
//friend class GprsServer;
using Print::write;
};

#endif /* SSLClient_H */
Loading

0 comments on commit 482fcd2

Please sign in to comment.