diff --git a/src/network/database_connector.cpp b/src/network/database_connector.cpp new file mode 100644 index 00000000000..30155495b05 --- /dev/null +++ b/src/network/database_connector.cpp @@ -0,0 +1,987 @@ +// +// SuperTuxKart - a fun racing game with go-kart +// Copyright (C) 2024 SuperTuxKart-Team +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of the GNU General Public License +// as published by the Free Software Foundation; either version 3 +// of the License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. + +#ifdef ENABLE_SQLITE3 + +#include "network/database_connector.hpp" + +#include "network/network_player_profile.hpp" +#include "network/server_config.hpp" +#include "network/socket_address.hpp" +#include "network/stk_host.hpp" +#include "network/stk_ipv6.hpp" +#include "network/stk_peer.hpp" +#include "utils/log.hpp" + +//----------------------------------------------------------------------------- +/** Prints "?" to the output stream and saves the Binder object to the + * corresponding BinderCollection so that it can produce bind function later + * When we invoke StringUtils::insertValues with a Binder argument, the + * implementation of insertValues ensures that this function is invoked for + * all Binder arguments from left to right. + */ +std::ostream& operator << (std::ostream& os, const Binder& binder) +{ + os << "?"; + binder.m_collection.lock()->m_binders.emplace_back(std::make_shared(binder)); + return os; +} // operator << (Binder) + +//----------------------------------------------------------------------------- +/** Returns a bind function that should be used inside an easySQLQuery. As the + * Binder objects are already ordered in a correct way, the indices just go + * from 1 upwards. Depending on a particular Binder, we can also bind NULL + * instead of a string. + */ +std::function BinderCollection::getBindFunction() const +{ + auto binders = m_binders; + return [binders](sqlite3_stmt* stmt) + { + int idx = 1; + for (std::shared_ptr binder: binders) + { + if (binder) + { + // SQLITE_TRANSIENT to copy string + if (binder->m_use_null_if_empty && binder->m_value.empty()) + { + if (sqlite3_bind_null(stmt, idx) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind NULL for %s.", + binder->m_name.c_str()); + } + } + else + { + if (sqlite3_bind_text(stmt, idx, binder->m_value.c_str(), + -1, SQLITE_TRANSIENT) != SQLITE_OK) + { + Log::error("easySQLQuery", "Failed to bind %s as %s.", + binder->m_value.c_str(), binder->m_name.c_str()); + } + } + } + ++idx; + } + }; +} // BinderCollection::getBindFunction + +//----------------------------------------------------------------------------- +/** Opens the database, sets its busy handler and variables related to it. */ +void DatabaseConnector::initDatabase() +{ + m_last_poll_db_time = StkTime::getMonoTimeMs(); + m_db = NULL; + m_ip_ban_table_exists = false; + m_ipv6_ban_table_exists = false; + m_online_id_ban_table_exists = false; + m_ip_geolocation_table_exists = false; + m_ipv6_geolocation_table_exists = false; + m_player_reports_table_exists = false; + if (!ServerConfig::m_sql_management) + return; + const std::string& path = ServerConfig::getConfigDirectory() + "/" + + ServerConfig::m_database_file.c_str(); + int ret = sqlite3_open_v2(path.c_str(), &m_db, + SQLITE_OPEN_SHAREDCACHE | SQLITE_OPEN_FULLMUTEX | + SQLITE_OPEN_READWRITE, NULL); + if (ret != SQLITE_OK) + { + Log::error("ServerLobby", "Cannot open database: %s.", + sqlite3_errmsg(m_db)); + sqlite3_close(m_db); + m_db = NULL; + return; + } + sqlite3_busy_handler(m_db, [](void* data, int retry) + { + int retry_count = ServerConfig::m_database_timeout / 100; + if (retry < retry_count) + { + sqlite3_sleep(100); + // Return non-zero to let caller retry again + return 1; + } + // Return zero to let caller return SQLITE_BUSY immediately + return 0; + }, NULL); + sqlite3_create_function(m_db, "insideIPv6CIDR", 2, SQLITE_UTF8, NULL, + &insideIPv6CIDRSQL, NULL, NULL); + sqlite3_create_function(m_db, "upperIPv6", 1, SQLITE_UTF8, NULL, + &upperIPv6SQL, NULL, NULL); + checkTableExists(ServerConfig::m_ip_ban_table, m_ip_ban_table_exists); + checkTableExists(ServerConfig::m_ipv6_ban_table, m_ipv6_ban_table_exists); + checkTableExists(ServerConfig::m_online_id_ban_table, + m_online_id_ban_table_exists); + checkTableExists(ServerConfig::m_player_reports_table, + m_player_reports_table_exists); + checkTableExists(ServerConfig::m_ip_geolocation_table, + m_ip_geolocation_table_exists); + checkTableExists(ServerConfig::m_ipv6_geolocation_table, + m_ipv6_geolocation_table_exists); +} // initDatabase + +//----------------------------------------------------------------------------- +/** Closes the database. */ +void DatabaseConnector::destroyDatabase() +{ + auto peers = STKHost::get()->getPeers(); + for (auto& peer : peers) + writeDisconnectInfoTable(peer.get()); + if (m_db != NULL) + sqlite3_close(m_db); +} // destroyDatabase + +//----------------------------------------------------------------------------- +/** Runs simple query with optional bind function. If output vector pointer is + * not (default) nullptr, then the output is written there. + * \param query The SQL query with '?'-placeholders for values to bind. + * \param output The 2D vector for output rows. If nullptr, the query output + * is ignored. + * \param bind_function The function for binding missing values. + * \return True if no error occurs. + */ +bool DatabaseConnector::easySQLQuery( + const std::string& query, std::vector>* output, + std::function bind_function, + std::string null_value) const +{ + if (!m_db) + return false; + sqlite3_stmt* stmt = NULL; + int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); + if (ret == SQLITE_OK) + { + if (bind_function) + bind_function(stmt); + ret = sqlite3_step(stmt); + if (output) + { + output->clear(); + while (ret == SQLITE_ROW) + { + output->emplace_back(); + int columns = sqlite3_column_count(stmt); + for (int i = 0; i < columns; ++i) + { + const char* value = (char*)sqlite3_column_text(stmt, i); + if (value == nullptr) + output->back().push_back(null_value); + else + output->back().push_back(std::string(value)); + } + ret = sqlite3_step(stmt); + } + } + ret = sqlite3_finalize(stmt); + if (ret != SQLITE_OK) + { + Log::error("DatabaseConnector", + "Error finalize database for easy query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + return false; + } + } + else + { + Log::error("DatabaseConnector", + "Error preparing database for easy query %s: %s", + query.c_str(), sqlite3_errmsg(m_db)); + return false; + } + return true; +} // easySQLQuery + +//----------------------------------------------------------------------------- +/** Performs a query to determine if a certain table exists. + * \param table The searched name. + * \param result The output value. + */ +void DatabaseConnector::checkTableExists(const std::string& table, bool& result) +{ + if (!m_db) + return; + result = false; + if (!table.empty()) + { + std::string query = StringUtils::insertValues( + "SELECT count(type) FROM sqlite_master " + "WHERE type='table' AND name='%s';", table.c_str()); + + std::vector> output; + if (easySQLQuery(query, &output) && !output.empty()) + { + int number; + if (StringUtils::fromString(output[0][0], number) && number == 1) + { + Log::info("DatabaseConnector", "Table named %s will be used.", + table.c_str()); + result = true; + } + } + } + if (!result && !table.empty()) + { + Log::warn("DatabaseConnector", "Table named %s not found in database.", + table.c_str()); + } +} // checkTableExists + +//----------------------------------------------------------------------------- +/** Queries the database's IP mapping to determine the country code for an + * address. + * \param addr Queried address. + * \return A country code string if the address is found in the mapping, + * and an empty string otherwise. + */ +std::string DatabaseConnector::ip2Country(const SocketAddress& addr) const +{ + if (!m_db || !m_ip_geolocation_table_exists || addr.isLAN()) + return ""; + + std::string cc_code; + std::string query = StringUtils::insertValues( + "SELECT country_code FROM %s " + "WHERE `ip_start` <= %d AND `ip_end` >= %d " + "ORDER BY `ip_start` DESC LIMIT 1;", + ServerConfig::m_ip_geolocation_table.c_str(), addr.getIP(), + addr.getIP()); + + std::vector> output; + if (easySQLQuery(query, &output) && !output.empty()) + { + cc_code = output[0][0]; + } + return cc_code; +} // ip2Country + +//----------------------------------------------------------------------------- +/** Queries the database's IPv6 mapping to determine the country code for an + * address. + * \param addr Queried address. + * \return A country code string if the address is found in the mapping, + * and an empty string otherwise. + */ +std::string DatabaseConnector::ipv62Country(const SocketAddress& addr) const +{ + if (!m_db || !m_ipv6_geolocation_table_exists) + return ""; + + std::string cc_code; + const std::string& ipv6 = addr.toString(false/*show_port*/); + std::string query = StringUtils::insertValues( + "SELECT country_code FROM %s " + "WHERE `ip_start` <= upperIPv6(\"%s\") AND `ip_end` >= upperIPv6(\"%s\") " + "ORDER BY `ip_start` DESC LIMIT 1;", + ServerConfig::m_ipv6_geolocation_table.c_str(), ipv6.c_str(), + ipv6.c_str()); + + std::vector> output; + if (easySQLQuery(query, &output) && !output.empty()) + { + cc_code = output[0][0]; + } + return cc_code; +} // ipv62Country + +// ---------------------------------------------------------------------------- +/** A function invoked within SQLite */ +void DatabaseConnector::upperIPv6SQL(sqlite3_context* context, int argc, + sqlite3_value** argv) +{ + if (argc != 1) + { + sqlite3_result_int64(context, 0); + return; + } + + char* ipv6 = (char*)sqlite3_value_text(argv[0]); + if (ipv6 == NULL) + { + sqlite3_result_int64(context, 0); + return; + } + sqlite3_result_int64(context, upperIPv6(ipv6)); +} + +// ---------------------------------------------------------------------------- +/** A function that checks within SQLite whether an IPv6 address (argv[1]) + * is located within a specified block (argv[0]) of IPv6 addresses. + */ +void DatabaseConnector::insideIPv6CIDRSQL(sqlite3_context* context, int argc, + sqlite3_value** argv) +{ + if (argc != 2) + { + sqlite3_result_int(context, 0); + return; + } + + char* ipv6_cidr = (char*)sqlite3_value_text(argv[0]); + char* ipv6_in = (char*)sqlite3_value_text(argv[1]); + if (ipv6_cidr == NULL || ipv6_in == NULL) + { + sqlite3_result_int(context, 0); + return; + } + sqlite3_result_int(context, insideIPv6CIDR(ipv6_cidr, ipv6_in)); +} // insideIPv6CIDRSQL + +// ---------------------------------------------------------------------------- +/* +Copy below code so it can be use as loadable extension to be used in sqlite3 +command interface (together with andIPv6 and insideIPv6CIDR from stk_ipv6) + +#include "sqlite3ext.h" +SQLITE_EXTENSION_INIT1 +// ---------------------------------------------------------------------------- +sqlite3_extension_init(sqlite3* db, char** pzErrMsg, + const sqlite3_api_routines* pApi) +{ + SQLITE_EXTENSION_INIT2(pApi) + sqlite3_create_function(db, "insideIPv6CIDR", 2, SQLITE_UTF8, NULL, + insideIPv6CIDRSQL, NULL, NULL); + sqlite3_create_function(db, "upperIPv6", 1, SQLITE_UTF8, 0, upperIPv6SQL, + 0, 0); + return 0; +} // sqlite3_extension_init +*/ + +//----------------------------------------------------------------------------- +/** When a peer disconnects from the server, this function saves to the + * database peer's disconnection time and statistics (ping and packet loss). + * \param peer Disconnecting peer. + */ +void DatabaseConnector::writeDisconnectInfoTable(STKPeer* peer) +{ + if (m_server_stats_table.empty()) + return; + std::string query = StringUtils::insertValues( + "UPDATE %s SET disconnected_time = datetime('now'), " + "ping = %d, packet_loss = %d " + "WHERE host_id = %u;", m_server_stats_table.c_str(), + peer->getAveragePing(), peer->getPacketLoss(), + peer->getHostId()); + easySQLQuery(query); +} // writeDisconnectInfoTable + +//----------------------------------------------------------------------------- +/** Creates necessary tables and views if they don't exist yet in the database. + * As the function is invoked during the server launch, it also updates rows + * related to players whose disconnection time wasn't written, and loads + * last used host id. + */ +void DatabaseConnector::initServerStatsTable() +{ + if (!ServerConfig::m_sql_management || !m_db) + return; + std::string table_name = std::string("v") + + StringUtils::toString(ServerConfig::m_server_db_version) + "_" + + ServerConfig::m_server_uid + "_stats"; + + std::ostringstream oss; + oss << "CREATE TABLE IF NOT EXISTS " << table_name << " (\n" + " host_id INTEGER UNSIGNED NOT NULL PRIMARY KEY, -- Unique host id in STKHost of each connection session for a STKPeer\n" + " ip INTEGER UNSIGNED NOT NULL, -- IP decimal of host\n"; + if (ServerConfig::m_ipv6_connection) + oss << " ipv6 TEXT NOT NULL DEFAULT '', -- IPv6 (if exists) in string of host\n"; + oss << " port INTEGER UNSIGNED NOT NULL, -- Port of host\n" + " online_id INTEGER UNSIGNED NOT NULL, -- Online if of the host (0 for offline account)\n" + " username TEXT NOT NULL, -- First player name in the host (if the host has splitscreen player)\n" + " player_num INTEGER UNSIGNED NOT NULL, -- Number of player(s) from the host, more than 1 if it has splitscreen player\n" + " country_code TEXT NULL DEFAULT NULL, -- 2-letter country code of the host\n" + " version TEXT NOT NULL, -- SuperTuxKart version of the host\n" + " os TEXT NOT NULL, -- Operating system of the host\n" + " connected_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Time when connected\n" + " disconnected_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Time when disconnected (saved when disconnected)\n" + " ping INTEGER UNSIGNED NOT NULL DEFAULT 0, -- Ping of the host\n" + " packet_loss INTEGER NOT NULL DEFAULT 0 -- Mean packet loss count from ENet (saved when disconnected)\n" + ") WITHOUT ROWID;"; + std::string query = oss.str(); + + if (easySQLQuery(query)) + m_server_stats_table = table_name; + + if (m_server_stats_table.empty()) + return; + + // Extra default table _countries: + // Server owner need to initialise this table himself, check NETWORKING.md + std::string country_table_name = std::string("v") + StringUtils::toString( + ServerConfig::m_server_db_version) + "_countries"; + query = StringUtils::insertValues( + "CREATE TABLE IF NOT EXISTS %s (\n" + " country_code TEXT NOT NULL PRIMARY KEY UNIQUE, -- Unique 2-letter country code\n" + " country_flag TEXT NOT NULL, -- Unicode country flag representation of 2-letter country code\n" + " country_name TEXT NOT NULL -- Readable name of this country\n" + ") WITHOUT ROWID;", country_table_name.c_str()); + easySQLQuery(query); + + // Default views: + // _full_stats + // Full stats with ip in human readable format and time played of each + // players in minutes + std::string full_stats_view_name = std::string("v") + + StringUtils::toString(ServerConfig::m_server_db_version) + "_" + + ServerConfig::m_server_uid + "_full_stats"; + oss.str(""); + oss << "CREATE VIEW IF NOT EXISTS " << full_stats_view_name << " AS\n" + << " SELECT host_id, ip,\n" + << " ((ip >> 24) & 255) ||'.'|| ((ip >> 16) & 255) ||'.'|| ((ip >> 8) & 255) ||'.'|| ((ip ) & 255) AS ip_readable,\n"; + if (ServerConfig::m_ipv6_connection) + oss << " ipv6,"; + oss << " port, online_id, username, player_num,\n" + << " " << m_server_stats_table << ".country_code AS country_code, country_flag, country_name, version, os,\n" + << " ROUND((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0, 2) AS time_played,\n" + << " connected_time, disconnected_time, ping, packet_loss FROM " << m_server_stats_table << "\n" + << " LEFT JOIN " << country_table_name << " ON " + << country_table_name << ".country_code = " << m_server_stats_table << ".country_code\n" + << " ORDER BY connected_time DESC;"; + query = oss.str(); + easySQLQuery(query); + + // _current_players + // Current players in server with ip in human readable format and time + // played of each players in minutes + std::string current_players_view_name = std::string("v") + + StringUtils::toString(ServerConfig::m_server_db_version) + "_" + + ServerConfig::m_server_uid + "_current_players"; + oss.str(""); + oss.clear(); + oss << "CREATE VIEW IF NOT EXISTS " << current_players_view_name << " AS\n" + << " SELECT host_id, ip,\n" + << " ((ip >> 24) & 255) ||'.'|| ((ip >> 16) & 255) ||'.'|| ((ip >> 8) & 255) ||'.'|| ((ip ) & 255) AS ip_readable,\n"; + if (ServerConfig::m_ipv6_connection) + oss << " ipv6,"; + oss << " port, online_id, username, player_num,\n" + << " " << m_server_stats_table << ".country_code AS country_code, country_flag, country_name, version, os,\n" + << " ROUND((STRFTIME(\"%s\", 'now') - STRFTIME(\"%s\", connected_time)) / 60.0, 2) AS time_played,\n" + << " connected_time, ping FROM " << m_server_stats_table << "\n" + << " LEFT JOIN " << country_table_name << " ON " + << country_table_name << ".country_code = " << m_server_stats_table << ".country_code\n" + << " WHERE connected_time = disconnected_time;"; + query = oss.str(); + easySQLQuery(query); + + // _player_stats + // All players with online id and username with their time played stats + // in this server since creation of this database + // If sqlite supports window functions (since 3.25), it will include last session player info (ip, country, ping...) + std::string player_stats_view_name = std::string("v") + + StringUtils::toString(ServerConfig::m_server_db_version) + "_" + + ServerConfig::m_server_uid + "_player_stats"; + oss.str(""); + oss.clear(); + if (sqlite3_libversion_number() < 3025000) + { + oss << "CREATE VIEW IF NOT EXISTS " << player_stats_view_name << " AS\n" + << " SELECT online_id, username, COUNT(online_id) AS num_connections,\n" + << " MIN(connected_time) AS first_connected_time,\n" + << " MAX(connected_time) AS last_connected_time,\n" + << " ROUND(SUM((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS total_time_played,\n" + << " ROUND(AVG((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS average_time_played,\n" + << " ROUND(MIN((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS min_time_played,\n" + << " ROUND(MAX((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS max_time_played\n" + << " FROM " << m_server_stats_table << "\n" + << " WHERE online_id != 0 GROUP BY online_id ORDER BY num_connections DESC;"; + } + else + { + oss << "CREATE VIEW IF NOT EXISTS " << player_stats_view_name << " AS\n" + << " SELECT a.online_id, a.username, a.ip, a.ip_readable,\n"; + if (ServerConfig::m_ipv6_connection) + oss << " a.ipv6,"; + oss << " a.port, a.player_num,\n" + << " a.country_code, a.country_flag, a.country_name, a.version, a.os, a.ping, a.packet_loss,\n" + << " b.num_connections, b.first_connected_time, b.first_disconnected_time,\n" + << " a.connected_time AS last_connected_time, a.disconnected_time AS last_disconnected_time,\n" + << " a.time_played AS last_time_played, b.total_time_played, b.average_time_played,\n" + << " b.min_time_played, b.max_time_played\n" + << " FROM\n" + << " (\n" + << " SELECT *,\n" + << " ROW_NUMBER() OVER\n" + << " (\n" + << " PARTITION BY online_id\n" + << " ORDER BY connected_time DESC\n" + << " ) RowNum\n" + << " FROM " << full_stats_view_name << " where online_id != 0\n" + << " ) as a\n" + << " JOIN\n" + << " (\n" + << " SELECT online_id, COUNT(online_id) AS num_connections,\n" + << " MIN(connected_time) AS first_connected_time,\n" + << " MIN(disconnected_time) AS first_disconnected_time,\n" + << " ROUND(SUM((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS total_time_played,\n" + << " ROUND(AVG((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS average_time_played,\n" + << " ROUND(MIN((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS min_time_played,\n" + << " ROUND(MAX((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS max_time_played\n" + << " FROM " << m_server_stats_table << " WHERE online_id != 0 GROUP BY online_id\n" + << " ) AS b\n" + << " ON b.online_id = a.online_id\n" + << " WHERE RowNum = 1 ORDER BY num_connections DESC;\n"; + } + query = oss.str(); + easySQLQuery(query); + + uint32_t last_host_id = 0; + query = StringUtils::insertValues("SELECT MAX(host_id) FROM %s;", + m_server_stats_table.c_str()); + + std::vector> output; + if (easySQLQuery(query, &output)) + { + if (!output.empty() && !output[0].empty() + && StringUtils::fromString(output[0][0], last_host_id)) + { + Log::info("DatabaseConnector", "%u was last server session max host id.", + last_host_id); + } + } + else + { + m_server_stats_table = ""; + } + + STKHost::get()->setNextHostId(last_host_id); + + // Update disconnected time (if stk crashed it will not be written) + query = StringUtils::insertValues( + "UPDATE %s SET disconnected_time = datetime('now') " + "WHERE connected_time = disconnected_time;", + m_server_stats_table.c_str()); + easySQLQuery(query); +} // initServerStatsTable + +//----------------------------------------------------------------------------- +/** Writes a report of one player about another player. + * \param reporter Peer that sends the report. + * \param reporter_npp Player profile that sends the report. + * \param reporting Peer that is reported. + * \param reporting_npp Player profile that is reported. + * \param info The report message. + * \return True if the database query succeeded. + */ +bool DatabaseConnector::writeReport( + STKPeer* reporter, std::shared_ptr reporter_npp, + STKPeer* reporting, std::shared_ptr reporting_npp, + irr::core::stringw& info) +{ + std::string query; + + std::shared_ptr coll = std::make_shared(); + if (ServerConfig::m_ipv6_connection) + { + query = StringUtils::insertValues( + "INSERT INTO %s " + "(server_uid, reporter_ip, reporter_ipv6, reporter_online_id, reporter_username, " + "info, reporting_ip, reporting_ipv6, reporting_online_id, reporting_username) " + "VALUES (%s, %u, \"%s\", %u, %s, %s, %u, \"%s\", %u, %s);", + ServerConfig::m_player_reports_table.c_str(), + Binder(coll, ServerConfig::m_server_uid, "server_uid"), + !reporter->getAddress().isIPv6() ? reporter->getAddress().getIP() : 0, + reporter->getAddress().isIPv6() ? reporter->getAddress().toString(false) : "", + reporter_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporter_npp->getName()), "reporter_name"), + Binder(coll, StringUtils::wideToUtf8(info), "info"), + !reporting->getAddress().isIPv6() ? reporting->getAddress().getIP() : 0, + reporting->getAddress().isIPv6() ? reporting->getAddress().toString(false) : "", + reporting_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporting_npp->getName()), "reporting_name") + ); + } + else + { + query = StringUtils::insertValues( + "INSERT INTO %s " + "(server_uid, reporter_ip, reporter_online_id, reporter_username, " + "info, reporting_ip, reporting_online_id, reporting_username) " + "VALUES (%s, %u, %u, %s, %s, %u, %u, %s);", + ServerConfig::m_player_reports_table.c_str(), + Binder(coll, ServerConfig::m_server_uid, "server_uid"), + reporter->getAddress().getIP(), + reporter_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporter_npp->getName()), "reporter_name"), + Binder(coll, StringUtils::wideToUtf8(info), "info"), + reporting->getAddress().getIP(), + reporting_npp->getOnlineId(), + Binder(coll, StringUtils::wideToUtf8(reporting_npp->getName()), "reporting_name") + ); + } + return easySQLQuery(query, nullptr, coll->getBindFunction()); +} // writeReport + +//----------------------------------------------------------------------------- +/** Gets the rows from IPv4 ban table, either all of them (for polling + * purposes), or those describing a certain address (if only one peer has to + * be checked). + * \param ip The IP address to check the database for. If zero, all rows + * will be given. + * \return A vector of rows in the form of IpBanTableData structures. + */ +std::vector +DatabaseConnector::getIpBanTableData(uint32_t ip) const +{ + std::vector result; + if (!m_ip_ban_table_exists) + { + return result; + } + bool single_ip = (ip != 0); + std::ostringstream oss; + oss << "SELECT rowid, ip_start, ip_end, reason, description FROM "; + oss << (std::string)ServerConfig::m_ip_ban_table << " WHERE "; + if (single_ip) + oss << "ip_start <= " << ip << " AND ip_end >= " << ip << " AND "; + oss << "datetime('now') > datetime(starting_time) AND " + "(expired_days is NULL OR datetime" + "(starting_time, '+'||expired_days||' days') > datetime('now'))"; + if (single_ip) + oss << " LIMIT 1"; + oss << ";"; + std::string query = oss.str(); + + std::vector> output; + easySQLQuery(query, &output); + + for (std::vector& row: output) + { + IpBanTableData element; + if (!StringUtils::fromString(row[0], element.row_id)) + continue; + if (!StringUtils::fromString(row[1], element.ip_start)) + continue; + if (!StringUtils::fromString(row[2], element.ip_end)) + continue; + element.reason = row[3]; + element.description = row[4]; + result.push_back(element); + } + return result; +} // getIpBanTableData + +//----------------------------------------------------------------------------- +/** For a peer that turned out to be banned by IPv4, this function increases + * the trigger count. + * \param ip_start Start of IP ban range corresponding to peer. + * \param ip_end End of IP ban range corresponding to peer. + */ +void DatabaseConnector::increaseIpBanTriggerCount(uint32_t ip_start, uint32_t ip_end) const +{ + std::string query = StringUtils::insertValues( + "UPDATE %s SET trigger_count = trigger_count + 1, " + "last_trigger = datetime('now') " + "WHERE ip_start = %u AND ip_end = %u;", + ServerConfig::m_ip_ban_table.c_str(), ip_start, ip_end); + easySQLQuery(query); +} // getIpBanTableData + +//----------------------------------------------------------------------------- +/** Gets the rows from IPv6 ban table, either all of them (for polling + * purposes), or those describing a certain address (if only one peer has to + * be checked). + * \param ip The IPv6 address to check the database for. If empty, all rows + * will be given. + * \return A vector of rows in the form of Ipv6BanTableData structures. + */ +std::vector +DatabaseConnector::getIpv6BanTableData(std::string ipv6) const +{ + std::vector result; + if (!m_ipv6_ban_table_exists) + { + return result; + } + bool single_ip = !ipv6.empty(); + std::string query; + std::shared_ptr coll = std::make_shared(); + + query = StringUtils::insertValues( + "SELECT rowid, ipv6_cidr, reason, description FROM %s WHERE ", + ServerConfig::m_ipv6_ban_table.c_str() + ); + if (single_ip) + query += StringUtils::insertValues( + "insideIPv6CIDR(ipv6_cidr, %s) = 1 AND ", + Binder(coll, ipv6, "ipv6") + ); + + query += "datetime('now') > datetime(starting_time) AND " + "(expired_days is NULL OR datetime" + "(starting_time, '+'||expired_days||' days') > datetime('now'))"; + + if (single_ip) + query += " LIMIT 1;"; + + std::vector> output; + easySQLQuery(query, &output, coll->getBindFunction()); + + for (std::vector& row: output) + { + Ipv6BanTableData element; + if (!StringUtils::fromString(row[0], element.row_id)) + continue; + element.ipv6_cidr = row[1]; + element.reason = row[2]; + element.description = row[3]; + result.push_back(element); + } + return result; +} // getIpv6BanTableData + +//----------------------------------------------------------------------------- +/** For a peer that turned out to be banned by IPv6, this function increases + * the trigger count. + * \param ipv6_cidr Block of IPv6 addresses corresponding to the peer. + */ +void DatabaseConnector::increaseIpv6BanTriggerCount(const std::string& ipv6_cidr) const +{ + std::shared_ptr coll = std::make_shared(); + std::string query = StringUtils::insertValues( + "UPDATE %s SET trigger_count = trigger_count + 1, " + "last_trigger = datetime('now') " + "WHERE ipv6_cidr = %s;", + ServerConfig::m_ipv6_ban_table.c_str(), + Binder(coll, ipv6_cidr, "ipv6_cidr") + ); + easySQLQuery(query, nullptr, coll->getBindFunction()); +} // increaseIpv6BanTriggerCount + +//----------------------------------------------------------------------------- +/** Gets the rows from online id ban table, either all of them (for polling + * purposes), or those describing a certain online id (if only one peer has + * to be checked). + * \param online_id The online id to check the database for. If empty, all + * rows will be given. + * \return A vector of rows in the form of OnlineIdBanTableData structures. + */ +std::vector +DatabaseConnector::getOnlineIdBanTableData(uint32_t online_id) const +{ + std::vector result; + if (!m_online_id_ban_table_exists) + { + return result; + } + bool single_id = (online_id != 0); + std::ostringstream oss; + oss << "SELECT rowid, online_id, reason, description FROM "; + oss << (std::string)ServerConfig::m_online_id_ban_table; + oss << " WHERE "; + if (single_id) + oss << "online_id = " << online_id << " AND "; + oss << "datetime('now') > datetime(starting_time) AND " + "(expired_days is NULL OR datetime" + "(starting_time, '+'||expired_days||' days') > datetime('now'))"; + if (single_id) + oss << " LIMIT 1"; + oss << ";"; + std::string query = oss.str(); + sqlite3_exec(m_db, query.c_str(), + [](void* ptr, int count, char** data, char** columns) + { + std::vector* vec = (std::vector*)ptr; + OnlineIdBanTableData element; + if (!StringUtils::fromString(data[0], element.row_id)) + return 0; + if (!StringUtils::fromString(data[1], element.online_id)) + return 0; + element.reason = std::string(data[2]); + element.description = std::string(data[3]); + vec->push_back(element); + return 0; + }, &result, NULL); + return result; +} // getOnlineIdBanTableData + +//----------------------------------------------------------------------------- +/** For a peer that turned out to be banned by online id, this function + * increases the trigger count. + * \param online_id Online id of the peer. + */ +void DatabaseConnector::increaseOnlineIdBanTriggerCount(uint32_t online_id) const +{ + std::string query = StringUtils::insertValues( + "UPDATE %s SET trigger_count = trigger_count + 1, " + "last_trigger = datetime('now') " + "WHERE online_id = %u;", + ServerConfig::m_online_id_ban_table.c_str(), online_id); + easySQLQuery(query); +} // increaseOnlineIdBanTriggerCount + +//----------------------------------------------------------------------------- +/** Clears reports that are older than a certain number of days + * (specified in the server config). + */ +void DatabaseConnector::clearOldReports() +{ + if (m_player_reports_table_exists && + ServerConfig::m_player_reports_expired_days != 0.0f) + { + std::string query = StringUtils::insertValues( + "DELETE FROM %s " + "WHERE datetime" + "(reported_time, '+%f days') < datetime('now');", + ServerConfig::m_player_reports_table.c_str(), + ServerConfig::m_player_reports_expired_days); + easySQLQuery(query); + } +} // clearOldReports + +//----------------------------------------------------------------------------- +/** Sets disconnection times for those peers that already left the server, but + * whose disconnection times wasn't set yet. + * \param present_hosts List of online ids of present peers. + */ +void DatabaseConnector::setDisconnectionTimes(std::vector& present_hosts) +{ + if (!hasServerStatsTable()) + return; + std::ostringstream oss; + oss << "UPDATE " << m_server_stats_table + << " SET disconnected_time = datetime('now')" + << " WHERE connected_time = disconnected_time"; + if (present_hosts.empty()) + { + oss << ";"; + } + else + { + oss << " AND host_id NOT IN ("; + for (unsigned i = 0; i < present_hosts.size(); i++) + { + if (i > 0) + oss << ","; + oss << present_hosts[i]; + } + oss << ");"; + } + std::string query = oss.str(); + easySQLQuery(query); +} // setDisconnectionTimes + +//----------------------------------------------------------------------------- +/** Adds a specified IP address to the IPv4 ban table. Usually invoked from + * network console. + * \param addr Address to ban. + */ +void DatabaseConnector::saveAddressToIpBanTable(const SocketAddress& addr) +{ + if (addr.isIPv6() || !m_db || !m_ip_ban_table_exists) + return; + + std::string query = StringUtils::insertValues( + "INSERT INTO %s (ip_start, ip_end) " + "VALUES (%u, %u);", + ServerConfig::m_ip_ban_table.c_str(), addr.getIP(), addr.getIP()); + easySQLQuery(query); +} // saveAddressToIpBanTable + +//----------------------------------------------------------------------------- +/** Called when the player joins the server, inserts player info into database. + * \param peer The peer that joins. + * \param online_id Player's online id. + * \param player_count Number of players joining using a single peer. + * \param country_code Country code deduced by global or local IP mapping. + */ +void DatabaseConnector::onPlayerJoinQueries(std::shared_ptr peer, + uint32_t online_id, unsigned player_count, const std::string& country_code) +{ + if (m_server_stats_table.empty() || peer->isAIPeer()) + return; + std::string query; + std::shared_ptr coll = std::make_shared(); + auto version_os = StringUtils::extractVersionOS(peer->getUserVersion()); + if (ServerConfig::m_ipv6_connection && peer->getAddress().isIPv6()) + { + query = StringUtils::insertValues( + "INSERT INTO %s " + "(host_id, ip, ipv6, port, online_id, username, player_num, " + "country_code, version, os, ping) " + "VALUES (%u, 0, \"%s\", %u, %u, %s, %u, %s, %s, %s, %u);", + m_server_stats_table.c_str(), + peer->getHostId(), + peer->getAddress().toString(false), + peer->getAddress().getPort(), + online_id, + Binder(coll, StringUtils::wideToUtf8(peer->getPlayerProfiles()[0]->getName()), "player_name"), + player_count, + Binder(coll, country_code, "country_code", true), + Binder(coll, version_os.first, "version"), + Binder(coll, version_os.second, "os"), + peer->getAveragePing() + ); + } + else + { + query = StringUtils::insertValues( + "INSERT INTO %s " + "(host_id, ip, port, online_id, username, player_num, " + "country_code, version, os, ping) " + "VALUES (%u, %u, %u, %u, %s, %u, %s, %s, %s, %u);", + m_server_stats_table.c_str(), + peer->getHostId(), + peer->getAddress().getIP(), + peer->getAddress().getPort(), + online_id, + Binder(coll, StringUtils::wideToUtf8(peer->getPlayerProfiles()[0]->getName()), "player_name"), + player_count, + Binder(coll, country_code, "country_code", true), + Binder(coll, version_os.first, "version"), + Binder(coll, version_os.second, "os"), + peer->getAveragePing() + ); + } + easySQLQuery(query, nullptr, coll->getBindFunction()); +} // onPlayerJoinQueries + +//----------------------------------------------------------------------------- +/** Prints all rows of the IPv4 ban table. Called from the network console. */ +void DatabaseConnector::listBanTable() +{ + if (!m_db) + return; + auto printer = [](void* data, int argc, char** argv, char** name) + { + for (int i = 0; i < argc; i++) + { + std::cout << name[i] << " = " << (argv[i] ? argv[i] : "NULL") + << "\n"; + } + std::cout << "\n"; + return 0; + }; + if (m_ip_ban_table_exists) + { + std::string query = "SELECT * FROM "; + query += ServerConfig::m_ip_ban_table; + query += ";"; + std::cout << "IP ban list:\n"; + sqlite3_exec(m_db, query.c_str(), printer, NULL, NULL); + } + if (m_online_id_ban_table_exists) + { + std::string query = "SELECT * FROM "; + query += ServerConfig::m_online_id_ban_table; + query += ";"; + std::cout << "Online Id ban list:\n"; + sqlite3_exec(m_db, query.c_str(), printer, NULL, NULL); + } +} // listBanTable +#endif // ENABLE_SQLITE3 diff --git a/src/network/database_connector.hpp b/src/network/database_connector.hpp new file mode 100644 index 00000000000..d22e588b8bd --- /dev/null +++ b/src/network/database_connector.hpp @@ -0,0 +1,195 @@ +// +// SuperTuxKart - a fun racing game with go-kart +// Copyright (C) 2024 SuperTuxKart-Team +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of the GNU General Public License +// as published by the Free Software Foundation; either version 3 +// of the License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. + +#ifdef ENABLE_SQLITE3 + +#ifndef DATABASE_CONNECTOR_HPP +#define DATABASE_CONNECTOR_HPP + +#include "utils/string_utils.hpp" +#include "utils/time.hpp" + +#include +#include +#include +#include +#include +#include + +class SocketAddress; +class STKPeer; +class NetworkPlayerProfile; + +/** The purpose of Binder and BinderCollection structures is to allow + * putting values to bind inside StringUtils::insertValues, which is commonly + * used for values that don't require binding (such as integers). + * Unlike previously used approach with separate query formation and binding, + * the arguments are written in the code in the order of query appearance + * (even though real binding still happens later). It also avoids repeated + * binding code. + * + * Syntax looks as follows: + * std::shared_ptr coll = std::make_shared...; + * std::string query_string = StringUtils::insertValues( + * "query contents with wildcards of type %d, %s, %u, ..." + * "where %s is put for values that will be bound later", + * values to insert, ..., Binder(coll, other parameters), ...); + * Then the bind function (e.g. for usage in easySQLQuery) should be + * coll->getBindFunction(). + */ + +struct Binder; + +/** BinderCollection is a structure that collects Binder objects used in an + * SQL query formed with insertValues() (see above). For a single query, a + * single instance of BinderCollection should be used. After a query is + * formed, BinderCollection can produce bind function to use with sqlite3. + */ +struct BinderCollection +{ + std::vector> m_binders; + + std::function getBindFunction() const; +}; + +/** Binder is a wrapper for a string to be bound into an SQL query. See above + * for its usage in insertValues(). When it's printed to an output stream + * (in particular, this is done in insertValues implementation), this Binder + * is added to the query's BinderCollection, and the '?'-placeholder is added + * to the query string instead of %s. + * + * When using Binder, make sure that: + * - operator << is invoked on it exactly once; + * - operator << is invoked on several Binders in the order in which they go + * in the query; + * - before calling insertValues, there is a %-wildcard corresponding to the + * Binder in the query string (and not '?'). + * For example, when the query formed inside of a function depends on its + * arguments, it should be formed part by part, from left to right. + * Of course, you can choose the "default" way, binding values separately from + * insertValues() call. + */ +struct Binder +{ + std::weak_ptr m_collection; + std::string m_value; + std::string m_name; + bool m_use_null_if_empty; + + Binder(std::shared_ptr collection, std::string value, + std::string name = "", bool use_null_if_empty = false): + m_collection(collection), m_value(value), + m_name(name), m_use_null_if_empty(use_null_if_empty) {} +}; + +std::ostream& operator << (std::ostream& os, const Binder& binder); + +/** A class that manages the database operations needed for the server to work. + * The SQL queries are intended to be placed only within the implementation + * of this class, while the logic corresponding to those queries should not + * belong here. + */ +class DatabaseConnector +{ +private: + sqlite3* m_db; + std::string m_server_stats_table; + bool m_ip_ban_table_exists; + bool m_ipv6_ban_table_exists; + bool m_online_id_ban_table_exists; + bool m_ip_geolocation_table_exists; + bool m_ipv6_geolocation_table_exists; + bool m_player_reports_table_exists; + uint64_t m_last_poll_db_time; + +public: + /** Corresponds to the row of IPv4 ban table. */ + struct IpBanTableData + { + int row_id; + uint32_t ip_start; + uint32_t ip_end; + std::string reason; + std::string description; + }; + /** Corresponds to the row of IPv6 ban table. */ + struct Ipv6BanTableData + { + int row_id; + std::string ipv6_cidr; + std::string reason; + std::string description; + }; + /** Corresponds to the row of online id ban table. */ + struct OnlineIdBanTableData + { + int row_id; + uint32_t online_id; + std::string reason; + std::string description; + }; + void initDatabase(); + void destroyDatabase(); + + bool easySQLQuery(const std::string& query, + std::vector>* output = nullptr, + std::function bind_function = nullptr, + std::string null_value = "") const; + + void checkTableExists(const std::string& table, bool& result); + + std::string ip2Country(const SocketAddress& addr) const; + + std::string ipv62Country(const SocketAddress& addr) const; + + static void upperIPv6SQL(sqlite3_context* context, int argc, + sqlite3_value** argv); + static void insideIPv6CIDRSQL(sqlite3_context* context, int argc, + sqlite3_value** argv); + void writeDisconnectInfoTable(STKPeer* peer); + void initServerStatsTable(); + bool writeReport( + STKPeer* reporter, std::shared_ptr reporter_npp, + STKPeer* reporting, std::shared_ptr reporting_npp, + irr::core::stringw& info); + bool hasDatabase() const { return m_db != nullptr; } + bool hasServerStatsTable() const { return !m_server_stats_table.empty(); } + bool hasPlayerReportsTable() const + { return m_player_reports_table_exists; } + bool hasIpBanTable() const { return m_ip_ban_table_exists; } + bool hasIpv6BanTable() const { return m_ipv6_ban_table_exists; } + bool hasOnlineIdBanTable() const { return m_online_id_ban_table_exists; } + bool isTimeToPoll() const + { return StkTime::getMonoTimeMs() >= m_last_poll_db_time + 60000; } + void updatePollTime() { m_last_poll_db_time = StkTime::getMonoTimeMs(); } + std::vector getIpBanTableData(uint32_t ip = 0) const; + std::vector getIpv6BanTableData(std::string ipv6 = "") const; + std::vector getOnlineIdBanTableData(uint32_t online_id = 0) const; + void increaseIpBanTriggerCount(uint32_t ip_start, uint32_t ip_end) const; + void increaseIpv6BanTriggerCount(const std::string& ipv6_cidr) const; + void increaseOnlineIdBanTriggerCount(uint32_t online_id) const; + void clearOldReports(); + void setDisconnectionTimes(std::vector& present_hosts); + void saveAddressToIpBanTable(const SocketAddress& addr); + void onPlayerJoinQueries(std::shared_ptr peer, uint32_t online_id, + unsigned player_count, const std::string& country_code); + void listBanTable(); +}; + +#endif // ifndef DATABASE_CONNECTOR_HPP +#endif // ifdef ENABLE_SQLITE3 diff --git a/src/network/protocols/server_lobby.cpp b/src/network/protocols/server_lobby.cpp index 70283270d3d..12092553663 100644 --- a/src/network/protocols/server_lobby.cpp +++ b/src/network/protocols/server_lobby.cpp @@ -30,6 +30,7 @@ #include "modes/capture_the_flag.hpp" #include "modes/linear_world.hpp" #include "network/crypto.hpp" +#include "network/database_connector.hpp" #include "network/event.hpp" #include "network/game_setup.hpp" #include "network/network.hpp" @@ -102,69 +103,6 @@ class SubmitRankingRequest : public Online::XMLRequest // We use max priority for all server requests to avoid downloading of addons // icons blocking the poll request in all-in-one graphical client server -#ifdef ENABLE_SQLITE3 - -// ---------------------------------------------------------------------------- -static void upperIPv6SQL(sqlite3_context* context, int argc, - sqlite3_value** argv) -{ - if (argc != 1) - { - sqlite3_result_int64(context, 0); - return; - } - - char* ipv6 = (char*)sqlite3_value_text(argv[0]); - if (ipv6 == NULL) - { - sqlite3_result_int64(context, 0); - return; - } - sqlite3_result_int64(context, upperIPv6(ipv6)); -} - -// ---------------------------------------------------------------------------- -void insideIPv6CIDRSQL(sqlite3_context* context, int argc, - sqlite3_value** argv) -{ - if (argc != 2) - { - sqlite3_result_int(context, 0); - return; - } - - char* ipv6_cidr = (char*)sqlite3_value_text(argv[0]); - char* ipv6_in = (char*)sqlite3_value_text(argv[1]); - if (ipv6_cidr == NULL || ipv6_in == NULL) - { - sqlite3_result_int(context, 0); - return; - } - sqlite3_result_int(context, insideIPv6CIDR(ipv6_cidr, ipv6_in)); -} // insideIPv6CIDRSQL - -// ---------------------------------------------------------------------------- -/* -Copy below code so it can be use as loadable extension to be used in sqlite3 -command interface (together with andIPv6 and insideIPv6CIDR from stk_ipv6) - -#include "sqlite3ext.h" -SQLITE_EXTENSION_INIT1 -// ---------------------------------------------------------------------------- -sqlite3_extension_init(sqlite3* db, char** pzErrMsg, - const sqlite3_api_routines* pApi) -{ - SQLITE_EXTENSION_INIT2(pApi) - sqlite3_create_function(db, "insideIPv6CIDR", 2, SQLITE_UTF8, NULL, - insideIPv6CIDRSQL, NULL, NULL); - sqlite3_create_function(db, "upperIPv6", 1, SQLITE_UTF8, 0, upperIPv6SQL, - 0, 0); - return 0; -} // sqlite3_extension_init -*/ - -#endif - /** This is the central game setup protocol running in the server. It is * mostly a finite state machine. Note that all nodes in ellipses and light * grey background are actual states; nodes in boxes and white background @@ -248,8 +186,11 @@ ServerLobby::ServerLobby() : LobbyProtocol() m_difficulty.store(ServerConfig::m_server_difficulty); m_game_mode.store(ServerConfig::m_server_mode); m_default_vote = new PeerVote(); - m_player_reports_table_exists = false; - initDatabase(); + +#ifdef ENABLE_SQLITE3 + m_db_connector = new DatabaseConnector(); + m_db_connector->initDatabase(); +#endif } // ServerLobby //----------------------------------------------------------------------------- @@ -267,302 +208,22 @@ ServerLobby::~ServerLobby() if (m_save_server_config) ServerConfig::writeServerConfigToDisk(); delete m_default_vote; - destroyDatabase(); -} // ~ServerLobby -//----------------------------------------------------------------------------- -void ServerLobby::initDatabase() -{ #ifdef ENABLE_SQLITE3 - m_last_poll_db_time = StkTime::getMonoTimeMs(); - m_db = NULL; - m_ip_ban_table_exists = false; - m_ipv6_ban_table_exists = false; - m_online_id_ban_table_exists = false; - m_ip_geolocation_table_exists = false; - m_ipv6_geolocation_table_exists = false; - if (!ServerConfig::m_sql_management) - return; - const std::string& path = ServerConfig::getConfigDirectory() + "/" + - ServerConfig::m_database_file.c_str(); - int ret = sqlite3_open_v2(path.c_str(), &m_db, - SQLITE_OPEN_SHAREDCACHE | SQLITE_OPEN_FULLMUTEX | - SQLITE_OPEN_READWRITE, NULL); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", "Cannot open database: %s.", - sqlite3_errmsg(m_db)); - sqlite3_close(m_db); - m_db = NULL; - return; - } - sqlite3_busy_handler(m_db, [](void* data, int retry) - { - int retry_count = ServerConfig::m_database_timeout / 100; - if (retry < retry_count) - { - sqlite3_sleep(100); - // Return non-zero to let caller retry again - return 1; - } - // Return zero to let caller return SQLITE_BUSY immediately - return 0; - }, NULL); - sqlite3_create_function(m_db, "insideIPv6CIDR", 2, SQLITE_UTF8, NULL, - &insideIPv6CIDRSQL, NULL, NULL); - sqlite3_create_function(m_db, "upperIPv6", 1, SQLITE_UTF8, NULL, - &upperIPv6SQL, NULL, NULL); - checkTableExists(ServerConfig::m_ip_ban_table, m_ip_ban_table_exists); - checkTableExists(ServerConfig::m_ipv6_ban_table, m_ipv6_ban_table_exists); - checkTableExists(ServerConfig::m_online_id_ban_table, - m_online_id_ban_table_exists); - checkTableExists(ServerConfig::m_player_reports_table, - m_player_reports_table_exists); - checkTableExists(ServerConfig::m_ip_geolocation_table, - m_ip_geolocation_table_exists); - checkTableExists(ServerConfig::m_ipv6_geolocation_table, - m_ipv6_geolocation_table_exists); + m_db_connector->destroyDatabase(); + delete m_db_connector; #endif -} // initDatabase +} // ~ServerLobby //----------------------------------------------------------------------------- + void ServerLobby::initServerStatsTable() { #ifdef ENABLE_SQLITE3 - if (!ServerConfig::m_sql_management || !m_db) - return; - std::string table_name = std::string("v") + - StringUtils::toString(ServerConfig::m_server_db_version) + "_" + - ServerConfig::m_server_uid + "_stats"; - - std::ostringstream oss; - oss << "CREATE TABLE IF NOT EXISTS " << table_name << " (\n" - " host_id INTEGER UNSIGNED NOT NULL PRIMARY KEY, -- Unique host id in STKHost of each connection session for a STKPeer\n" - " ip INTEGER UNSIGNED NOT NULL, -- IP decimal of host\n"; - if (ServerConfig::m_ipv6_connection) - oss << " ipv6 TEXT NOT NULL DEFAULT '', -- IPv6 (if exists) in string of host\n"; - oss << " port INTEGER UNSIGNED NOT NULL, -- Port of host\n" - " online_id INTEGER UNSIGNED NOT NULL, -- Online if of the host (0 for offline account)\n" - " username TEXT NOT NULL, -- First player name in the host (if the host has splitscreen player)\n" - " player_num INTEGER UNSIGNED NOT NULL, -- Number of player(s) from the host, more than 1 if it has splitscreen player\n" - " country_code TEXT NULL DEFAULT NULL, -- 2-letter country code of the host\n" - " version TEXT NOT NULL, -- SuperTuxKart version of the host\n" - " os TEXT NOT NULL, -- Operating system of the host\n" - " connected_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Time when connected\n" - " disconnected_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Time when disconnected (saved when disconnected)\n" - " ping INTEGER UNSIGNED NOT NULL DEFAULT 0, -- Ping of the host\n" - " packet_loss INTEGER NOT NULL DEFAULT 0 -- Mean packet loss count from ENet (saved when disconnected)\n" - ") WITHOUT ROWID;"; - std::string query = oss.str(); - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - ret = sqlite3_finalize(stmt); - if (ret == SQLITE_OK) - m_server_stats_table = table_name; - else - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - if (m_server_stats_table.empty()) - return; - - // Extra default table _countries: - // Server owner need to initialise this table himself, check NETWORKING.md - std::string country_table_name = std::string("v") + StringUtils::toString( - ServerConfig::m_server_db_version) + "_countries"; - query = StringUtils::insertValues( - "CREATE TABLE IF NOT EXISTS %s (\n" - " country_code TEXT NOT NULL PRIMARY KEY UNIQUE, -- Unique 2-letter country code\n" - " country_flag TEXT NOT NULL, -- Unicode country flag representation of 2-letter country code\n" - " country_name TEXT NOT NULL -- Readable name of this country\n" - ") WITHOUT ROWID;", country_table_name.c_str()); - easySQLQuery(query); - - // Default views: - // _full_stats - // Full stats with ip in human readable format and time played of each - // players in minutes - std::string full_stats_view_name = std::string("v") + - StringUtils::toString(ServerConfig::m_server_db_version) + "_" + - ServerConfig::m_server_uid + "_full_stats"; - oss.str(""); - oss << "CREATE VIEW IF NOT EXISTS " << full_stats_view_name << " AS\n" - << " SELECT host_id, ip,\n" - << " ((ip >> 24) & 255) ||'.'|| ((ip >> 16) & 255) ||'.'|| ((ip >> 8) & 255) ||'.'|| ((ip ) & 255) AS ip_readable,\n"; - if (ServerConfig::m_ipv6_connection) - oss << " ipv6,"; - oss << " port, online_id, username, player_num,\n" - << " " << m_server_stats_table << ".country_code AS country_code, country_flag, country_name, version, os,\n" - << " ROUND((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0, 2) AS time_played,\n" - << " connected_time, disconnected_time, ping, packet_loss FROM " << m_server_stats_table << "\n" - << " LEFT JOIN " << country_table_name << " ON " - << country_table_name << ".country_code = " << m_server_stats_table << ".country_code\n" - << " ORDER BY connected_time DESC;"; - query = oss.str(); - easySQLQuery(query); - - // _current_players - // Current players in server with ip in human readable format and time - // played of each players in minutes - std::string current_players_view_name = std::string("v") + - StringUtils::toString(ServerConfig::m_server_db_version) + "_" + - ServerConfig::m_server_uid + "_current_players"; - oss.str(""); - oss.clear(); - oss << "CREATE VIEW IF NOT EXISTS " << current_players_view_name << " AS\n" - << " SELECT host_id, ip,\n" - << " ((ip >> 24) & 255) ||'.'|| ((ip >> 16) & 255) ||'.'|| ((ip >> 8) & 255) ||'.'|| ((ip ) & 255) AS ip_readable,\n"; - if (ServerConfig::m_ipv6_connection) - oss << " ipv6,"; - oss << " port, online_id, username, player_num,\n" - << " " << m_server_stats_table << ".country_code AS country_code, country_flag, country_name, version, os,\n" - << " ROUND((STRFTIME(\"%s\", 'now') - STRFTIME(\"%s\", connected_time)) / 60.0, 2) AS time_played,\n" - << " connected_time, ping FROM " << m_server_stats_table << "\n" - << " LEFT JOIN " << country_table_name << " ON " - << country_table_name << ".country_code = " << m_server_stats_table << ".country_code\n" - << " WHERE connected_time = disconnected_time;"; - query = oss.str(); - easySQLQuery(query); - - // _player_stats - // All players with online id and username with their time played stats - // in this server since creation of this database - // If sqlite supports window functions (since 3.25), it will include last session player info (ip, country, ping...) - std::string player_stats_view_name = std::string("v") + - StringUtils::toString(ServerConfig::m_server_db_version) + "_" + - ServerConfig::m_server_uid + "_player_stats"; - oss.str(""); - oss.clear(); - if (sqlite3_libversion_number() < 3025000) - { - oss << "CREATE VIEW IF NOT EXISTS " << player_stats_view_name << " AS\n" - << " SELECT online_id, username, COUNT(online_id) AS num_connections,\n" - << " MIN(connected_time) AS first_connected_time,\n" - << " MAX(connected_time) AS last_connected_time,\n" - << " ROUND(SUM((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS total_time_played,\n" - << " ROUND(AVG((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS average_time_played,\n" - << " ROUND(MIN((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS min_time_played,\n" - << " ROUND(MAX((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS max_time_played\n" - << " FROM " << m_server_stats_table << "\n" - << " WHERE online_id != 0 GROUP BY online_id ORDER BY num_connections DESC;"; - } - else - { - oss << "CREATE VIEW IF NOT EXISTS " << player_stats_view_name << " AS\n" - << " SELECT a.online_id, a.username, a.ip, a.ip_readable,\n"; - if (ServerConfig::m_ipv6_connection) - oss << " a.ipv6,"; - oss << " a.port, a.player_num,\n" - << " a.country_code, a.country_flag, a.country_name, a.version, a.os, a.ping, a.packet_loss,\n" - << " b.num_connections, b.first_connected_time, b.first_disconnected_time,\n" - << " a.connected_time AS last_connected_time, a.disconnected_time AS last_disconnected_time,\n" - << " a.time_played AS last_time_played, b.total_time_played, b.average_time_played,\n" - << " b.min_time_played, b.max_time_played\n" - << " FROM\n" - << " (\n" - << " SELECT *,\n" - << " ROW_NUMBER() OVER\n" - << " (\n" - << " PARTITION BY online_id\n" - << " ORDER BY connected_time DESC\n" - << " ) RowNum\n" - << " FROM " << full_stats_view_name << " where online_id != 0\n" - << " ) as a\n" - << " JOIN\n" - << " (\n" - << " SELECT online_id, COUNT(online_id) AS num_connections,\n" - << " MIN(connected_time) AS first_connected_time,\n" - << " MIN(disconnected_time) AS first_disconnected_time,\n" - << " ROUND(SUM((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS total_time_played,\n" - << " ROUND(AVG((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS average_time_played,\n" - << " ROUND(MIN((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS min_time_played,\n" - << " ROUND(MAX((STRFTIME(\"%s\", disconnected_time) - STRFTIME(\"%s\", connected_time)) / 60.0), 2) AS max_time_played\n" - << " FROM " << m_server_stats_table << " WHERE online_id != 0 GROUP BY online_id\n" - << " ) AS b\n" - << " ON b.online_id = a.online_id\n" - << " WHERE RowNum = 1 ORDER BY num_connections DESC;\n"; - } - query = oss.str(); - easySQLQuery(query); - - uint32_t last_host_id = 0; - query = StringUtils::insertValues("SELECT MAX(host_id) FROM %s;", - m_server_stats_table.c_str()); - ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW && sqlite3_column_type(stmt, 0) != SQLITE_NULL) - { - last_host_id = (unsigned)sqlite3_column_int64(stmt, 0); - Log::info("ServerLobby", "%u was last server session max host id.", - last_host_id); - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - m_server_stats_table = ""; - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - m_server_stats_table = ""; - } - STKHost::get()->setNextHostId(last_host_id); - - // Update disconnected time (if stk crashed it will not be written) - query = StringUtils::insertValues( - "UPDATE %s SET disconnected_time = datetime('now') " - "WHERE connected_time = disconnected_time;", - m_server_stats_table.c_str()); - easySQLQuery(query); + m_db_connector->initServerStatsTable(); #endif } // initServerStatsTable -//----------------------------------------------------------------------------- -void ServerLobby::destroyDatabase() -{ -#ifdef ENABLE_SQLITE3 - auto peers = STKHost::get()->getPeers(); - for (auto& peer : peers) - writeDisconnectInfoTable(peer.get()); - if (m_db != NULL) - sqlite3_close(m_db); -#endif -} // destroyDatabase - -//----------------------------------------------------------------------------- -void ServerLobby::writeDisconnectInfoTable(STKPeer* peer) -{ -#ifdef ENABLE_SQLITE3 - if (m_server_stats_table.empty()) - return; - std::string query = StringUtils::insertValues( - "UPDATE %s SET disconnected_time = datetime('now'), " - "ping = %d, packet_loss = %d " - "WHERE host_id = %u;", m_server_stats_table.c_str(), - peer->getAveragePing(), peer->getPacketLoss(), - peer->getHostId()); - easySQLQuery(query); -#endif -} // writeDisconnectInfoTable - //----------------------------------------------------------------------------- void ServerLobby::updateAddons() { @@ -964,346 +625,104 @@ bool ServerLobby::notifyEventAsynchronous(Event* event) */ void ServerLobby::pollDatabase() { - if (!ServerConfig::m_sql_management || !m_db) + if (!ServerConfig::m_sql_management || !m_db_connector->hasDatabase()) return; - if (StkTime::getMonoTimeMs() < m_last_poll_db_time + 60000) + if (!m_db_connector->isTimeToPoll()) return; - m_last_poll_db_time = StkTime::getMonoTimeMs(); + m_db_connector->updatePollTime(); - if (m_ip_ban_table_exists) + std::vector ip_ban_list = + m_db_connector->getIpBanTableData(); + std::vector ipv6_ban_list = + m_db_connector->getIpv6BanTableData(); + std::vector online_id_ban_list = + m_db_connector->getOnlineIdBanTableData(); + + for (std::shared_ptr& p : STKHost::get()->getPeers()) { - std::string query = - "SELECT ip_start, ip_end, reason, description FROM "; - query += ServerConfig::m_ip_ban_table; - query += " WHERE datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now'));"; - auto peers = STKHost::get()->getPeers(); - sqlite3_exec(m_db, query.c_str(), - [](void* ptr, int count, char** data, char** columns) + if (p->isAIPeer()) + continue; + bool is_kicked = false; + std::string address = ""; + std::string reason = ""; + std::string description = ""; + + if (p->getAddress().isIPv6()) + { + address = p->getAddress().toString(false); + if (address.empty()) + continue; + for (auto& item: ipv6_ban_list) { - std::vector >* peers_ptr = - (std::vector >*)ptr; - for (std::shared_ptr& p : *peers_ptr) + if (insideIPv6CIDR(item.ipv6_cidr.c_str(), address.c_str()) == 1) { - // IPv4 ban list atm - if (p->isAIPeer() || p->getAddress().isIPv6()) - continue; - - uint32_t ip_start = 0; - uint32_t ip_end = 0; - if (!StringUtils::fromString(data[0], ip_start)) - continue; - if (!StringUtils::fromString(data[1], ip_end)) - continue; - uint32_t peer_addr = p->getAddress().getIP(); - if (ip_start <= peer_addr && ip_end >= peer_addr) - { - Log::info("ServerLobby", - "Kick %s, reason: %s, description: %s", - p->getAddress().toString().c_str(), - data[2], data[3]); - p->kick(); - } + is_kicked = true; + reason = item.reason; + description = item.description; + break; } - return 0; - }, &peers, NULL); - } - - if (m_ipv6_ban_table_exists) - { - std::string query = - "SELECT ipv6_cidr, reason, description FROM "; - query += ServerConfig::m_ipv6_ban_table; - query += " WHERE datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now'));"; - auto peers = STKHost::get()->getPeers(); - sqlite3_exec(m_db, query.c_str(), - [](void* ptr, int count, char** data, char** columns) + } + } + else + { + uint32_t peer_addr = p->getAddress().getIP(); + address = p->getAddress().toString(); + for (auto& item: ip_ban_list) { - std::vector >* peers_ptr = - (std::vector >*)ptr; - for (std::shared_ptr& p : *peers_ptr) + if (item.ip_start <= peer_addr && item.ip_end >= peer_addr) { - std::string ipv6; - if (p->getAddress().isIPv6()) - ipv6 = p->getAddress().toString(false); - // IPv6 ban list atm - if (p->isAIPeer() || ipv6.empty()) - continue; - - char* ipv6_cidr = data[0]; - if (insideIPv6CIDR(ipv6_cidr, ipv6.c_str()) == 1) - { - Log::info("ServerLobby", - "Kick %s, reason: %s, description: %s", - ipv6.c_str(), data[1], data[2]); - p->kick(); - } + is_kicked = true; + reason = item.reason; + description = item.description; + break; } - return 0; - }, &peers, NULL); - } - - if (m_online_id_ban_table_exists) - { - std::string query = "SELECT online_id, reason, description FROM "; - query += ServerConfig::m_online_id_ban_table; - query += " WHERE datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now'));"; - auto peers = STKHost::get()->getPeers(); - sqlite3_exec(m_db, query.c_str(), - [](void* ptr, int count, char** data, char** columns) + } + } + if (!is_kicked && !p->getPlayerProfiles().empty()) + { + uint32_t online_id = p->getPlayerProfiles()[0]->getOnlineId(); + for (auto& item: online_id_ban_list) { - std::vector >* peers_ptr = - (std::vector >*)ptr; - for (std::shared_ptr& p : *peers_ptr) + if (item.online_id == online_id) { - if (p->isAIPeer() - || p->getPlayerProfiles().empty()) - continue; - - uint32_t online_id = 0; - if (!StringUtils::fromString(data[0], online_id)) - continue; - if (online_id == p->getPlayerProfiles()[0]->getOnlineId()) - { - Log::info("ServerLobby", - "Kick %s, reason: %s, description: %s", - p->getAddress().toString().c_str(), - data[1], data[2]); - p->kick(); - } + is_kicked = true; + reason = item.reason; + description = item.description; + break; } - return 0; - }, &peers, NULL); - } + } + } + if (is_kicked) + { + Log::info("ServerLobby", "Kick %s, reason: %s, description: %s", + address.c_str(), reason.c_str(), description.c_str()); + p->kick(); + } + } // for p in peers - if (m_player_reports_table_exists && - ServerConfig::m_player_reports_expired_days != 0.0f) - { - std::string query = StringUtils::insertValues( - "DELETE FROM %s " - "WHERE datetime" - "(reported_time, '+%f days') < datetime('now');", - ServerConfig::m_player_reports_table.c_str(), - ServerConfig::m_player_reports_expired_days); - easySQLQuery(query); - } - if (m_server_stats_table.empty()) - return; + m_db_connector->clearOldReports(); - std::string query; auto peers = STKHost::get()->getPeers(); - std::vector exist_hosts; + std::vector hosts; if (!peers.empty()) { for (auto& peer : peers) { if (!peer->isValidated()) continue; - exist_hosts.push_back(peer->getHostId()); + hosts.push_back(peer->getHostId()); } } - if (peers.empty() || exist_hosts.empty()) - { - query = StringUtils::insertValues( - "UPDATE %s SET disconnected_time = datetime('now') " - "WHERE connected_time = disconnected_time;", - m_server_stats_table.c_str()); - } - else - { - std::ostringstream oss; - oss << "UPDATE " << m_server_stats_table - << " SET disconnected_time = datetime('now')" - << " WHERE connected_time = disconnected_time AND" - << " host_id NOT IN ("; - for (unsigned i = 0; i < exist_hosts.size(); i++) - { - oss << exist_hosts[i]; - if (i != (exist_hosts.size() - 1)) - oss << ","; - } - oss << ");"; - query = oss.str(); - } - easySQLQuery(query); + m_db_connector->setDisconnectionTimes(hosts); } // pollDatabase - -//----------------------------------------------------------------------------- -/** Run simple query with write lock waiting and optional function, this - * function has no callback for the return (if any) by the query. - * Return true if no error occurs - */ -bool ServerLobby::easySQLQuery(const std::string& query, - std::function bind_function) const -{ - if (!m_db) - return false; - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - if (bind_function) - bind_function(stmt); - ret = sqlite3_step(stmt); - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for easy query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return false; - } - } - else - { - Log::error("ServerLobby", - "Error preparing database for easy query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return false; - } - return true; -} // easySQLQuery - -//----------------------------------------------------------------------------- -/* Write true to result if table name exists in database. */ -void ServerLobby::checkTableExists(const std::string& table, bool& result) -{ - if (!m_db) - return; - sqlite3_stmt* stmt = NULL; - if (!table.empty()) - { - std::string query = StringUtils::insertValues( - "SELECT count(type) FROM sqlite_master " - "WHERE type='table' AND name='%s';", table.c_str()); - - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - int number = sqlite3_column_int(stmt, 0); - if (number == 1) - { - Log::info("ServerLobby", "Table named %s will used.", - table.c_str()); - result = true; - } - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - } - if (!result && !table.empty()) - { - Log::warn("ServerLobby", "Table named %s not found in database.", - table.c_str()); - } -} // checkTableExists - -//----------------------------------------------------------------------------- -std::string ServerLobby::ip2Country(const SocketAddress& addr) const -{ - if (!m_db || !m_ip_geolocation_table_exists || addr.isLAN()) - return ""; - - std::string cc_code; - std::string query = StringUtils::insertValues( - "SELECT country_code FROM %s " - "WHERE `ip_start` <= %d AND `ip_end` >= %d " - "ORDER BY `ip_start` DESC LIMIT 1;", - ServerConfig::m_ip_geolocation_table.c_str(), addr.getIP(), - addr.getIP()); - - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - const char* country_code = (char*)sqlite3_column_text(stmt, 0); - cc_code = country_code; - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return ""; - } - return cc_code; -} // ip2Country - -//----------------------------------------------------------------------------- -std::string ServerLobby::ipv62Country(const SocketAddress& addr) const -{ - if (!m_db || !m_ipv6_geolocation_table_exists) - return ""; - - std::string cc_code; - const std::string& ipv6 = addr.toString(false/*show_port*/); - std::string query = StringUtils::insertValues( - "SELECT country_code FROM %s " - "WHERE `ip_start` <= upperIPv6(\"%s\") AND `ip_end` >= upperIPv6(\"%s\") " - "ORDER BY `ip_start` DESC LIMIT 1;", - ServerConfig::m_ipv6_geolocation_table.c_str(), ipv6.c_str(), - ipv6.c_str()); - - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - const char* country_code = (char*)sqlite3_column_text(stmt, 0); - cc_code = country_code; - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return ""; - } - return cc_code; -} // ipv62Country - #endif - //----------------------------------------------------------------------------- void ServerLobby::writePlayerReport(Event* event) { #ifdef ENABLE_SQLITE3 - if (!m_db || !m_player_reports_table_exists) + if (!m_db_connector->hasDatabase() || !m_db_connector->hasPlayerReportsTable()) return; STKPeer* reporter = event->getPeer(); if (!reporter->hasPlayerProfiles()) @@ -1321,65 +740,8 @@ void ServerLobby::writePlayerReport(Event* event) return; auto reporting_npp = reporting_peer->getPlayerProfiles()[0]; - std::string query; - if (ServerConfig::m_ipv6_connection) - { - query = StringUtils::insertValues( - "INSERT INTO %s " - "(server_uid, reporter_ip, reporter_ipv6, reporter_online_id, reporter_username, " - "info, reporting_ip, reporting_ipv6, reporting_online_id, reporting_username) " - "VALUES (?, %u, \"%s\", %u, ?, ?, %u, \"%s\", %u, ?);", - ServerConfig::m_player_reports_table.c_str(), - !reporter->getAddress().isIPv6() ? reporter->getAddress().getIP() : 0, - reporter->getAddress().isIPv6() ? reporter->getAddress().toString(false) : "", - reporter_npp->getOnlineId(), - !reporting_peer->getAddress().isIPv6() ? reporting_peer->getAddress().getIP() : 0, - reporting_peer->getAddress().isIPv6() ? reporting_peer->getAddress().toString(false) : "", - reporting_npp->getOnlineId()); - } - else - { - query = StringUtils::insertValues( - "INSERT INTO %s " - "(server_uid, reporter_ip, reporter_online_id, reporter_username, " - "info, reporting_ip, reporting_online_id, reporting_username) " - "VALUES (?, %u, %u, ?, ?, %u, %u, ?);", - ServerConfig::m_player_reports_table.c_str(), - reporter->getAddress().getIP(), reporter_npp->getOnlineId(), - reporting_peer->getAddress().getIP(), reporting_npp->getOnlineId()); - } - bool written = easySQLQuery(query, - [reporter_npp, reporting_npp, info](sqlite3_stmt* stmt) - { - // SQLITE_TRANSIENT to copy string - if (sqlite3_bind_text(stmt, 1, ServerConfig::m_server_uid.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - ServerConfig::m_server_uid.c_str()); - } - if (sqlite3_bind_text(stmt, 2, - StringUtils::wideToUtf8(reporter_npp->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(reporter_npp->getName()).c_str()); - } - if (sqlite3_bind_text(stmt, 3, - StringUtils::wideToUtf8(info).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(info).c_str()); - } - if (sqlite3_bind_text(stmt, 4, - StringUtils::wideToUtf8(reporting_npp->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8(reporting_npp->getName()).c_str()); - } - }); + bool written = m_db_connector->writeReport(reporter, reporter_npp, + reporting_peer.get(), reporting_npp, info); if (written) { NetworkString* success = getNetworkString(); @@ -3359,7 +2721,9 @@ void ServerLobby::clientDisconnected(Event* event) updatePlayerList(); delete msg; - writeDisconnectInfoTable(event->getPeer()); +#ifdef ENABLE_SQLITE3 + m_db_connector->writeDisconnectInfoTable(event->getPeer()); +#endif } // clientDisconnected //----------------------------------------------------------------------------- @@ -3402,14 +2766,7 @@ void ServerLobby::kickPlayerWithReason(STKPeer* peer, const char* reason) const void ServerLobby::saveIPBanTable(const SocketAddress& addr) { #ifdef ENABLE_SQLITE3 - if (addr.isIPv6() || !m_db || !m_ip_ban_table_exists) - return; - - std::string query = StringUtils::insertValues( - "INSERT INTO %s (ip_start, ip_end) " - "VALUES (%u, %u);", - ServerConfig::m_ip_ban_table.c_str(), addr.getIP(), addr.getIP()); - easySQLQuery(query); + m_db_connector->saveAddressToIpBanTable(addr); #endif } // saveIPBanTable @@ -3754,9 +3111,9 @@ void ServerLobby::handleUnencryptedConnection(std::shared_ptr peer, #ifdef ENABLE_SQLITE3 if (country_code.empty() && !peer->getAddress().isIPv6()) - country_code = ip2Country(peer->getAddress()); + country_code = m_db_connector->ip2Country(peer->getAddress()); if (country_code.empty() && peer->getAddress().isIPv6()) - country_code = ipv62Country(peer->getAddress()); + country_code = m_db_connector->ipv62Country(peer->getAddress()); #endif auto red_blue = STKHost::get()->getAllPlayersTeamInfo(); @@ -3828,7 +3185,7 @@ void ServerLobby::handleUnencryptedConnection(std::shared_ptr peer, message_ack->addFloat(auto_start_timer) .addUInt32(ServerConfig::m_state_frequency) .addUInt8(ServerConfig::m_chat ? 1 : 0) - .addUInt8(m_player_reports_table_exists ? 1 : 0); + .addUInt8(playerReportsTableExists() ? 1 : 0); peer->setSpectator(false); @@ -3889,74 +3246,7 @@ void ServerLobby::handleUnencryptedConnection(std::shared_ptr peer, } #ifdef ENABLE_SQLITE3 - if (m_server_stats_table.empty() || peer->isAIPeer()) - return; - std::string query; - if (ServerConfig::m_ipv6_connection && peer->getAddress().isIPv6()) - { - query = StringUtils::insertValues( - "INSERT INTO %s " - "(host_id, ip, ipv6 ,port, online_id, username, player_num, " - "country_code, version, os, ping) " - "VALUES (%u, 0, \"%s\" ,%u, %u, ?, %u, ?, ?, ?, %u);", - m_server_stats_table.c_str(), peer->getHostId(), - peer->getAddress().toString(false), peer->getAddress().getPort(), - online_id, player_count, peer->getAveragePing()); - } - else - { - query = StringUtils::insertValues( - "INSERT INTO %s " - "(host_id, ip, port, online_id, username, player_num, " - "country_code, version, os, ping) " - "VALUES (%u, %u, %u, %u, ?, %u, ?, ?, ?, %u);", - m_server_stats_table.c_str(), peer->getHostId(), - peer->getAddress().getIP(), peer->getAddress().getPort(), - online_id, player_count, peer->getAveragePing()); - } - easySQLQuery(query, [peer, country_code](sqlite3_stmt* stmt) - { - if (sqlite3_bind_text(stmt, 1, StringUtils::wideToUtf8( - peer->getPlayerProfiles()[0]->getName()).c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - StringUtils::wideToUtf8( - peer->getPlayerProfiles()[0]->getName()).c_str()); - } - if (country_code.empty()) - { - if (sqlite3_bind_null(stmt, 2) != SQLITE_OK) - { - Log::error("easySQLQuery", - "Failed to bind NULL for country code."); - } - } - else - { - if (sqlite3_bind_text(stmt, 2, country_code.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind country: %s.", - country_code.c_str()); - } - } - auto version_os = - StringUtils::extractVersionOS(peer->getUserVersion()); - if (sqlite3_bind_text(stmt, 3, version_os.first.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - version_os.first.c_str()); - } - if (sqlite3_bind_text(stmt, 4, version_os.second.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - version_os.second.c_str()); - } - } - ); + m_db_connector->onPlayerJoinQueries(peer, online_id, player_count, country_code); #endif } // handleUnencryptedConnection @@ -4796,66 +4086,34 @@ void ServerLobby::resetServer() void ServerLobby::testBannedForIP(STKPeer* peer) const { #ifdef ENABLE_SQLITE3 - if (!m_db || !m_ip_ban_table_exists) + if (!m_db_connector->hasDatabase() || !m_db_connector->hasIpBanTable()) return; // Test for IPv4 if (peer->getAddress().isIPv6()) return; - int row_id = -1; - unsigned ip_start = 0; - unsigned ip_end = 0; - std::string query = StringUtils::insertValues( - "SELECT rowid, ip_start, ip_end, reason, description FROM %s " - "WHERE ip_start <= %u AND ip_end >= %u " - "AND datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now')) " - "LIMIT 1;", - ServerConfig::m_ip_ban_table.c_str(), - peer->getAddress().getIP(), peer->getAddress().getIP()); - - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - row_id = sqlite3_column_int(stmt, 0); - ip_start = (unsigned)sqlite3_column_int64(stmt, 1); - ip_end = (unsigned)sqlite3_column_int64(stmt, 2); - const char* reason = (char*)sqlite3_column_text(stmt, 3); - const char* desc = (char*)sqlite3_column_text(stmt, 4); - Log::info("ServerLobby", "%s banned by IP: %s " + bool is_banned = false; + uint32_t ip_start = 0; + uint32_t ip_end = 0; + + std::vector ip_ban_list = + m_db_connector->getIpBanTableData(peer->getAddress().getIP()); + if (!ip_ban_list.empty()) + { + is_banned = true; + ip_start = ip_ban_list[0].ip_start; + ip_end = ip_ban_list[0].ip_end; + int row_id = ip_ban_list[0].row_id; + std::string reason = ip_ban_list[0].reason; + std::string description = ip_ban_list[0].description; + Log::info("ServerLobby", "%s banned by IP: %s " "(rowid: %d, description: %s).", - peer->getAddress().toString().c_str(), reason, row_id, desc); - kickPlayerWithReason(peer, reason); - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return; - } - if (row_id != -1) - { - query = StringUtils::insertValues( - "UPDATE %s SET trigger_count = trigger_count + 1, " - "last_trigger = datetime('now') " - "WHERE ip_start = %u AND ip_end = %u;", - ServerConfig::m_ip_ban_table.c_str(), ip_start, ip_end); - easySQLQuery(query); + peer->getAddress().toString().c_str(), reason.c_str(), row_id, description.c_str()); + kickPlayerWithReason(peer, reason.c_str()); } + if (is_banned) + m_db_connector->increaseIpBanTriggerCount(ip_start, ip_end); #endif } // testBannedForIP @@ -4863,78 +4121,33 @@ void ServerLobby::testBannedForIP(STKPeer* peer) const void ServerLobby::testBannedForIPv6(STKPeer* peer) const { #ifdef ENABLE_SQLITE3 - if (!m_db || !m_ipv6_ban_table_exists) + if (!m_db_connector->hasDatabase() || !m_db_connector->hasIpv6BanTable()) return; // Test for IPv6 if (!peer->getAddress().isIPv6()) return; - int row_id = -1; - std::string ipv6_cidr; - std::string query = StringUtils::insertValues( - "SELECT rowid, ipv6_cidr, reason, description FROM %s " - "WHERE insideIPv6CIDR(ipv6_cidr, ?) = 1 " - "AND datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now')) " - "LIMIT 1;", - ServerConfig::m_ipv6_ban_table.c_str()); - - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - if (sqlite3_bind_text(stmt, 1, - peer->getAddress().toString(false).c_str(), -1, SQLITE_TRANSIENT) - != SQLITE_OK) - { - Log::error("ServerLobby", "Error binding ipv6 addr for query: %s", - sqlite3_errmsg(m_db)); - } + bool is_banned = false; + std::string ipv6_cidr = ""; - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - row_id = sqlite3_column_int(stmt, 0); - ipv6_cidr = (char*)sqlite3_column_text(stmt, 1); - const char* reason = (char*)sqlite3_column_text(stmt, 2); - const char* desc = (char*)sqlite3_column_text(stmt, 3); - Log::info("ServerLobby", "%s banned by IP: %s " - "(rowid: %d, description: %s).", - peer->getAddress().toString().c_str(), reason, row_id, desc); - kickPlayerWithReason(peer, reason); - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", - "Error finalize database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database for query %s: %s", - query.c_str(), sqlite3_errmsg(m_db)); - return; - } - if (row_id != -1) + std::vector ipv6_ban_list = + m_db_connector->getIpv6BanTableData(peer->getAddress().toString(false)); + + if (!ipv6_ban_list.empty()) { - query = StringUtils::insertValues( - "UPDATE %s SET trigger_count = trigger_count + 1, " - "last_trigger = datetime('now') " - "WHERE ipv6_cidr = ?;", ServerConfig::m_ipv6_ban_table.c_str()); - easySQLQuery(query, [ipv6_cidr](sqlite3_stmt* stmt) - { - if (sqlite3_bind_text(stmt, 1, ipv6_cidr.c_str(), - -1, SQLITE_TRANSIENT) != SQLITE_OK) - { - Log::error("easySQLQuery", "Failed to bind %s.", - ipv6_cidr.c_str()); - } - }); + is_banned = true; + ipv6_cidr = ipv6_ban_list[0].ipv6_cidr; + int row_id = ipv6_ban_list[0].row_id; + std::string reason = ipv6_ban_list[0].reason; + std::string description = ipv6_ban_list[0].description; + Log::info("ServerLobby", "%s banned by IPv6: %s " + "(rowid: %d, description: %s).", + peer->getAddress().toString(false).c_str(), reason.c_str(), row_id, description.c_str()); + kickPlayerWithReason(peer, reason.c_str()); } + if (is_banned) + m_db_connector->increaseIpv6BanTriggerCount(ipv6_cidr); #endif } // testBannedForIPv6 @@ -4943,57 +4156,26 @@ void ServerLobby::testBannedForOnlineId(STKPeer* peer, uint32_t online_id) const { #ifdef ENABLE_SQLITE3 - if (!m_db || !m_online_id_ban_table_exists) + if (!m_db_connector->hasDatabase() || !m_db_connector->hasOnlineIdBanTable()) return; - int row_id = -1; - std::string query = StringUtils::insertValues( - "SELECT rowid, reason, description FROM %s " - "WHERE online_id = %u " - "AND datetime('now') > datetime(starting_time) AND " - "(expired_days is NULL OR datetime" - "(starting_time, '+'||expired_days||' days') > datetime('now')) " - "LIMIT 1;", - ServerConfig::m_online_id_ban_table.c_str(), online_id); - - sqlite3_stmt* stmt = NULL; - int ret = sqlite3_prepare_v2(m_db, query.c_str(), -1, &stmt, 0); - if (ret == SQLITE_OK) - { - ret = sqlite3_step(stmt); - if (ret == SQLITE_ROW) - { - row_id = sqlite3_column_int(stmt, 0); - const char* reason = (char*)sqlite3_column_text(stmt, 1); - const char* desc = (char*)sqlite3_column_text(stmt, 2); - Log::info("ServerLobby", "%s banned by online id: %s " - "(online id: %u rowid: %d, description: %s).", - peer->getAddress().toString().c_str(), reason, online_id, - row_id, desc); - kickPlayerWithReason(peer, reason); - } - ret = sqlite3_finalize(stmt); - if (ret != SQLITE_OK) - { - Log::error("ServerLobby", "Error finalize database: %s", - sqlite3_errmsg(m_db)); - } - } - else - { - Log::error("ServerLobby", "Error preparing database: %s", - sqlite3_errmsg(m_db)); - return; - } - if (row_id != -1) + bool is_banned = false; + std::vector online_id_ban_list = + m_db_connector->getOnlineIdBanTableData(online_id); + + if (!online_id_ban_list.empty()) { - query = StringUtils::insertValues( - "UPDATE %s SET trigger_count = trigger_count + 1, " - "last_trigger = datetime('now') " - "WHERE online_id = %u;", - ServerConfig::m_online_id_ban_table.c_str(), online_id); - easySQLQuery(query); + is_banned = true; + int row_id = online_id_ban_list[0].row_id; + std::string reason = online_id_ban_list[0].reason; + std::string description = online_id_ban_list[0].description; + Log::info("ServerLobby", "%s banned by online id: %s " + "(online id: %u, rowid: %d, description: %s).", + peer->getAddress().toString().c_str(), reason.c_str(), online_id, row_id, description.c_str()); + kickPlayerWithReason(peer, reason.c_str()); } + if (is_banned) + m_db_connector->increaseOnlineIdBanTriggerCount(online_id); #endif } // testBannedForOnlineId @@ -5001,34 +4183,7 @@ void ServerLobby::testBannedForOnlineId(STKPeer* peer, void ServerLobby::listBanTable() { #ifdef ENABLE_SQLITE3 - if (!m_db) - return; - auto printer = [](void* data, int argc, char** argv, char** name) - { - for (int i = 0; i < argc; i++) - { - std::cout << name[i] << " = " << (argv[i] ? argv[i] : "NULL") - << "\n"; - } - std::cout << "\n"; - return 0; - }; - if (m_ip_ban_table_exists) - { - std::string query = "SELECT * FROM "; - query += ServerConfig::m_ip_ban_table; - query += ";"; - std::cout << "IP ban list:\n"; - sqlite3_exec(m_db, query.c_str(), printer, NULL, NULL); - } - if (m_online_id_ban_table_exists) - { - std::string query = "SELECT * FROM "; - query += ServerConfig::m_online_id_ban_table; - query += ";"; - std::cout << "Online Id ban list:\n"; - sqlite3_exec(m_db, query.c_str(), printer, NULL, NULL); - } + m_db_connector->listBanTable(); #endif } // listBanTable @@ -6052,3 +5207,13 @@ void ServerLobby::handleServerCommand(Event* event, delete chat; } } // handleServerCommand + +//----------------------------------------------------------------------------- +bool ServerLobby::playerReportsTableExists() const +{ +#ifdef ENABLE_SQLITE3 + return m_db_connector->hasPlayerReportsTable(); +#else + return false; +#endif +} diff --git a/src/network/protocols/server_lobby.hpp b/src/network/protocols/server_lobby.hpp index a54f9f901bb..fc18cfeca43 100644 --- a/src/network/protocols/server_lobby.hpp +++ b/src/network/protocols/server_lobby.hpp @@ -34,11 +34,8 @@ #include #include -#ifdef ENABLE_SQLITE3 -#include -#endif - class BareNetworkString; +class DatabaseConnector; class NetworkItemManager; class NetworkString; class NetworkPlayerProfile; @@ -78,39 +75,12 @@ class ServerLobby : public LobbyProtocol std::string m_country_code; bool m_tried = false; }; - bool m_player_reports_table_exists; #ifdef ENABLE_SQLITE3 - sqlite3* m_db; - - std::string m_server_stats_table; - - bool m_ip_ban_table_exists; - - bool m_ipv6_ban_table_exists; - - bool m_online_id_ban_table_exists; - - bool m_ip_geolocation_table_exists; - - bool m_ipv6_geolocation_table_exists; - - uint64_t m_last_poll_db_time; + DatabaseConnector* m_db_connector; void pollDatabase(); - - bool easySQLQuery(const std::string& query, - std::function bind_function = nullptr) const; - - void checkTableExists(const std::string& table, bool& result); - - std::string ip2Country(const SocketAddress& addr) const; - - std::string ipv62Country(const SocketAddress& addr) const; #endif - void initDatabase(); - - void destroyDatabase(); std::atomic m_state; @@ -378,7 +348,6 @@ class ServerLobby : public LobbyProtocol void testBannedForIP(STKPeer* peer) const; void testBannedForIPv6(STKPeer* peer) const; void testBannedForOnlineId(STKPeer* peer, uint32_t online_id) const; - void writeDisconnectInfoTable(STKPeer* peer); void writePlayerReport(Event* event); bool supportsAI(); void updateAddons(); @@ -419,6 +388,7 @@ class ServerLobby : public LobbyProtocol uint32_t getServerIdOnline() const { return m_server_id_online; } void setClientServerHostId(uint32_t id) { m_client_server_host_id = id; } static int m_fixed_laps; + bool playerReportsTableExists() const; }; // class ServerLobby #endif // SERVER_LOBBY_HPP