diff --git a/src/outputmessage.h b/src/outputmessage.h index 402a4dc6f0..c58d5ed671 100644 --- a/src/outputmessage.h +++ b/src/outputmessage.h @@ -21,12 +21,12 @@ class OutputMessage : public NetworkMessage void writeMessageLength() { add_header(info.length); } - void addCryptoHeader(checksumMode_t mode, uint32_t& sequence) + void addCryptoHeader(checksumMode_t mode) { if (mode == CHECKSUM_ADLER) { add_header(adlerChecksum(&buffer[outputBufferStart], info.length)); } else if (mode == CHECKSUM_SEQUENCE) { - add_header(sequence++); + add_header(getSequenceId()); } writeMessageLength(); @@ -48,6 +48,13 @@ class OutputMessage : public NetworkMessage info.position += msgLen; } + void setSequenceId(uint32_t sequence) { + sequenceId = sequence; + } + uint32_t getSequenceId() const { + return sequenceId; + } + private: template void add_header(T add) @@ -60,6 +67,7 @@ class OutputMessage : public NetworkMessage } MsgSize_t outputBufferStart = INITIAL_BUFFER_POSITION; + uint32_t sequenceId; }; namespace tfs::net { diff --git a/src/protocol.cpp b/src/protocol.cpp index 71b7f6c554..b06ff5061b 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -46,11 +46,20 @@ bool XTEA_decrypt(NetworkMessage& msg, const xtea::round_keys& key) void Protocol::onSendMessage(const OutputMessage_ptr& msg) { if (!rawMessages) { + if (encryptionEnabled && checksumMode == CHECKSUM_SEQUENCE) { + uint32_t compressionChecksum = 0; + if (msg->getLength() >= 128 && deflateMessage(*msg)) { + compressionChecksum = 0x80000000; + } + + msg->setSequenceId(compressionChecksum | getNextSequenceId()); + } + msg->writeMessageLength(); if (encryptionEnabled) { XTEA_encrypt(*msg, key); - msg->addCryptoHeader(checksumMode, sequenceNumber); + msg->addCryptoHeader(checksumMode); } } } @@ -86,6 +95,34 @@ bool Protocol::RSA_decrypt(NetworkMessage& msg) return msg.getByte() == 0; } +bool Protocol::deflateMessage(OutputMessage& msg) +{ + static thread_local std::vector buffer(NETWORKMESSAGE_MAXSIZE); + zstream.next_in = msg.getOutputBuffer(); + zstream.avail_in = msg.getLength(); + zstream.next_out = buffer.data(); + zstream.avail_out = buffer.size(); + + const auto result = deflate(&zstream, Z_FINISH); + if (result != Z_OK && result != Z_STREAM_END) { + std::cout << "Error while deflating packet data error: " << (zstream.msg ? zstream.msg : "unknown") + << std::endl; + return false; + } + + const auto size = zstream.total_out; + if (size <= 0) { + std::cout << "Deflated packet data had invalid size: " << size + << " error: " << (zstream.msg ? zstream.msg : "unknown") << std::endl; + return false; + } + + msg.reset(); + msg.addBytes(reinterpret_cast(buffer.data()), size); + + return true; +} + Connection::Address Protocol::getIP() const { if (auto connection = getConnection()) { diff --git a/src/protocol.h b/src/protocol.h index a6ea13dddf..929c5142d3 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -4,13 +4,19 @@ #ifndef FS_PROTOCOL_H #define FS_PROTOCOL_H +#include + #include "connection.h" #include "xtea.h" class Protocol : public std::enable_shared_from_this { public: - explicit Protocol(Connection_ptr connection) : connection(connection) {} + explicit Protocol(Connection_ptr connection) : connection(connection) { + if (deflateInit2(&zstream, 6, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY) != Z_OK) { + std::cout << "ZLIB initialization error: " << (zstream.msg ? zstream.msg : "unknown") << std::endl; + } + } virtual ~Protocol() = default; // non-copyable @@ -42,6 +48,15 @@ class Protocol : public std::enable_shared_from_this } } + uint32_t getNextSequenceId() { + const auto sequence = ++sequenceNumber; + if (sequenceNumber >= std::numeric_limits::max()) { + sequenceNumber = 0; + } + + return sequence; + } + protected: static constexpr size_t RSA_BUFFER_LENGTH = 128; @@ -57,6 +72,8 @@ class Protocol : public std::enable_shared_from_this static bool RSA_decrypt(NetworkMessage& msg); + bool deflateMessage(OutputMessage& msg); + void setRawMessages(bool value) { rawMessages = value; } virtual void release() {} @@ -72,6 +89,8 @@ class Protocol : public std::enable_shared_from_this bool encryptionEnabled = false; checksumMode_t checksumMode = CHECKSUM_ADLER; bool rawMessages = false; + + z_stream zstream{}; }; #endif // FS_PROTOCOL_H