Skip to content

Commit

Permalink
Add packet compression
Browse files Browse the repository at this point in the history
  • Loading branch information
nekiro committed Dec 9, 2024
1 parent 7573e41 commit 2127690
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
12 changes: 10 additions & 2 deletions src/outputmessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class OutputMessage : public NetworkMessage

void writeMessageLength() { add_header(info.length); }

void addCryptoHeader(checksumMode_t mode, uint32_t& sequence)
void addCryptoHeader(checksumMode_t mode)
{
if (mode == CHECKSUM_ADLER) {
add_header(adlerChecksum(&buffer[outputBufferStart], info.length));
} else if (mode == CHECKSUM_SEQUENCE) {
add_header(sequence++);
add_header(getSequenceId());
}

writeMessageLength();
Expand All @@ -48,6 +48,13 @@ class OutputMessage : public NetworkMessage
info.position += msgLen;
}

void setSequenceId(uint32_t sequence) {
sequenceId = sequence;
}
uint32_t getSequenceId() const {
return sequenceId;
}

private:
template <typename T>
void add_header(T add)
Expand All @@ -60,6 +67,7 @@ class OutputMessage : public NetworkMessage
}

MsgSize_t outputBufferStart = INITIAL_BUFFER_POSITION;
uint32_t sequenceId;
};

namespace tfs::net {
Expand Down
39 changes: 38 additions & 1 deletion src/protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,20 @@ bool XTEA_decrypt(NetworkMessage& msg, const xtea::round_keys& key)
void Protocol::onSendMessage(const OutputMessage_ptr& msg)
{
if (!rawMessages) {
if (encryptionEnabled && checksumMode == CHECKSUM_SEQUENCE) {
uint32_t compressionChecksum = 0;
if (msg->getLength() >= 128 && deflateMessage(*msg)) {
compressionChecksum = 0x80000000;
}

msg->setSequenceId(compressionChecksum | getNextSequenceId());
}

msg->writeMessageLength();

if (encryptionEnabled) {
XTEA_encrypt(*msg, key);
msg->addCryptoHeader(checksumMode, sequenceNumber);
msg->addCryptoHeader(checksumMode);
}
}
}
Expand Down Expand Up @@ -86,6 +95,34 @@ bool Protocol::RSA_decrypt(NetworkMessage& msg)
return msg.getByte() == 0;
}

bool Protocol::deflateMessage(OutputMessage& msg)
{
static thread_local std::vector<uint8_t> buffer(NETWORKMESSAGE_MAXSIZE);
zstream.next_in = msg.getOutputBuffer();
zstream.avail_in = msg.getLength();
zstream.next_out = buffer.data();
zstream.avail_out = buffer.size();

const auto result = deflate(&zstream, Z_FINISH);
if (result != Z_OK && result != Z_STREAM_END) {
std::cout << "Error while deflating packet data error: " << (zstream.msg ? zstream.msg : "unknown")
<< std::endl;
return false;
}

const auto size = zstream.total_out;
if (size <= 0) {
std::cout << "Deflated packet data had invalid size: " << size
<< " error: " << (zstream.msg ? zstream.msg : "unknown") << std::endl;
return false;
}

msg.reset();
msg.addBytes(reinterpret_cast<const char*>(buffer.data()), size);

return true;
}

Connection::Address Protocol::getIP() const
{
if (auto connection = getConnection()) {
Expand Down
21 changes: 20 additions & 1 deletion src/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
#ifndef FS_PROTOCOL_H
#define FS_PROTOCOL_H

#include <zlib.h>

#include "connection.h"
#include "xtea.h"

class Protocol : public std::enable_shared_from_this<Protocol>
{
public:
explicit Protocol(Connection_ptr connection) : connection(connection) {}
explicit Protocol(Connection_ptr connection) : connection(connection) {
if (deflateInit2(&zstream, 6, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY) != Z_OK) {
std::cout << "ZLIB initialization error: " << (zstream.msg ? zstream.msg : "unknown") << std::endl;
}
}
virtual ~Protocol() = default;

// non-copyable
Expand Down Expand Up @@ -42,6 +48,15 @@ class Protocol : public std::enable_shared_from_this<Protocol>
}
}

uint32_t getNextSequenceId() {
const auto sequence = ++sequenceNumber;
if (sequenceNumber >= std::numeric_limits<int32_t>::max()) {
sequenceNumber = 0;
}

return sequence;
}

protected:
static constexpr size_t RSA_BUFFER_LENGTH = 128;

Expand All @@ -57,6 +72,8 @@ class Protocol : public std::enable_shared_from_this<Protocol>

static bool RSA_decrypt(NetworkMessage& msg);

bool deflateMessage(OutputMessage& msg);

void setRawMessages(bool value) { rawMessages = value; }

virtual void release() {}
Expand All @@ -72,6 +89,8 @@ class Protocol : public std::enable_shared_from_this<Protocol>
bool encryptionEnabled = false;
checksumMode_t checksumMode = CHECKSUM_ADLER;
bool rawMessages = false;

z_stream zstream{};
};

#endif // FS_PROTOCOL_H

0 comments on commit 2127690

Please sign in to comment.