diff options
Diffstat (limited to 'src/server/shared')
51 files changed, 2578 insertions, 387 deletions
diff --git a/src/server/shared/CMakeLists.txt b/src/server/shared/CMakeLists.txt index fb86ac7db18..abc019863b1 100644 --- a/src/server/shared/CMakeLists.txt +++ b/src/server/shared/CMakeLists.txt @@ -21,6 +21,7 @@ file(GLOB_RECURSE sources_Logging Logging/*.cpp Logging/*.h) file(GLOB_RECURSE sources_Networking Networking/*.cpp Networking/*.h) file(GLOB_RECURSE sources_Packets Packets/*.cpp Packets/*.h) file(GLOB_RECURSE sources_Threading Threading/*.cpp Threading/*.h) +file(GLOB_RECURSE sources_Updater Updater/*.cpp Updater/*.h) file(GLOB_RECURSE sources_Utilities Utilities/*.cpp Utilities/*.h) file(GLOB sources_localdir *.cpp *.h) @@ -51,6 +52,7 @@ set(shared_STAT_SRCS ${sources_Networking} ${sources_Packets} ${sources_Threading} + ${sources_Updater} ${sources_Utilities} ${sources_localdir} ) @@ -59,8 +61,9 @@ include_directories( ${CMAKE_BINARY_DIR} ${CMAKE_SOURCE_DIR}/dep/recastnavigation/Detour ${CMAKE_SOURCE_DIR}/dep/SFMT - ${CMAKE_SOURCE_DIR}/dep/sockets/include + ${CMAKE_SOURCE_DIR}/dep/cppformat ${CMAKE_SOURCE_DIR}/dep/utf8cpp + ${CMAKE_SOURCE_DIR}/dep/process ${CMAKE_SOURCE_DIR}/src/server ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/Configuration @@ -74,12 +77,15 @@ include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/Packets ${CMAKE_CURRENT_SOURCE_DIR}/Threading ${CMAKE_CURRENT_SOURCE_DIR}/Utilities + ${CMAKE_CURRENT_SOURCE_DIR}/Updater ${CMAKE_SOURCE_DIR}/src/server/game/Entities/Object ${MYSQL_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR} ${VALGRIND_INCLUDE_DIR} ) +GroupSources(${CMAKE_CURRENT_SOURCE_DIR}) + add_library(shared STATIC ${shared_STAT_SRCS} ${shared_STAT_PCH_SRC} diff --git a/src/server/shared/Configuration/Config.cpp b/src/server/shared/Configuration/Config.cpp index 1e1f8c7c3c6..ea426a5d33e 100644 --- a/src/server/shared/Configuration/Config.cpp +++ b/src/server/shared/Configuration/Config.cpp @@ -21,7 +21,6 @@ #include <boost/property_tree/ptree.hpp> #include <boost/property_tree/ini_parser.hpp> #include "Config.h" -#include "Errors.h" using namespace boost::property_tree; diff --git a/src/server/shared/Cryptography/ARC4.cpp b/src/server/shared/Cryptography/ARC4.cpp index 28c3da8d228..d1082b39347 100644 --- a/src/server/shared/Cryptography/ARC4.cpp +++ b/src/server/shared/Cryptography/ARC4.cpp @@ -17,7 +17,6 @@ */ #include "ARC4.h" -#include <openssl/sha.h> ARC4::ARC4(uint8 len) : m_ctx() { diff --git a/src/server/shared/Cryptography/BigNumber.cpp b/src/server/shared/Cryptography/BigNumber.cpp index 434268097fe..720e8e30441 100644 --- a/src/server/shared/Cryptography/BigNumber.cpp +++ b/src/server/shared/Cryptography/BigNumber.cpp @@ -18,7 +18,6 @@ #include "Cryptography/BigNumber.h" #include <openssl/bn.h> -#include <openssl/crypto.h> #include <cstring> #include <algorithm> #include <memory> @@ -171,19 +170,20 @@ bool BigNumber::isZero() const std::unique_ptr<uint8[]> BigNumber::AsByteArray(int32 minSize, bool littleEndian) { - int length = (minSize >= GetNumBytes()) ? minSize : GetNumBytes(); + int numBytes = GetNumBytes(); + int length = (minSize >= numBytes) ? minSize : numBytes; uint8* array = new uint8[length]; // If we need more bytes than length of BigNumber set the rest to 0 - if (length > GetNumBytes()) + if (length > numBytes) memset((void*)array, 0, length); BN_bn2bin(_bn, (unsigned char *)array); // openssl's BN stores data internally in big endian format, reverse if little endian desired if (littleEndian) - std::reverse(array, array + length); + std::reverse(array, array + numBytes); std::unique_ptr<uint8[]> ret(array); return ret; diff --git a/src/server/shared/Cryptography/SHA1.h b/src/server/shared/Cryptography/SHA1.h index ebd9f721d4a..f59bdc25556 100644 --- a/src/server/shared/Cryptography/SHA1.h +++ b/src/server/shared/Cryptography/SHA1.h @@ -39,8 +39,8 @@ class SHA1Hash void Initialize(); void Finalize(); - uint8 *GetDigest(void) { return mDigest; }; - int GetLength(void) const { return SHA_DIGEST_LENGTH; }; + uint8 *GetDigest(void) { return mDigest; } + int GetLength(void) const { return SHA_DIGEST_LENGTH; } private: SHA_CTX mC; diff --git a/src/server/shared/Database/AdhocStatement.h b/src/server/shared/Database/AdhocStatement.h index 8d9365900d9..8195d9add98 100644 --- a/src/server/shared/Database/AdhocStatement.h +++ b/src/server/shared/Database/AdhocStatement.h @@ -40,4 +40,4 @@ class BasicStatementTask : public SQLOperation QueryResultPromise* m_result; }; -#endif
\ No newline at end of file +#endif diff --git a/src/server/shared/Database/DatabaseLoader.cpp b/src/server/shared/Database/DatabaseLoader.cpp new file mode 100644 index 00000000000..36ee4b12c83 --- /dev/null +++ b/src/server/shared/Database/DatabaseLoader.cpp @@ -0,0 +1,191 @@ +/* + * Copyright (C) 2008-2015 TrinityCore <http://www.trinitycore.org/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "DatabaseLoader.h" +#include "DBUpdater.h" +#include "Config.h" + +#include <mysqld_error.h> + +DatabaseLoader::DatabaseLoader(std::string const& logger, uint32 const defaultUpdateMask) + : _logger(logger), _autoSetup(sConfigMgr->GetBoolDefault("Updates.AutoSetup", true)), + _updateFlags(sConfigMgr->GetIntDefault("Updates.EnableDatabases", defaultUpdateMask)) +{ +} + +template <class T> +DatabaseLoader& DatabaseLoader::AddDatabase(DatabaseWorkerPool<T>& pool, std::string const& name) +{ + bool const updatesEnabledForThis = DBUpdater<T>::IsEnabled(_updateFlags); + + _open.push(std::make_pair([this, name, updatesEnabledForThis, &pool]() -> bool + { + std::string const dbString = sConfigMgr->GetStringDefault(name + "DatabaseInfo", ""); + if (dbString.empty()) + { + TC_LOG_ERROR(_logger.c_str(), "Database %s not specified in configuration file!", name.c_str()); + return false; + } + + uint8 const asyncThreads = uint8(sConfigMgr->GetIntDefault(name + "Database.WorkerThreads", 1)); + if (asyncThreads < 1 || asyncThreads > 32) + { + TC_LOG_ERROR(_logger.c_str(), "%s database: invalid number of worker threads specified. " + "Please pick a value between 1 and 32.", name.c_str()); + return false; + } + + uint8 const synchThreads = uint8(sConfigMgr->GetIntDefault(name + "Database.SynchThreads", 1)); + + pool.SetConnectionInfo(dbString, asyncThreads, synchThreads); + if (uint32 error = pool.Open()) + { + // Database does not exist + if ((error == ER_BAD_DB_ERROR) && updatesEnabledForThis && _autoSetup) + { + // Try to create the database and connect again if auto setup is enabled + if (DBUpdater<T>::Create(pool) && (!pool.Open())) + error = 0; + } + + // If the error wasn't handled quit + if (error) + { + TC_LOG_ERROR("sql.driver", "\nDatabasePool %s NOT opened. There were errors opening the MySQL connections. Check your SQLDriverLogFile " + "for specific errors. Read wiki at http://collab.kpsn.org/display/tc/TrinityCore+Home", name.c_str()); + + return false; + } + } + return true; + }, + [&pool]() + { + pool.Close(); + })); + + // Populate and update only if updates are enabled for this pool + if (updatesEnabledForThis) + { + _populate.push([this, name, &pool]() -> bool + { + if (!DBUpdater<T>::Populate(pool)) + { + TC_LOG_ERROR(_logger.c_str(), "Could not populate the %s database, see log for details.", name.c_str()); + return false; + } + return true; + }); + + _update.push([this, name, &pool]() -> bool + { + if (!DBUpdater<T>::Update(pool)) + { + TC_LOG_ERROR(_logger.c_str(), "Could not update the %s database, see log for details.", name.c_str()); + return false; + } + return true; + }); + } + + _prepare.push([this, name, &pool]() -> bool + { + if (!pool.PrepareStatements()) + { + TC_LOG_ERROR(_logger.c_str(), "Could not prepare statements of the %s database, see log for details.", name.c_str()); + return false; + } + return true; + }); + + return *this; +} + +bool DatabaseLoader::Load() +{ + if (!OpenDatabases()) + return false; + + if (!PopulateDatabases()) + return false; + + if (!UpdateDatabases()) + return false; + + if (!PrepareStatements()) + return false; + + return true; +} + +bool DatabaseLoader::OpenDatabases() +{ + while (!_open.empty()) + { + std::pair<Predicate, std::function<void()>> const load = _open.top(); + if (load.first()) + _close.push(load.second); + else + { + // Close all loaded databases + while (!_close.empty()) + { + _close.top()(); + _close.pop(); + } + return false; + } + + _open.pop(); + } + return true; +} + +// Processes the elements of the given stack until a predicate returned false. +bool DatabaseLoader::Process(std::stack<Predicate>& stack) +{ + while (!stack.empty()) + { + if (!stack.top()()) + return false; + + stack.pop(); + } + return true; +} + +bool DatabaseLoader::PopulateDatabases() +{ + return Process(_populate); +} + +bool DatabaseLoader::UpdateDatabases() +{ + return Process(_update); +} + +bool DatabaseLoader::PrepareStatements() +{ + return Process(_prepare); +} + +template +DatabaseLoader& DatabaseLoader::AddDatabase<LoginDatabaseConnection>(DatabaseWorkerPool<LoginDatabaseConnection>& pool, std::string const& name); +template +DatabaseLoader& DatabaseLoader::AddDatabase<WorldDatabaseConnection>(DatabaseWorkerPool<WorldDatabaseConnection>& pool, std::string const& name); +template +DatabaseLoader& DatabaseLoader::AddDatabase<CharacterDatabaseConnection>(DatabaseWorkerPool<CharacterDatabaseConnection>& pool, std::string const& name); diff --git a/src/server/shared/Database/DatabaseLoader.h b/src/server/shared/Database/DatabaseLoader.h new file mode 100644 index 00000000000..d35597ba807 --- /dev/null +++ b/src/server/shared/Database/DatabaseLoader.h @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2008-2015 TrinityCore <http://www.trinitycore.org/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef DatabaseLoader_h__ +#define DatabaseLoader_h__ + +#include "DatabaseWorkerPool.h" +#include "DatabaseEnv.h" + +#include <stack> +#include <functional> + +// A helper class to initiate all database worker pools, +// handles updating, delays preparing of statements and cleans up on failure. +class DatabaseLoader +{ +public: + DatabaseLoader(std::string const& logger, uint32 const defaultUpdateMask); + + // Register a database to the loader (lazy implemented) + template <class T> + DatabaseLoader& AddDatabase(DatabaseWorkerPool<T>& pool, std::string const& name); + + // Load all databases + bool Load(); + + enum DatabaseTypeFlags + { + DATABASE_NONE = 0, + + DATABASE_LOGIN = 1, + DATABASE_CHARACTER = 2, + DATABASE_WORLD = 4, + + DATABASE_MASK_ALL = DATABASE_LOGIN | DATABASE_CHARACTER | DATABASE_WORLD + }; + +private: + bool OpenDatabases(); + bool PopulateDatabases(); + bool UpdateDatabases(); + bool PrepareStatements(); + + using Predicate = std::function<bool()>; + + static bool Process(std::stack<Predicate>& stack); + + std::string const _logger; + bool const _autoSetup; + uint32 const _updateFlags; + + std::stack<std::pair<Predicate, std::function<void()>>> _open; + std::stack<std::function<void()>> _close; + std::stack<Predicate> _populate, _update, _prepare; +}; + +#endif // DatabaseLoader_h__ diff --git a/src/server/shared/Database/DatabaseWorker.cpp b/src/server/shared/Database/DatabaseWorker.cpp index 1fe638552a0..56757ce12a0 100644 --- a/src/server/shared/Database/DatabaseWorker.cpp +++ b/src/server/shared/Database/DatabaseWorker.cpp @@ -18,8 +18,6 @@ #include "DatabaseEnv.h" #include "DatabaseWorker.h" #include "SQLOperation.h" -#include "MySQLConnection.h" -#include "MySQLThreading.h" #include "ProducerConsumerQueue.h" DatabaseWorker::DatabaseWorker(ProducerConsumerQueue<SQLOperation*>* newQueue, MySQLConnection* connection) diff --git a/src/server/shared/Database/DatabaseWorkerPool.h b/src/server/shared/Database/DatabaseWorkerPool.h index f0ddbe91ad8..6bc204fbf76 100644 --- a/src/server/shared/Database/DatabaseWorkerPool.h +++ b/src/server/shared/Database/DatabaseWorkerPool.h @@ -28,8 +28,10 @@ #include "QueryResult.h" #include "QueryHolder.h" #include "AdhocStatement.h" +#include "StringFormat.h" #include <mysqld_error.h> +#include <memory> #define MIN_MYSQL_SERVER_VERSION 50100u #define MIN_MYSQL_CLIENT_VERSION 50100u @@ -57,9 +59,9 @@ class DatabaseWorkerPool public: /* Activity state */ - DatabaseWorkerPool() : _connectionInfo(NULL) + DatabaseWorkerPool() : _queue(new ProducerConsumerQueue<SQLOperation*>()), + _async_threads(0), _synch_threads(0) { - _queue = new ProducerConsumerQueue<SQLOperation*>(); memset(_connectionCount, 0, sizeof(_connectionCount)); _connections.resize(IDX_SIZE); @@ -70,31 +72,37 @@ class DatabaseWorkerPool ~DatabaseWorkerPool() { _queue->Cancel(); + } - delete _queue; + void SetConnectionInfo(std::string const& infoString, uint8 const asyncThreads, uint8 const synchThreads) + { + _connectionInfo.reset(new MySQLConnectionInfo(infoString)); - delete _connectionInfo; + _async_threads = asyncThreads; + _synch_threads = synchThreads; } - bool Open(const std::string& infoString, uint8 async_threads, uint8 synch_threads) + uint32 Open() { - _connectionInfo = new MySQLConnectionInfo(infoString); + WPFatal(_connectionInfo.get(), "Connection info was not set!"); TC_LOG_INFO("sql.driver", "Opening DatabasePool '%s'. Asynchronous connections: %u, synchronous connections: %u.", - GetDatabaseName(), async_threads, synch_threads); + GetDatabaseName(), _async_threads, _synch_threads); - bool res = OpenConnections(IDX_ASYNC, async_threads); + uint32 error = OpenConnections(IDX_ASYNC, _async_threads); - if (!res) - return res; + if (error) + return error; - res = OpenConnections(IDX_SYNCH, synch_threads); + error = OpenConnections(IDX_SYNCH, _synch_threads); - if (res) + if (!error) + { TC_LOG_INFO("sql.driver", "DatabasePool '%s' opened successfully. %u total connections running.", GetDatabaseName(), (_connectionCount[IDX_SYNCH] + _connectionCount[IDX_ASYNC])); + } - return res; + return error; } void Close() @@ -120,6 +128,32 @@ class DatabaseWorkerPool TC_LOG_INFO("sql.driver", "All connections on DatabasePool '%s' closed.", GetDatabaseName()); } + //! Prepares all prepared statements + bool PrepareStatements() + { + for (uint8 i = 0; i < IDX_SIZE; ++i) + for (uint32 c = 0; c < _connectionCount[i]; ++c) + { + T* t = _connections[i][c]; + t->LockIfReady(); + if (!t->PrepareStatements()) + { + t->Unlock(); + Close(); + return false; + } + else + t->Unlock(); + } + + return true; + } + + inline MySQLConnectionInfo const* GetConnectionInfo() const + { + return _connectionInfo.get(); + } + /** Delayed one-way statement methods. */ @@ -137,18 +171,13 @@ class DatabaseWorkerPool //! Enqueues a one-way SQL operation in string format -with variable args- that will be executed asynchronously. //! This method should only be used for queries that are only executed once, e.g during startup. - void PExecute(const char* sql, ...) + template<typename... Args> + void PExecute(const char* sql, Args const&... args) { if (!sql) return; - va_list ap; - char szQuery[MAX_QUERY_LEN]; - va_start(ap, sql); - vsnprintf(szQuery, MAX_QUERY_LEN, sql, ap); - va_end(ap); - - Execute(szQuery); + Execute(Trinity::StringFormat(sql, args...).c_str()); } //! Enqueues a one-way SQL operation in prepared statement format that will be executed asynchronously. @@ -177,18 +206,13 @@ class DatabaseWorkerPool //! Directly executes a one-way SQL operation in string format -with variable args-, that will block the calling thread until finished. //! This method should only be used for queries that are only executed once, e.g during startup. - void DirectPExecute(const char* sql, ...) + template<typename... Args> + void DirectPExecute(const char* sql, Args const&... args) { if (!sql) return; - va_list ap; - char szQuery[MAX_QUERY_LEN]; - va_start(ap, sql); - vsnprintf(szQuery, MAX_QUERY_LEN, sql, ap); - va_end(ap); - - return DirectExecute(szQuery); + DirectExecute(Trinity::StringFormat(sql, args...).c_str()); } //! Directly executes a one-way SQL operation in prepared statement format, that will block the calling thread until finished. @@ -227,34 +251,24 @@ class DatabaseWorkerPool //! Directly executes an SQL query in string format -with variable args- that will block the calling thread until finished. //! Returns reference counted auto pointer, no need for manual memory management in upper level code. - QueryResult PQuery(const char* sql, T* conn, ...) + template<typename... Args> + QueryResult PQuery(const char* sql, T* conn, Args const&... args) { if (!sql) return QueryResult(NULL); - va_list ap; - char szQuery[MAX_QUERY_LEN]; - va_start(ap, conn); - vsnprintf(szQuery, MAX_QUERY_LEN, sql, ap); - va_end(ap); - - return Query(szQuery, conn); + return Query(Trinity::StringFormat(sql, args...).c_str(), conn); } //! Directly executes an SQL query in string format -with variable args- that will block the calling thread until finished. //! Returns reference counted auto pointer, no need for manual memory management in upper level code. - QueryResult PQuery(const char* sql, ...) + template<typename... Args> + QueryResult PQuery(const char* sql, Args const&... args) { if (!sql) return QueryResult(NULL); - va_list ap; - char szQuery[MAX_QUERY_LEN]; - va_start(ap, sql); - vsnprintf(szQuery, MAX_QUERY_LEN, sql, ap); - va_end(ap); - - return Query(szQuery); + return Query(Trinity::StringFormat(sql, args...).c_str()); } //! Directly executes an SQL query in prepared format that will block the calling thread until finished. @@ -295,15 +309,10 @@ class DatabaseWorkerPool //! Enqueues a query in string format -with variable args- that will set the value of the QueryResultFuture return object as soon as the query is executed. //! The return value is then processed in ProcessQueryCallback methods. - QueryResultFuture AsyncPQuery(const char* sql, ...) + template<typename... Args> + QueryResultFuture AsyncPQuery(const char* sql, Args const&... args) { - va_list ap; - char szQuery[MAX_QUERY_LEN]; - va_start(ap, sql); - vsnprintf(szQuery, MAX_QUERY_LEN, sql, ap); - va_end(ap); - - return AsyncQuery(szQuery); + return AsyncQuery(Trinity::StringFormat(sql, args...).c_str()); } //! Enqueues a query in prepared format that will set the value of the PreparedQueryResultFuture return object as soon as the query is executed. @@ -461,7 +470,7 @@ class DatabaseWorkerPool } private: - bool OpenConnections(InternalIndex type, uint8 numConnections) + uint32 OpenConnections(InternalIndex type, uint8 numConnections) { _connections[type].resize(numConnections); for (uint8 i = 0; i < numConnections; ++i) @@ -469,7 +478,7 @@ class DatabaseWorkerPool T* t; if (type == IDX_ASYNC) - t = new T(_queue, *_connectionInfo); + t = new T(_queue.get(), *_connectionInfo); else if (type == IDX_SYNCH) t = new T(*_connectionInfo); else @@ -478,35 +487,32 @@ class DatabaseWorkerPool _connections[type][i] = t; ++_connectionCount[type]; - bool res = t->Open(); + uint32 error = t->Open(); - if (res) + if (!error) { if (mysql_get_server_version(t->GetHandle()) < MIN_MYSQL_SERVER_VERSION) { TC_LOG_ERROR("sql.driver", "TrinityCore does not support MySQL versions below 5.1"); - res = false; + error = 1; } } // Failed to open a connection or invalid version, abort and cleanup - if (!res) + if (error) { - TC_LOG_ERROR("sql.driver", "DatabasePool %s NOT opened. There were errors opening the MySQL connections. Check your SQLDriverLogFile " - "for specific errors. Read wiki at http://collab.kpsn.org/display/tc/TrinityCore+Home", GetDatabaseName()); - while (_connectionCount[type] != 0) { T* t = _connections[type][i--]; delete t; --_connectionCount[type]; } - - return false; + return error; } } - return true; + // Everything is fine + return 0; } unsigned long EscapeString(char *to, const char *from, unsigned long length) @@ -546,10 +552,13 @@ class DatabaseWorkerPool return _connectionInfo->database.c_str(); } - ProducerConsumerQueue<SQLOperation*>* _queue; //! Queue shared by async worker threads. - std::vector< std::vector<T*> > _connections; - uint32 _connectionCount[2]; //! Counter of MySQL connections; - MySQLConnectionInfo* _connectionInfo; + //! Queue shared by async worker threads. + std::unique_ptr<ProducerConsumerQueue<SQLOperation*>> _queue; + std::vector<std::vector<T*>> _connections; + //! Counter of MySQL connections; + uint32 _connectionCount[IDX_SIZE]; + std::unique_ptr<MySQLConnectionInfo> _connectionInfo; + uint8 _async_threads, _synch_threads; }; #endif diff --git a/src/server/shared/Database/Field.h b/src/server/shared/Database/Field.h index b88fb56cb06..1bbd264482f 100644 --- a/src/server/shared/Database/Field.h +++ b/src/server/shared/Database/Field.h @@ -255,11 +255,7 @@ class Field Field(); ~Field(); - #if defined(__GNUC__) - #pragma pack(1) - #else #pragma pack(push, 1) - #endif struct { uint32 length; // Length (prepared strings only) @@ -267,11 +263,7 @@ class Field enum_field_types type; // Field type bool raw; // Raw bytes? (Prepared statement or ad hoc) } data; - #if defined(__GNUC__) - #pragma pack() - #else #pragma pack(pop) - #endif void SetByteValue(void const* newValue, size_t const newSize, enum_field_types newType, uint32 length); void SetStructuredValue(char* newValue, enum_field_types newType); diff --git a/src/server/shared/Database/Implementation/CharacterDatabase.cpp b/src/server/shared/Database/Implementation/CharacterDatabase.cpp index 265b0578841..1ca01501d01 100644 --- a/src/server/shared/Database/Implementation/CharacterDatabase.cpp +++ b/src/server/shared/Database/Implementation/CharacterDatabase.cpp @@ -60,7 +60,8 @@ void CharacterDatabaseConnection::DoPrepareStatements() PrepareStatement(CHAR_SEL_CHAR_POSITION_XYZ, "SELECT map, position_x, position_y, position_z FROM characters WHERE guid = ?", CONNECTION_SYNCH); PrepareStatement(CHAR_SEL_CHAR_POSITION, "SELECT position_x, position_y, position_z, orientation, map, taxi_path FROM characters WHERE guid = ?", CONNECTION_SYNCH); - PrepareStatement(CHAR_DEL_BATTLEGROUND_RANDOM, "DELETE FROM character_battleground_random", CONNECTION_ASYNC); + PrepareStatement(CHAR_DEL_BATTLEGROUND_RANDOM_ALL, "DELETE FROM character_battleground_random", CONNECTION_ASYNC); + PrepareStatement(CHAR_DEL_BATTLEGROUND_RANDOM, "DELETE FROM character_battleground_random WHERE guid = ?", CONNECTION_ASYNC); PrepareStatement(CHAR_INS_BATTLEGROUND_RANDOM, "INSERT INTO character_battleground_random (guid) VALUES (?)", CONNECTION_ASYNC); PrepareStatement(CHAR_SEL_CHARACTER, "SELECT guid, account, name, race, class, gender, level, xp, money, playerBytes, playerBytes2, playerFlags, " @@ -104,7 +105,7 @@ void CharacterDatabaseConnection::DoPrepareStatements() PrepareStatement(CHAR_SEL_MAIL_COUNT, "SELECT COUNT(*) FROM mail WHERE receiver = ?", CONNECTION_SYNCH); PrepareStatement(CHAR_SEL_CHARACTER_SOCIALLIST, "SELECT friend, flags, note FROM character_social JOIN characters ON characters.guid = character_social.friend WHERE character_social.guid = ? AND deleteinfos_name IS NULL LIMIT 255", CONNECTION_ASYNC); PrepareStatement(CHAR_SEL_CHARACTER_HOMEBIND, "SELECT mapId, zoneId, posX, posY, posZ FROM character_homebind WHERE guid = ?", CONNECTION_ASYNC); - PrepareStatement(CHAR_SEL_CHARACTER_SPELLCOOLDOWNS, "SELECT spell, item, time FROM character_spell_cooldown WHERE guid = ?", CONNECTION_ASYNC); + PrepareStatement(CHAR_SEL_CHARACTER_SPELLCOOLDOWNS, "SELECT spell, item, time FROM character_spell_cooldown WHERE guid = ? AND time > UNIX_TIMESTAMP()", CONNECTION_ASYNC); PrepareStatement(CHAR_SEL_CHARACTER_DECLINEDNAMES, "SELECT genitive, dative, accusative, instrumental, prepositional FROM character_declinedname WHERE guid = ?", CONNECTION_ASYNC); PrepareStatement(CHAR_SEL_GUILD_MEMBER, "SELECT guildid, rank FROM guild_member WHERE guid = ?", CONNECTION_BOTH); PrepareStatement(CHAR_SEL_GUILD_MEMBER_EXTENDED, "SELECT g.guildid, g.name, gr.rname, gr.rid, gm.pnote, gm.offnote " @@ -300,6 +301,7 @@ void CharacterDatabaseConnection::DoPrepareStatements() PrepareStatement(CHAR_DEL_ARENA_TEAM_MEMBER, "DELETE FROM arena_team_member WHERE arenaTeamId = ? AND guid = ?", CONNECTION_ASYNC); PrepareStatement(CHAR_UPD_ARENA_TEAM_STATS, "UPDATE arena_team SET rating = ?, weekGames = ?, weekWins = ?, seasonGames = ?, seasonWins = ?, rank = ? WHERE arenaTeamId = ?", CONNECTION_ASYNC); PrepareStatement(CHAR_UPD_ARENA_TEAM_MEMBER, "UPDATE arena_team_member SET personalRating = ?, weekGames = ?, weekWins = ?, seasonGames = ?, seasonWins = ? WHERE arenaTeamId = ? AND guid = ?", CONNECTION_ASYNC); + PrepareStatement(CHAR_DEL_CHARACTER_ARENA_STATS, "DELETE FROM character_arena_stats WHERE guid = ?", CONNECTION_ASYNC); PrepareStatement(CHAR_REP_CHARACTER_ARENA_STATS, "REPLACE INTO character_arena_stats (guid, slot, matchMakerRating) VALUES (?, ?, ?)", CONNECTION_ASYNC); PrepareStatement(CHAR_SEL_PLAYER_ARENA_TEAMS, "SELECT arena_team_member.arenaTeamId FROM arena_team_member JOIN arena_team ON arena_team_member.arenaTeamId = arena_team.arenaTeamId WHERE guid = ?", CONNECTION_SYNCH); PrepareStatement(CHAR_UPD_ARENA_TEAM_NAME, "UPDATE arena_team SET name = ? WHERE arenaTeamId = ?", CONNECTION_ASYNC); @@ -495,7 +497,8 @@ void CharacterDatabaseConnection::DoPrepareStatements() PrepareStatement(CHAR_UPD_CHAR_REP_FACTION_CHANGE, "UPDATE character_reputation SET faction = ?, standing = ? WHERE faction = ? AND guid = ?", CONNECTION_ASYNC); PrepareStatement(CHAR_UPD_CHAR_TITLES_FACTION_CHANGE, "UPDATE characters SET knownTitles = ? WHERE guid = ?", CONNECTION_ASYNC); PrepareStatement(CHAR_RES_CHAR_TITLES_FACTION_CHANGE, "UPDATE characters SET chosenTitle = 0 WHERE guid = ?", CONNECTION_ASYNC); - PrepareStatement(CHAR_DEL_CHAR_SPELL_COOLDOWN, "DELETE FROM character_spell_cooldown WHERE guid = ?", CONNECTION_ASYNC); + PrepareStatement(CHAR_DEL_CHAR_SPELL_COOLDOWNS, "DELETE FROM character_spell_cooldown WHERE guid = ?", CONNECTION_ASYNC); + PrepareStatement(CHAR_INS_CHAR_SPELL_COOLDOWN, "INSERT INTO character_spell_cooldown (guid, spell, item, time) VALUES (?, ?, ?, ?)", CONNECTION_ASYNC); PrepareStatement(CHAR_DEL_CHARACTER, "DELETE FROM characters WHERE guid = ?", CONNECTION_ASYNC); PrepareStatement(CHAR_DEL_CHAR_ACTION, "DELETE FROM character_action WHERE guid = ?", CONNECTION_ASYNC); PrepareStatement(CHAR_DEL_CHAR_AURA, "DELETE FROM character_aura WHERE guid = ?", CONNECTION_ASYNC); @@ -575,7 +578,7 @@ void CharacterDatabaseConnection::DoPrepareStatements() PrepareStatement(CHAR_INS_CHAR_PET_DECLINEDNAME, "INSERT INTO character_pet_declinedname (id, owner, genitive, dative, accusative, instrumental, prepositional) VALUES (?, ?, ?, ?, ?, ?, ?)", CONNECTION_ASYNC); PrepareStatement(CHAR_SEL_PET_AURA, "SELECT caster_guid, spell, effect_mask, recalculate_mask, stackcount, amount0, amount1, amount2, base_amount0, base_amount1, base_amount2, maxduration, remaintime, remaincharges FROM pet_aura WHERE guid = ?", CONNECTION_SYNCH); PrepareStatement(CHAR_SEL_PET_SPELL, "SELECT spell, active FROM pet_spell WHERE guid = ?", CONNECTION_SYNCH); - PrepareStatement(CHAR_SEL_PET_SPELL_COOLDOWN, "SELECT spell, time FROM pet_spell_cooldown WHERE guid = ?", CONNECTION_SYNCH); + PrepareStatement(CHAR_SEL_PET_SPELL_COOLDOWN, "SELECT spell, time FROM pet_spell_cooldown WHERE guid = ? AND time > UNIX_TIMESTAMP()", CONNECTION_SYNCH); PrepareStatement(CHAR_SEL_PET_DECLINED_NAME, "SELECT genitive, dative, accusative, instrumental, prepositional FROM character_pet_declinedname WHERE owner = ? AND id = ?", CONNECTION_SYNCH); PrepareStatement(CHAR_DEL_PET_AURAS, "DELETE FROM pet_aura WHERE guid = ?", CONNECTION_BOTH); PrepareStatement(CHAR_DEL_PET_SPELLS, "DELETE FROM pet_spell WHERE guid = ?", CONNECTION_ASYNC); diff --git a/src/server/shared/Database/Implementation/CharacterDatabase.h b/src/server/shared/Database/Implementation/CharacterDatabase.h index f7ff5b9186e..f88a912e022 100644 --- a/src/server/shared/Database/Implementation/CharacterDatabase.h +++ b/src/server/shared/Database/Implementation/CharacterDatabase.h @@ -72,6 +72,7 @@ enum CharacterDatabaseStatements CHAR_SEL_CHAR_POSITION_XYZ, CHAR_SEL_CHAR_POSITION, + CHAR_DEL_BATTLEGROUND_RANDOM_ALL, CHAR_DEL_BATTLEGROUND_RANDOM, CHAR_INS_BATTLEGROUND_RANDOM, @@ -253,6 +254,7 @@ enum CharacterDatabaseStatements CHAR_DEL_ARENA_TEAM_MEMBER, CHAR_UPD_ARENA_TEAM_STATS, CHAR_UPD_ARENA_TEAM_MEMBER, + CHAR_DEL_CHARACTER_ARENA_STATS, CHAR_REP_CHARACTER_ARENA_STATS, CHAR_SEL_PLAYER_ARENA_TEAMS, CHAR_UPD_ARENA_TEAM_NAME, @@ -434,7 +436,8 @@ enum CharacterDatabaseStatements CHAR_UPD_CHAR_REP_FACTION_CHANGE, CHAR_UPD_CHAR_TITLES_FACTION_CHANGE, CHAR_RES_CHAR_TITLES_FACTION_CHANGE, - CHAR_DEL_CHAR_SPELL_COOLDOWN, + CHAR_DEL_CHAR_SPELL_COOLDOWNS, + CHAR_INS_CHAR_SPELL_COOLDOWN, CHAR_DEL_CHARACTER, CHAR_DEL_CHAR_ACTION, CHAR_DEL_CHAR_AURA, diff --git a/src/server/shared/Database/MySQLConnection.cpp b/src/server/shared/Database/MySQLConnection.cpp index 1a9f973d47b..10f4a7baa18 100644 --- a/src/server/shared/Database/MySQLConnection.cpp +++ b/src/server/shared/Database/MySQLConnection.cpp @@ -22,11 +22,9 @@ #include <winsock2.h> #endif #include <mysql.h> -#include <mysqld_error.h> #include <errmsg.h> #include "MySQLConnection.h" -#include "MySQLThreading.h" #include "QueryResult.h" #include "SQLOperation.h" #include "PreparedStatement.h" @@ -72,14 +70,14 @@ void MySQLConnection::Close() delete this; } -bool MySQLConnection::Open() +uint32 MySQLConnection::Open() { MYSQL *mysqlInit; mysqlInit = mysql_init(NULL); if (!mysqlInit) { TC_LOG_ERROR("sql.sql", "Could not initialize Mysql connection to database `%s`", m_connectionInfo.database.c_str()); - return false; + return CR_UNKNOWN_ERROR; } int port; @@ -137,13 +135,13 @@ bool MySQLConnection::Open() // set connection properties to UTF8 to properly handle locales for different // server configs - core sends data in UTF8, so MySQL must expect UTF8 too mysql_set_character_set(m_Mysql, "utf8"); - return PrepareStatements(); + return 0; } else { - TC_LOG_ERROR("sql.sql", "Could not connect to MySQL database at %s: %s\n", m_connectionInfo.host.c_str(), mysql_error(mysqlInit)); + TC_LOG_ERROR("sql.sql", "Could not connect to MySQL database at %s: %s", m_connectionInfo.host.c_str(), mysql_error(mysqlInit)); mysql_close(mysqlInit); - return false; + return mysql_errno(mysqlInit); } } @@ -491,8 +489,18 @@ bool MySQLConnection::_HandleMySQLErrno(uint32 errNo) m_reconnecting = true; uint64 oldThreadId = mysql_thread_id(GetHandle()); mysql_close(GetHandle()); - if (this->Open()) // Don't remove 'this' pointer unless you want to skip loading all prepared statements.... + + uint32 const lErrno = Open(); + if (!lErrno) { + // Don't remove 'this' pointer unless you want to skip loading all prepared statements... + if (!this->PrepareStatements()) + { + TC_LOG_ERROR("sql.sql", "Could not re-prepare statements!"); + Close(); + return false; + } + TC_LOG_INFO("sql.sql", "Connection to the MySQL server is active."); if (oldThreadId != mysql_thread_id(GetHandle())) TC_LOG_INFO("sql.sql", "Successfully reconnected to %s @%s:%s (%s).", @@ -503,7 +511,7 @@ bool MySQLConnection::_HandleMySQLErrno(uint32 errNo) return true; } - uint32 lErrno = mysql_errno(GetHandle()); // It's possible this attempted reconnect throws 2006 at us. To prevent crazy recursive calls, sleep here. + // It's possible this attempted reconnect throws 2006 at us. To prevent crazy recursive calls, sleep here. std::this_thread::sleep_for(std::chrono::seconds(3)); // Sleep 3 seconds return _HandleMySQLErrno(lErrno); // Call self (recursive) } diff --git a/src/server/shared/Database/MySQLConnection.h b/src/server/shared/Database/MySQLConnection.h index d486f5b4679..78d8d2fb5dd 100644 --- a/src/server/shared/Database/MySQLConnection.h +++ b/src/server/shared/Database/MySQLConnection.h @@ -72,9 +72,11 @@ class MySQLConnection MySQLConnection(ProducerConsumerQueue<SQLOperation*>* queue, MySQLConnectionInfo& connInfo); //! Constructor for asynchronous connections. virtual ~MySQLConnection(); - virtual bool Open(); + virtual uint32 Open(); void Close(); + bool PrepareStatements(); + public: bool Execute(const char* sql); bool Execute(PreparedStatement* stmt); @@ -111,7 +113,6 @@ class MySQLConnection MySQLPreparedStatement* GetPreparedStatement(uint32 index); void PrepareStatement(uint32 index, const char* sql, ConnectionFlags flags); - bool PrepareStatements(); virtual void DoPrepareStatements() = 0; protected: diff --git a/src/server/shared/Database/QueryHolder.cpp b/src/server/shared/Database/QueryHolder.cpp index 75b96e1996c..2fdb3825526 100644 --- a/src/server/shared/Database/QueryHolder.cpp +++ b/src/server/shared/Database/QueryHolder.cpp @@ -40,29 +40,6 @@ bool SQLQueryHolder::SetQuery(size_t index, const char *sql) return true; } -bool SQLQueryHolder::SetPQuery(size_t index, const char *format, ...) -{ - if (!format) - { - TC_LOG_ERROR("sql.sql", "Query (index: %u) is empty.", uint32(index)); - return false; - } - - va_list ap; - char szQuery [MAX_QUERY_LEN]; - va_start(ap, format); - int res = vsnprintf(szQuery, MAX_QUERY_LEN, format, ap); - va_end(ap); - - if (res == -1) - { - TC_LOG_ERROR("sql.sql", "SQL Query truncated (and not execute) for format: %s", format); - return false; - } - - return SetQuery(index, szQuery); -} - bool SQLQueryHolder::SetPreparedQuery(size_t index, PreparedStatement* stmt) { if (m_queries.size() <= index) diff --git a/src/server/shared/Database/QueryHolder.h b/src/server/shared/Database/QueryHolder.h index 6fd1901447e..4102bba1223 100644 --- a/src/server/shared/Database/QueryHolder.h +++ b/src/server/shared/Database/QueryHolder.h @@ -29,8 +29,9 @@ class SQLQueryHolder public: SQLQueryHolder() { } ~SQLQueryHolder(); - bool SetQuery(size_t index, const char *sql); - bool SetPQuery(size_t index, const char *format, ...) ATTR_PRINTF(3, 4); + bool SetQuery(size_t index, const char* sql); + template<typename... Args> + bool SetPQuery(size_t index, const char* sql, Args const&... args) { return SetQuery(index, Trinity::StringFormat(sql, args...).c_str()); } bool SetPreparedQuery(size_t index, PreparedStatement* stmt); void SetSize(size_t size); QueryResult GetResult(size_t index); diff --git a/src/server/shared/Database/Transaction.cpp b/src/server/shared/Database/Transaction.cpp index 9f36d198bde..f657411f716 100644 --- a/src/server/shared/Database/Transaction.cpp +++ b/src/server/shared/Database/Transaction.cpp @@ -30,17 +30,6 @@ void Transaction::Append(const char* sql) m_queries.push_back(data); } -void Transaction::PAppend(const char* sql, ...) -{ - va_list ap; - char szQuery [MAX_QUERY_LEN]; - va_start(ap, sql); - vsnprintf(szQuery, MAX_QUERY_LEN, sql, ap); - va_end(ap); - - Append(szQuery); -} - //- Append a prepared statement to the transaction void Transaction::Append(PreparedStatement* stmt) { diff --git a/src/server/shared/Database/Transaction.h b/src/server/shared/Database/Transaction.h index 83d59006ddc..a51c96ea93b 100644 --- a/src/server/shared/Database/Transaction.h +++ b/src/server/shared/Database/Transaction.h @@ -19,6 +19,7 @@ #define _TRANSACTION_H #include "SQLOperation.h" +#include "StringFormat.h" //- Forward declare (don't include header to prevent circular includes) class PreparedStatement; @@ -38,7 +39,8 @@ class Transaction void Append(PreparedStatement* statement); void Append(const char* sql); - void PAppend(const char* sql, ...); + template<typename... Args> + void PAppend(const char* sql, Args const&... args) { Append(Trinity::StringFormat(sql, args...).c_str()); } size_t GetSize() const { return m_queries.size(); } @@ -59,8 +61,8 @@ class TransactionTask : public SQLOperation friend class DatabaseWorker; public: - TransactionTask(SQLTransaction trans) : m_trans(trans) { } ; - ~TransactionTask(){ }; + TransactionTask(SQLTransaction trans) : m_trans(trans) { } + ~TransactionTask() { } protected: bool Execute() override; diff --git a/src/server/shared/Debugging/WheatyExceptionReport.cpp b/src/server/shared/Debugging/WheatyExceptionReport.cpp index e9f4f9ca9ac..e9f888f280d 100644 --- a/src/server/shared/Debugging/WheatyExceptionReport.cpp +++ b/src/server/shared/Debugging/WheatyExceptionReport.cpp @@ -61,6 +61,9 @@ HANDLE WheatyExceptionReport::m_hDumpFile; HANDLE WheatyExceptionReport::m_hProcess; SymbolPairs WheatyExceptionReport::symbols; std::stack<SymbolDetail> WheatyExceptionReport::symbolDetails; +bool WheatyExceptionReport::stackOverflowException; +bool WheatyExceptionReport::alreadyCrashed; +std::mutex WheatyExceptionReport::alreadyCrashedLock; // Declare global instance of class WheatyExceptionReport g_WheatyExceptionReport; @@ -72,6 +75,8 @@ WheatyExceptionReport::WheatyExceptionReport() // Constructor // Install the unhandled exception filter function m_previousFilter = SetUnhandledExceptionFilter(WheatyUnhandledExceptionFilter); m_hProcess = GetCurrentProcess(); + stackOverflowException = false; + alreadyCrashed = false; if (!IsDebuggerPresent()) { _CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_FILE); @@ -97,6 +102,16 @@ WheatyExceptionReport::~WheatyExceptionReport() LONG WINAPI WheatyExceptionReport::WheatyUnhandledExceptionFilter( PEXCEPTION_POINTERS pExceptionInfo) { + std::unique_lock<std::mutex> guard(alreadyCrashedLock); + // Handle only 1 exception in the whole process lifetime + if (alreadyCrashed) + return EXCEPTION_EXECUTE_HANDLER; + + alreadyCrashed = true; + + if (pExceptionInfo->ExceptionRecord->ExceptionCode == STATUS_STACK_OVERFLOW) + stackOverflowException = true; + TCHAR module_folder_name[MAX_PATH]; GetModuleFileName(0, module_folder_name, MAX_PATH); TCHAR* pos = _tcsrchr(module_folder_name, '\\'); @@ -419,107 +434,114 @@ void WheatyExceptionReport::printTracesForAllThreads(bool bWriteVariables) void WheatyExceptionReport::GenerateExceptionReport( PEXCEPTION_POINTERS pExceptionInfo) { - SYSTEMTIME systime; - GetLocalTime(&systime); - - // Start out with a banner - _tprintf(_T("Revision: %s\r\n"), _FULLVERSION); - _tprintf(_T("Date %u:%u:%u. Time %u:%u \r\n"), systime.wDay, systime.wMonth, systime.wYear, systime.wHour, systime.wMinute); - PEXCEPTION_RECORD pExceptionRecord = pExceptionInfo->ExceptionRecord; - - PrintSystemInfo(); - // First print information about the type of fault - _tprintf(_T("\r\n//=====================================================\r\n")); - _tprintf(_T("Exception code: %08X %s\r\n"), - pExceptionRecord->ExceptionCode, - GetExceptionString(pExceptionRecord->ExceptionCode)); - - // Now print information about where the fault occured - TCHAR szFaultingModule[MAX_PATH]; - DWORD section; - DWORD_PTR offset; - GetLogicalAddress(pExceptionRecord->ExceptionAddress, - szFaultingModule, - sizeof(szFaultingModule), - section, offset); + __try + { + SYSTEMTIME systime; + GetLocalTime(&systime); + + // Start out with a banner + _tprintf(_T("Revision: %s\r\n"), _FULLVERSION); + _tprintf(_T("Date %u:%u:%u. Time %u:%u \r\n"), systime.wDay, systime.wMonth, systime.wYear, systime.wHour, systime.wMinute); + PEXCEPTION_RECORD pExceptionRecord = pExceptionInfo->ExceptionRecord; + + PrintSystemInfo(); + // First print information about the type of fault + _tprintf(_T("\r\n//=====================================================\r\n")); + _tprintf(_T("Exception code: %08X %s\r\n"), + pExceptionRecord->ExceptionCode, + GetExceptionString(pExceptionRecord->ExceptionCode)); + + // Now print information about where the fault occured + TCHAR szFaultingModule[MAX_PATH]; + DWORD section; + DWORD_PTR offset; + GetLogicalAddress(pExceptionRecord->ExceptionAddress, + szFaultingModule, + sizeof(szFaultingModule), + section, offset); #ifdef _M_IX86 - _tprintf(_T("Fault address: %08X %02X:%08X %s\r\n"), - pExceptionRecord->ExceptionAddress, - section, offset, szFaultingModule); + _tprintf(_T("Fault address: %08X %02X:%08X %s\r\n"), + pExceptionRecord->ExceptionAddress, + section, offset, szFaultingModule); #endif #ifdef _M_X64 - _tprintf(_T("Fault address: %016I64X %02X:%016I64X %s\r\n"), - pExceptionRecord->ExceptionAddress, - section, offset, szFaultingModule); + _tprintf(_T("Fault address: %016I64X %02X:%016I64X %s\r\n"), + pExceptionRecord->ExceptionAddress, + section, offset, szFaultingModule); #endif - PCONTEXT pCtx = pExceptionInfo->ContextRecord; + PCONTEXT pCtx = pExceptionInfo->ContextRecord; - // Show the registers - #ifdef _M_IX86 // X86 Only! - _tprintf(_T("\r\nRegisters:\r\n")); + // Show the registers +#ifdef _M_IX86 // X86 Only! + _tprintf(_T("\r\nRegisters:\r\n")); - _tprintf(_T("EAX:%08X\r\nEBX:%08X\r\nECX:%08X\r\nEDX:%08X\r\nESI:%08X\r\nEDI:%08X\r\n") - , pCtx->Eax, pCtx->Ebx, pCtx->Ecx, pCtx->Edx, - pCtx->Esi, pCtx->Edi); + _tprintf(_T("EAX:%08X\r\nEBX:%08X\r\nECX:%08X\r\nEDX:%08X\r\nESI:%08X\r\nEDI:%08X\r\n") + , pCtx->Eax, pCtx->Ebx, pCtx->Ecx, pCtx->Edx, + pCtx->Esi, pCtx->Edi); - _tprintf(_T("CS:EIP:%04X:%08X\r\n"), pCtx->SegCs, pCtx->Eip); - _tprintf(_T("SS:ESP:%04X:%08X EBP:%08X\r\n"), - pCtx->SegSs, pCtx->Esp, pCtx->Ebp); - _tprintf(_T("DS:%04X ES:%04X FS:%04X GS:%04X\r\n"), - pCtx->SegDs, pCtx->SegEs, pCtx->SegFs, pCtx->SegGs); - _tprintf(_T("Flags:%08X\r\n"), pCtx->EFlags); - #endif + _tprintf(_T("CS:EIP:%04X:%08X\r\n"), pCtx->SegCs, pCtx->Eip); + _tprintf(_T("SS:ESP:%04X:%08X EBP:%08X\r\n"), + pCtx->SegSs, pCtx->Esp, pCtx->Ebp); + _tprintf(_T("DS:%04X ES:%04X FS:%04X GS:%04X\r\n"), + pCtx->SegDs, pCtx->SegEs, pCtx->SegFs, pCtx->SegGs); + _tprintf(_T("Flags:%08X\r\n"), pCtx->EFlags); +#endif - #ifdef _M_X64 - _tprintf(_T("\r\nRegisters:\r\n")); - _tprintf(_T("RAX:%016I64X\r\nRBX:%016I64X\r\nRCX:%016I64X\r\nRDX:%016I64X\r\nRSI:%016I64X\r\nRDI:%016I64X\r\n") - _T("R8: %016I64X\r\nR9: %016I64X\r\nR10:%016I64X\r\nR11:%016I64X\r\nR12:%016I64X\r\nR13:%016I64X\r\nR14:%016I64X\r\nR15:%016I64X\r\n") - , pCtx->Rax, pCtx->Rbx, pCtx->Rcx, pCtx->Rdx, - pCtx->Rsi, pCtx->Rdi, pCtx->R9, pCtx->R10, pCtx->R11, pCtx->R12, pCtx->R13, pCtx->R14, pCtx->R15); - _tprintf(_T("CS:RIP:%04X:%016I64X\r\n"), pCtx->SegCs, pCtx->Rip); - _tprintf(_T("SS:RSP:%04X:%016X RBP:%08X\r\n"), - pCtx->SegSs, pCtx->Rsp, pCtx->Rbp); - _tprintf(_T("DS:%04X ES:%04X FS:%04X GS:%04X\r\n"), - pCtx->SegDs, pCtx->SegEs, pCtx->SegFs, pCtx->SegGs); - _tprintf(_T("Flags:%08X\r\n"), pCtx->EFlags); - #endif +#ifdef _M_X64 + _tprintf(_T("\r\nRegisters:\r\n")); + _tprintf(_T("RAX:%016I64X\r\nRBX:%016I64X\r\nRCX:%016I64X\r\nRDX:%016I64X\r\nRSI:%016I64X\r\nRDI:%016I64X\r\n") + _T("R8: %016I64X\r\nR9: %016I64X\r\nR10:%016I64X\r\nR11:%016I64X\r\nR12:%016I64X\r\nR13:%016I64X\r\nR14:%016I64X\r\nR15:%016I64X\r\n") + , pCtx->Rax, pCtx->Rbx, pCtx->Rcx, pCtx->Rdx, + pCtx->Rsi, pCtx->Rdi, pCtx->R9, pCtx->R10, pCtx->R11, pCtx->R12, pCtx->R13, pCtx->R14, pCtx->R15); + _tprintf(_T("CS:RIP:%04X:%016I64X\r\n"), pCtx->SegCs, pCtx->Rip); + _tprintf(_T("SS:RSP:%04X:%016X RBP:%08X\r\n"), + pCtx->SegSs, pCtx->Rsp, pCtx->Rbp); + _tprintf(_T("DS:%04X ES:%04X FS:%04X GS:%04X\r\n"), + pCtx->SegDs, pCtx->SegEs, pCtx->SegFs, pCtx->SegGs); + _tprintf(_T("Flags:%08X\r\n"), pCtx->EFlags); +#endif - SymSetOptions(SYMOPT_DEFERRED_LOADS); + SymSetOptions(SYMOPT_DEFERRED_LOADS); - // Initialize DbgHelp - if (!SymInitialize(GetCurrentProcess(), 0, TRUE)) - { - _tprintf(_T("\n\rCRITICAL ERROR.\n\r Couldn't initialize the symbol handler for process.\n\rError [%s].\n\r\n\r"), - ErrorMessage(GetLastError())); - } + // Initialize DbgHelp + if (!SymInitialize(GetCurrentProcess(), 0, TRUE)) + { + _tprintf(_T("\n\rCRITICAL ERROR.\n\r Couldn't initialize the symbol handler for process.\n\rError [%s].\n\r\n\r"), + ErrorMessage(GetLastError())); + } - CONTEXT trashableContext = *pCtx; + CONTEXT trashableContext = *pCtx; - WriteStackDetails(&trashableContext, false, NULL); - printTracesForAllThreads(false); + WriteStackDetails(&trashableContext, false, NULL); + printTracesForAllThreads(false); -// #ifdef _M_IX86 // X86 Only! + // #ifdef _M_IX86 // X86 Only! - _tprintf(_T("========================\r\n")); - _tprintf(_T("Local Variables And Parameters\r\n")); + _tprintf(_T("========================\r\n")); + _tprintf(_T("Local Variables And Parameters\r\n")); - trashableContext = *pCtx; - WriteStackDetails(&trashableContext, true, NULL); - printTracesForAllThreads(true); + trashableContext = *pCtx; + WriteStackDetails(&trashableContext, true, NULL); + printTracesForAllThreads(true); - /*_tprintf(_T("========================\r\n")); - _tprintf(_T("Global Variables\r\n")); + /*_tprintf(_T("========================\r\n")); + _tprintf(_T("Global Variables\r\n")); - SymEnumSymbols(GetCurrentProcess(), + SymEnumSymbols(GetCurrentProcess(), (UINT_PTR)GetModuleHandle(szFaultingModule), 0, EnumerateSymbolsCallback, 0);*/ - // #endif // X86 Only! + // #endif // X86 Only! - SymCleanup(GetCurrentProcess()); + SymCleanup(GetCurrentProcess()); - _tprintf(_T("\r\n")); + _tprintf(_T("\r\n")); + } + __except (EXCEPTION_EXECUTE_HANDLER) + { + _tprintf(_T("Error writing the crash log\r\n")); + } } //====================================================================== @@ -1068,7 +1090,7 @@ bool logChildren) { case btChar: case btStdString: - FormatOutputValue(buffer, basicType, length, (PVOID)offset, sizeof(buffer)); + FormatOutputValue(buffer, basicType, length, (PVOID)offset, sizeof(buffer), elementsCount); symbolDetails.top().Value = buffer; break; default: @@ -1196,7 +1218,8 @@ void WheatyExceptionReport::FormatOutputValue(char * pszCurrBuffer, BasicType basicType, DWORD64 length, PVOID pAddress, -size_t bufferSize) +size_t bufferSize, +size_t countOverride) { __try { @@ -1204,10 +1227,15 @@ size_t bufferSize) { case btChar: { - if (strlen((char*)pAddress) > bufferSize - 6) + // Special case handling for char[] type + if (countOverride != 0) + length = countOverride; + else + length = strlen((char*)pAddress); + if (length > bufferSize - 6) pszCurrBuffer += sprintf(pszCurrBuffer, "\"%.*s...\"", bufferSize - 6, (char*)pAddress); else - pszCurrBuffer += sprintf(pszCurrBuffer, "\"%s\"", (char*)pAddress); + pszCurrBuffer += sprintf(pszCurrBuffer, "\"%.*s\"", length, (char*)pAddress); break; } case btStdString: @@ -1307,16 +1335,43 @@ DWORD_PTR WheatyExceptionReport::DereferenceUnsafePointer(DWORD_PTR address) //============================================================================ int __cdecl WheatyExceptionReport::_tprintf(const TCHAR * format, ...) { - TCHAR szBuff[WER_LARGE_BUFFER_SIZE]; int retValue; - DWORD cbWritten; va_list argptr; - va_start(argptr, format); + if (stackOverflowException) + { + retValue = heapprintf(format, argptr); + va_end(argptr); + } + else + { + retValue = stackprintf(format, argptr); + va_end(argptr); + } + + return retValue; +} + +int __cdecl WheatyExceptionReport::stackprintf(const TCHAR * format, va_list argptr) +{ + int retValue; + DWORD cbWritten; + + TCHAR szBuff[WER_LARGE_BUFFER_SIZE]; retValue = vsprintf(szBuff, format, argptr); - va_end(argptr); + WriteFile(m_hReportFile, szBuff, retValue * sizeof(TCHAR), &cbWritten, 0); + + return retValue; +} +int __cdecl WheatyExceptionReport::heapprintf(const TCHAR * format, va_list argptr) +{ + int retValue; + DWORD cbWritten; + TCHAR* szBuff = (TCHAR*)malloc(sizeof(TCHAR) * WER_LARGE_BUFFER_SIZE); + retValue = vsprintf(szBuff, format, argptr); WriteFile(m_hReportFile, szBuff, retValue * sizeof(TCHAR), &cbWritten, 0); + free(szBuff); return retValue; } diff --git a/src/server/shared/Debugging/WheatyExceptionReport.h b/src/server/shared/Debugging/WheatyExceptionReport.h index 9137b91aac9..ef6334add16 100644 --- a/src/server/shared/Debugging/WheatyExceptionReport.h +++ b/src/server/shared/Debugging/WheatyExceptionReport.h @@ -172,12 +172,14 @@ class WheatyExceptionReport static char * DumpTypeIndex(char *, DWORD64, DWORD, unsigned, DWORD_PTR, bool &, const char*, char*, bool, bool); - static void FormatOutputValue(char * pszCurrBuffer, BasicType basicType, DWORD64 length, PVOID pAddress, size_t bufferSize); + static void FormatOutputValue(char * pszCurrBuffer, BasicType basicType, DWORD64 length, PVOID pAddress, size_t bufferSize, size_t countOverride = 0); static BasicType GetBasicType(DWORD typeIndex, DWORD64 modBase); static DWORD_PTR DereferenceUnsafePointer(DWORD_PTR address); static int __cdecl _tprintf(const TCHAR * format, ...); + static int __cdecl stackprintf(const TCHAR * format, va_list argptr); + static int __cdecl heapprintf(const TCHAR * format, va_list argptr); static bool StoreSymbol(DWORD type , DWORD_PTR offset); static void ClearSymbols(); @@ -191,6 +193,9 @@ class WheatyExceptionReport static HANDLE m_hProcess; static SymbolPairs symbols; static std::stack<SymbolDetail> symbolDetails; + static bool stackOverflowException; + static bool alreadyCrashed; + static std::mutex alreadyCrashedLock; static char* PushSymbolDetail(char* pszCurrBuffer); static char* PopSymbolDetail(char* pszCurrBuffer); diff --git a/src/server/shared/Define.h b/src/server/shared/Define.h index add7995ad8f..97e07cef8b3 100644 --- a/src/server/shared/Define.h +++ b/src/server/shared/Define.h @@ -78,9 +78,9 @@ #endif //!COREDEBUG #if COMPILER == COMPILER_GNU -# define ATTR_NORETURN __attribute__((noreturn)) -# define ATTR_PRINTF(F, V) __attribute__ ((format (printf, F, V))) -# define ATTR_DEPRECATED __attribute__((deprecated)) +# define ATTR_NORETURN __attribute__((__noreturn__)) +# define ATTR_PRINTF(F, V) __attribute__ ((__format__ (__printf__, F, V))) +# define ATTR_DEPRECATED __attribute__((__deprecated__)) #else //COMPILER != COMPILER_GNU # define ATTR_NORETURN # define ATTR_PRINTF(F, V) diff --git a/src/server/shared/Dynamic/ObjectRegistry.h b/src/server/shared/Dynamic/ObjectRegistry.h index be253f947a7..1eb9368be61 100644 --- a/src/server/shared/Dynamic/ObjectRegistry.h +++ b/src/server/shared/Dynamic/ObjectRegistry.h @@ -47,12 +47,12 @@ class ObjectRegistry } /// Inserts a registry item - bool InsertItem(T *obj, Key key, bool override = false) + bool InsertItem(T *obj, Key key, bool _override = false) { typename RegistryMapType::iterator iter = i_registeredObjects.find(key); if ( iter != i_registeredObjects.end() ) { - if ( !override ) + if ( !_override ) return false; delete iter->second; i_registeredObjects.erase(iter); diff --git a/src/server/shared/Logging/Appender.cpp b/src/server/shared/Logging/Appender.cpp index dff931e3da8..31db3ae1b86 100644 --- a/src/server/shared/Logging/Appender.cpp +++ b/src/server/shared/Logging/Appender.cpp @@ -18,6 +18,10 @@ #include "Appender.h" #include "Common.h" #include "Util.h" +#include "StringFormat.h" + +#include <utility> +#include <sstream> std::string LogMessage::getTimeStr(time_t time) { @@ -68,38 +72,23 @@ void Appender::setLogLevel(LogLevel _level) level = _level; } -void Appender::write(LogMessage& message) +void Appender::write(LogMessage* message) { - if (!level || level > message.level) + if (!level || level > message->level) return; - message.prefix.clear(); + std::ostringstream ss; + if (flags & APPENDER_FLAGS_PREFIX_TIMESTAMP) - message.prefix.append(message.getTimeStr()); + ss << message->getTimeStr() << ' '; if (flags & APPENDER_FLAGS_PREFIX_LOGLEVEL) - { - if (!message.prefix.empty()) - message.prefix.push_back(' '); - - char text[MAX_QUERY_LEN]; - snprintf(text, MAX_QUERY_LEN, "%-5s", Appender::getLogLevelString(message.level)); - message.prefix.append(text); - } + ss << Trinity::StringFormat("%-5s ", Appender::getLogLevelString(message->level)); if (flags & APPENDER_FLAGS_PREFIX_LOGFILTERTYPE) - { - if (!message.prefix.empty()) - message.prefix.push_back(' '); - - message.prefix.push_back('['); - message.prefix.append(message.type); - message.prefix.push_back(']'); - } - - if (!message.prefix.empty()) - message.prefix.push_back(' '); + ss << '[' << message->type << "] "; + message->prefix = ss.str(); _write(message); } diff --git a/src/server/shared/Logging/Appender.h b/src/server/shared/Logging/Appender.h index 6f38eb6aaf7..73af351e41d 100644 --- a/src/server/shared/Logging/Appender.h +++ b/src/server/shared/Logging/Appender.h @@ -21,6 +21,7 @@ #include <unordered_map> #include <string> #include <time.h> +#include <type_traits> #include "Define.h" // Values assigned have their equivalent in enum ACE_Log_Priority @@ -57,16 +58,19 @@ enum AppenderFlags struct LogMessage { - LogMessage(LogLevel _level, std::string const& _type, std::string const& _text) - : level(_level), type(_type), text(_text), mtime(time(NULL)) + LogMessage(LogLevel _level, std::string const& _type, std::string&& _text) + : level(_level), type(_type), text(std::forward<std::string>(_text)), mtime(time(NULL)) { } + LogMessage(LogMessage const& /*other*/) = delete; + LogMessage& operator=(LogMessage const& /*other*/) = delete; + static std::string getTimeStr(time_t time); std::string getTimeStr(); - LogLevel level; - std::string type; - std::string text; + LogLevel const level; + std::string const type; + std::string const text; std::string prefix; std::string param1; time_t mtime; @@ -91,11 +95,11 @@ class Appender AppenderFlags getFlags() const; void setLogLevel(LogLevel); - void write(LogMessage& message); + void write(LogMessage* message); static const char* getLogLevelString(LogLevel level); private: - virtual void _write(LogMessage const& /*message*/) = 0; + virtual void _write(LogMessage const* /*message*/) = 0; uint8 id; std::string name; diff --git a/src/server/shared/Logging/AppenderConsole.cpp b/src/server/shared/Logging/AppenderConsole.cpp index ae27337fb9a..2efa4db4d2e 100644 --- a/src/server/shared/Logging/AppenderConsole.cpp +++ b/src/server/shared/Logging/AppenderConsole.cpp @@ -158,14 +158,14 @@ void AppenderConsole::ResetColor(bool stdout_stream) #endif } -void AppenderConsole::_write(LogMessage const& message) +void AppenderConsole::_write(LogMessage const* message) { - bool stdout_stream = !(message.level == LOG_LEVEL_ERROR || message.level == LOG_LEVEL_FATAL); + bool stdout_stream = !(message->level == LOG_LEVEL_ERROR || message->level == LOG_LEVEL_FATAL); if (_colored) { uint8 index; - switch (message.level) + switch (message->level) { case LOG_LEVEL_TRACE: index = 5; @@ -189,9 +189,9 @@ void AppenderConsole::_write(LogMessage const& message) } SetColor(stdout_stream, _colors[index]); - utf8printf(stdout_stream ? stdout : stderr, "%s%s", message.prefix.c_str(), message.text.c_str()); + utf8printf(stdout_stream ? stdout : stderr, "%s%s\n", message->prefix.c_str(), message->text.c_str()); ResetColor(stdout_stream); } else - utf8printf(stdout_stream ? stdout : stderr, "%s%s", message.prefix.c_str(), message.text.c_str()); + utf8printf(stdout_stream ? stdout : stderr, "%s%s\n", message->prefix.c_str(), message->text.c_str()); } diff --git a/src/server/shared/Logging/AppenderConsole.h b/src/server/shared/Logging/AppenderConsole.h index 0f9536b3111..0acf7636e35 100644 --- a/src/server/shared/Logging/AppenderConsole.h +++ b/src/server/shared/Logging/AppenderConsole.h @@ -51,7 +51,7 @@ class AppenderConsole: public Appender private: void SetColor(bool stdout_stream, ColorTypes color); void ResetColor(bool stdout_stream); - void _write(LogMessage const& message) override; + void _write(LogMessage const* message) override; bool _colored; ColorTypes _colors[MaxLogLevels]; }; diff --git a/src/server/shared/Logging/AppenderDB.cpp b/src/server/shared/Logging/AppenderDB.cpp index d297d6d08d3..8a329ea3a0f 100644 --- a/src/server/shared/Logging/AppenderDB.cpp +++ b/src/server/shared/Logging/AppenderDB.cpp @@ -23,18 +23,18 @@ AppenderDB::AppenderDB(uint8 id, std::string const& name, LogLevel level) AppenderDB::~AppenderDB() { } -void AppenderDB::_write(LogMessage const& message) +void AppenderDB::_write(LogMessage const* message) { // Avoid infinite loop, PExecute triggers Logging with "sql.sql" type - if (!enabled || !message.type.find("sql")) + if (!enabled || (message->type.find("sql") != std::string::npos)) return; PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_INS_LOG); - stmt->setUInt64(0, message.mtime); + stmt->setUInt64(0, message->mtime); stmt->setUInt32(1, realmId); - stmt->setString(2, message.type); - stmt->setUInt8(3, uint8(message.level)); - stmt->setString(4, message.text); + stmt->setString(2, message->type); + stmt->setUInt8(3, uint8(message->level)); + stmt->setString(4, message->text); LoginDatabase.Execute(stmt); } diff --git a/src/server/shared/Logging/AppenderDB.h b/src/server/shared/Logging/AppenderDB.h index e20ceaf77b4..09affdb46f1 100644 --- a/src/server/shared/Logging/AppenderDB.h +++ b/src/server/shared/Logging/AppenderDB.h @@ -31,7 +31,7 @@ class AppenderDB: public Appender private: uint32 realmId; bool enabled; - void _write(LogMessage const& message) override; + void _write(LogMessage const* message) override; }; #endif diff --git a/src/server/shared/Logging/AppenderFile.cpp b/src/server/shared/Logging/AppenderFile.cpp index 07a88a367ae..5a8d610a36b 100644 --- a/src/server/shared/Logging/AppenderFile.cpp +++ b/src/server/shared/Logging/AppenderFile.cpp @@ -43,21 +43,21 @@ AppenderFile::~AppenderFile() CloseFile(); } -void AppenderFile::_write(LogMessage const& message) +void AppenderFile::_write(LogMessage const* message) { - bool exceedMaxSize = maxFileSize > 0 && (fileSize.load() + message.Size()) > maxFileSize; + bool exceedMaxSize = maxFileSize > 0 && (fileSize.load() + message->Size()) > maxFileSize; if (dynamicName) { char namebuf[TRINITY_PATH_MAX]; - snprintf(namebuf, TRINITY_PATH_MAX, filename.c_str(), message.param1.c_str()); + snprintf(namebuf, TRINITY_PATH_MAX, filename.c_str(), message->param1.c_str()); // always use "a" with dynamic name otherwise it could delete the log we wrote in last _write() call FILE* file = OpenFile(namebuf, "a", backup || exceedMaxSize); if (!file) return; - fprintf(file, "%s%s", message.prefix.c_str(), message.text.c_str()); + fprintf(file, "%s%s\n", message->prefix.c_str(), message->text.c_str()); fflush(file); - fileSize += uint64(message.Size()); + fileSize += uint64(message->Size()); fclose(file); return; } @@ -67,9 +67,9 @@ void AppenderFile::_write(LogMessage const& message) if (!logfile) return; - fprintf(logfile, "%s%s", message.prefix.c_str(), message.text.c_str()); + fprintf(logfile, "%s%s\n", message->prefix.c_str(), message->text.c_str()); fflush(logfile); - fileSize += uint64(message.Size()); + fileSize += uint64(message->Size()); } FILE* AppenderFile::OpenFile(std::string const &filename, std::string const &mode, bool backup) diff --git a/src/server/shared/Logging/AppenderFile.h b/src/server/shared/Logging/AppenderFile.h index 23651fc1129..36afdd23ad1 100644 --- a/src/server/shared/Logging/AppenderFile.h +++ b/src/server/shared/Logging/AppenderFile.h @@ -30,7 +30,7 @@ class AppenderFile: public Appender private: void CloseFile(); - void _write(LogMessage const& message) override; + void _write(LogMessage const* message) override; FILE* logfile; std::string filename; std::string logDir; diff --git a/src/server/shared/Logging/Log.cpp b/src/server/shared/Logging/Log.cpp index 4bf4dacb302..a2150733c6b 100644 --- a/src/server/shared/Logging/Log.cpp +++ b/src/server/shared/Logging/Log.cpp @@ -25,7 +25,6 @@ #include "AppenderDB.h" #include "LogOperation.h" -#include <cstdarg> #include <cstdio> #include <sstream> @@ -199,6 +198,9 @@ void Log::CreateLoggerFromConfig(std::string const& appenderName) return; } + if (level < lowestLogLevel) + lowestLogLevel = level; + logger.Create(name, level); //fprintf(stdout, "Log::CreateLoggerFromConfig: Created Logger %s, Level %u\n", name.c_str(), level); @@ -261,30 +263,18 @@ void Log::ReadLoggersFromConfig() } } -void Log::vlog(std::string const& filter, LogLevel level, char const* str, va_list argptr) -{ - char text[MAX_QUERY_LEN]; - vsnprintf(text, MAX_QUERY_LEN, str, argptr); - write(new LogMessage(level, filter, text)); -} - -void Log::write(LogMessage* msg) const +void Log::write(std::unique_ptr<LogMessage>&& msg) const { Logger const* logger = GetLoggerByType(msg->type); - msg->text.append("\n"); if (_ioService) { - auto logOperation = std::shared_ptr<LogOperation>(new LogOperation(logger, msg)); + auto logOperation = std::shared_ptr<LogOperation>(new LogOperation(logger, std::forward<std::unique_ptr<LogMessage>>(msg))); _ioService->post(_strand->wrap([logOperation](){ logOperation->call(); })); - } else - { - logger->write(*msg); - delete msg; - } + logger->write(msg.get()); } std::string Log::GetTimestampStr() @@ -321,6 +311,9 @@ bool Log::SetLogLevel(std::string const& name, const char* newLevelc, bool isLog return false; it->second.setLogLevel(newLevel); + + if (newLevel != LOG_LEVEL_DISABLED && newLevel < lowestLogLevel) + lowestLogLevel = newLevel; } else { @@ -343,40 +336,20 @@ void Log::outCharDump(char const* str, uint32 accountId, uint64 guid, char const ss << "== START DUMP == (account: " << accountId << " guid: " << guid << " name: " << name << ")\n" << str << "\n== END DUMP ==\n"; - LogMessage* msg = new LogMessage(LOG_LEVEL_INFO, "entities.player.dump", ss.str()); + std::unique_ptr<LogMessage> msg(new LogMessage(LOG_LEVEL_INFO, "entities.player.dump", ss.str())); std::ostringstream param; param << guid << '_' << name; msg->param1 = param.str(); - write(msg); -} - -void Log::outCommand(uint32 account, const char * str, ...) -{ - if (!str || !ShouldLog("commands.gm", LOG_LEVEL_INFO)) - return; - - va_list ap; - va_start(ap, str); - char text[MAX_QUERY_LEN]; - vsnprintf(text, MAX_QUERY_LEN, str, ap); - va_end(ap); - - LogMessage* msg = new LogMessage(LOG_LEVEL_INFO, "commands.gm", text); - - std::ostringstream ss; - ss << account; - msg->param1 = ss.str(); - - write(msg); + write(std::move(msg)); } void Log::SetRealmId(uint32 id) { for (AppenderMap::iterator it = appenders.begin(); it != appenders.end(); ++it) if (it->second && it->second->getType() == APPENDER_DB) - ((AppenderDB *)it->second)->setRealmId(id); + static_cast<AppenderDB*>(it->second)->setRealmId(id); } void Log::Close() @@ -394,6 +367,7 @@ void Log::LoadFromConfig() { Close(); + lowestLogLevel = LOG_LEVEL_FATAL; AppenderId = 0; m_logsDir = sConfigMgr->GetStringDefault("LogsDir", ""); if (!m_logsDir.empty()) diff --git a/src/server/shared/Logging/Log.h b/src/server/shared/Logging/Log.h index 1d67ff87f76..e22a06e635e 100644 --- a/src/server/shared/Logging/Log.h +++ b/src/server/shared/Logging/Log.h @@ -22,12 +22,14 @@ #include "Define.h" #include "Appender.h" #include "Logger.h" -#include <stdarg.h> +#include "StringFormat.h" #include <boost/asio/io_service.hpp> #include <boost/asio/strand.hpp> +#include <stdarg.h> #include <unordered_map> #include <string> +#include <memory> #define LOGGER_ROOT "root" @@ -59,17 +61,32 @@ class Log bool ShouldLog(std::string const& type, LogLevel level) const; bool SetLogLevel(std::string const& name, char const* level, bool isLogger = true); - void outMessage(std::string const& f, LogLevel level, char const* str, ...) ATTR_PRINTF(4, 5); + template<typename... Args> + inline void outMessage(std::string const& filter, LogLevel const level, const char* fmt, Args const&... args) + { + write(std::unique_ptr<LogMessage>(new LogMessage(level, filter, Trinity::StringFormat(fmt, args...)))); + } + + template<typename... Args> + void outCommand(uint32 account, const char* fmt, Args const&... args) + { + if (!ShouldLog("commands.gm", LOG_LEVEL_INFO)) + return; + + std::unique_ptr<LogMessage> msg(new LogMessage(LOG_LEVEL_INFO, "commands.gm", std::move(Trinity::StringFormat(fmt, args...)))); + + msg->param1 = std::to_string(account); + + write(std::move(msg)); + } - void outCommand(uint32 account, const char * str, ...) ATTR_PRINTF(3, 4); void outCharDump(char const* str, uint32 account_id, uint64 guid, char const* name); void SetRealmId(uint32 id); private: static std::string GetTimestampStr(); - void vlog(std::string const& f, LogLevel level, char const* str, va_list argptr); - void write(LogMessage* msg) const; + void write(std::unique_ptr<LogMessage>&& msg) const; Logger const* GetLoggerByType(std::string const& type) const; Appender* GetAppenderByName(std::string const& name); @@ -82,6 +99,7 @@ class Log AppenderMap appenders; LoggerMap loggers; uint8 AppenderId; + LogLevel lowestLogLevel; std::string m_logsDir; std::string m_logsTimestamp; @@ -113,6 +131,10 @@ inline bool Log::ShouldLog(std::string const& type, LogLevel level) const // Speed up in cases where requesting "Type.sub1.sub2" but only configured // Logger "Type" + // Don't even look for a logger if the LogLevel is lower than lowest log levels across all loggers + if (level < lowestLogLevel) + return false; + Logger const* logger = GetLoggerByType(type); if (!logger) return false; @@ -121,23 +143,34 @@ inline bool Log::ShouldLog(std::string const& type, LogLevel level) const return logLevel != LOG_LEVEL_DISABLED && logLevel <= level; } -inline void Log::outMessage(std::string const& filter, LogLevel level, const char * str, ...) -{ - va_list ap; - va_start(ap, str); - - vlog(filter, level, str, ap); - - va_end(ap); -} - #define sLog Log::instance() +#define LOG_EXCEPTION_FREE(filterType__, level__, ...) \ + { \ + try \ + { \ + sLog->outMessage(filterType__, level__, __VA_ARGS__); \ + } \ + catch (std::exception& e) \ + { \ + sLog->outMessage("server", LOG_LEVEL_ERROR, "Wrong format occurred (%s) at %s:%u.", \ + e.what(), __FILE__, __LINE__); \ + } \ + } + #if PLATFORM != PLATFORM_WINDOWS +void check_args(const char* format, ...) ATTR_PRINTF(1, 2); + +// This will catch format errors on build time #define TC_LOG_MESSAGE_BODY(filterType__, level__, ...) \ do { \ if (sLog->ShouldLog(filterType__, level__)) \ - sLog->outMessage(filterType__, level__, __VA_ARGS__); \ + { \ + if (false) \ + check_args(__VA_ARGS__); \ + \ + LOG_EXCEPTION_FREE(filterType__, level__, __VA_ARGS__); \ + } \ } while (0) #else #define TC_LOG_MESSAGE_BODY(filterType__, level__, ...) \ @@ -145,7 +178,7 @@ inline void Log::outMessage(std::string const& filter, LogLevel level, const cha __pragma(warning(disable:4127)) \ do { \ if (sLog->ShouldLog(filterType__, level__)) \ - sLog->outMessage(filterType__, level__, __VA_ARGS__); \ + LOG_EXCEPTION_FREE(filterType__, level__, __VA_ARGS__); \ } while (0) \ __pragma(warning(pop)) #endif diff --git a/src/server/shared/Logging/LogOperation.cpp b/src/server/shared/Logging/LogOperation.cpp index 9afb28a49c2..bcd923c705e 100644 --- a/src/server/shared/Logging/LogOperation.cpp +++ b/src/server/shared/Logging/LogOperation.cpp @@ -18,14 +18,8 @@ #include "LogOperation.h" #include "Logger.h" -LogOperation::~LogOperation() -{ - delete msg; -} - int LogOperation::call() { - if (logger && msg) - logger->write(*msg); + logger->write(msg.get()); return 0; } diff --git a/src/server/shared/Logging/LogOperation.h b/src/server/shared/Logging/LogOperation.h index b8655413273..ffdd35c3c09 100644 --- a/src/server/shared/Logging/LogOperation.h +++ b/src/server/shared/Logging/LogOperation.h @@ -18,23 +18,25 @@ #ifndef LOGOPERATION_H #define LOGOPERATION_H +#include <memory> + class Logger; struct LogMessage; class LogOperation { public: - LogOperation(Logger const* _logger, LogMessage* _msg) - : logger(_logger), msg(_msg) + LogOperation(Logger const* _logger, std::unique_ptr<LogMessage>&& _msg) + : logger(_logger), msg(std::forward<std::unique_ptr<LogMessage>>(_msg)) { } - ~LogOperation(); + ~LogOperation() { } int call(); protected: Logger const* logger; - LogMessage* msg; + std::unique_ptr<LogMessage> msg; }; #endif diff --git a/src/server/shared/Logging/Logger.cpp b/src/server/shared/Logging/Logger.cpp index 615732deb30..3b02eb47575 100644 --- a/src/server/shared/Logging/Logger.cpp +++ b/src/server/shared/Logging/Logger.cpp @@ -50,9 +50,9 @@ void Logger::setLogLevel(LogLevel _level) level = _level; } -void Logger::write(LogMessage& message) const +void Logger::write(LogMessage* message) const { - if (!level || level > message.level || message.text.empty()) + if (!level || level > message->level || message->text.empty()) { //fprintf(stderr, "Logger::write: Logger %s, Level %u. Msg %s Level %u WRONG LEVEL MASK OR EMPTY MSG\n", getName().c_str(), getLogLevel(), message.text.c_str(), message.level); return; diff --git a/src/server/shared/Logging/Logger.h b/src/server/shared/Logging/Logger.h index a81ee8d7bd2..1aee75c5d72 100644 --- a/src/server/shared/Logging/Logger.h +++ b/src/server/shared/Logging/Logger.h @@ -32,7 +32,7 @@ class Logger std::string const& getName() const; LogLevel getLogLevel() const; void setLogLevel(LogLevel level); - void write(LogMessage& message) const; + void write(LogMessage* message) const; private: std::string name; diff --git a/src/server/shared/Networking/MessageBuffer.h b/src/server/shared/Networking/MessageBuffer.h index 90afaf35796..95e26974626 100644 --- a/src/server/shared/Networking/MessageBuffer.h +++ b/src/server/shared/Networking/MessageBuffer.h @@ -55,9 +55,9 @@ public: uint8* GetBasePointer() { return _storage.data(); } - uint8* GetReadPointer() { return &_storage[_rpos]; } + uint8* GetReadPointer() { return GetBasePointer() + _rpos; } - uint8* GetWritePointer() { return &_storage[_wpos]; } + uint8* GetWritePointer() { return GetBasePointer() + _wpos; } void ReadCompleted(size_type bytes) { _rpos += bytes; } diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h index f6bf3976b85..1989411bccb 100644 --- a/src/server/shared/Networking/Socket.h +++ b/src/server/shared/Networking/Socket.h @@ -34,7 +34,9 @@ using boost::asio::ip::tcp; #define READ_BLOCK_SIZE 4096 -#define TC_SOCKET_USE_IOCP BOOST_ASIO_HAS_IOCP +#ifdef BOOST_ASIO_HAS_IOCP +#define TC_SOCKET_USE_IOCP +#endif template<class T> class Socket : public std::enable_shared_from_this<T> @@ -48,6 +50,7 @@ public: virtual ~Socket() { + _closed = true; boost::system::error_code error; _socket.close(error); } @@ -60,7 +63,7 @@ public: return false; #ifndef TC_SOCKET_USE_IOCP - std::unique_lock<std::mutex> guard(_writeLock, std::try_to_lock); + std::unique_lock<std::mutex> guard(_writeLock); if (!guard) return true; @@ -95,26 +98,6 @@ public: std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); } - void ReadData(std::size_t size) - { - if (!IsOpen()) - return; - - boost::system::error_code error; - - std::size_t bytesRead = boost::asio::read(_socket, boost::asio::buffer(_readBuffer.GetWritePointer(), size), error); - - _readBuffer.WriteCompleted(bytesRead); - - if (error || bytesRead != size) - { - TC_LOG_DEBUG("network", "Socket::ReadData: %s errored with: %i (%s)", GetRemoteIpAddress().to_string().c_str(), error.value(), - error.message().c_str()); - - CloseSocket(); - } - } - void QueuePacket(MessageBuffer&& buffer, std::unique_lock<std::mutex>& guard) { _writeQueue.push(std::move(buffer)); @@ -138,6 +121,8 @@ public: if (shutdownError) TC_LOG_DEBUG("network", "Socket::CloseSocket: %s errored when shutting down socket: %i (%s)", GetRemoteIpAddress().to_string().c_str(), shutdownError.value(), shutdownError.message().c_str()); + + OnClose(); } /// Marks the socket for closing after write buffer becomes empty @@ -146,6 +131,8 @@ public: MessageBuffer& GetReadBuffer() { return _readBuffer; } protected: + virtual void OnClose() { } + virtual void ReadHandler() = 0; bool AsyncProcessQueue(std::unique_lock<std::mutex>&) diff --git a/src/server/shared/Networking/SocketMgr.h b/src/server/shared/Networking/SocketMgr.h index df57baf257e..2078ae90c62 100644 --- a/src/server/shared/Networking/SocketMgr.h +++ b/src/server/shared/Networking/SocketMgr.h @@ -99,7 +99,7 @@ public: } catch (boost::system::system_error const& err) { - TC_LOG_INFO("network", "Failed to retrieve client's remote address %s", err.what()); + TC_LOG_WARN("network", "Failed to retrieve client's remote address %s", err.what()); } } diff --git a/src/server/shared/PrecompiledHeaders/sharedPCH.h b/src/server/shared/PrecompiledHeaders/sharedPCH.h index d0c15b17f0c..87af9f44eb7 100644 --- a/src/server/shared/PrecompiledHeaders/sharedPCH.h +++ b/src/server/shared/PrecompiledHeaders/sharedPCH.h @@ -6,3 +6,4 @@ #include "SQLOperation.h" #include "Errors.h" #include "TypeList.h" +#include "TaskScheduler.h" diff --git a/src/server/shared/Threading/ProducerConsumerQueue.h b/src/server/shared/Threading/ProducerConsumerQueue.h index ab01568309b..e2f13e5c339 100644 --- a/src/server/shared/Threading/ProducerConsumerQueue.h +++ b/src/server/shared/Threading/ProducerConsumerQueue.h @@ -70,7 +70,9 @@ public: { std::unique_lock<std::mutex> lock(_queueLock); - _condition.wait(lock, [this]() { return !_queue.empty() || _shutdown; }); + // we could be using .wait(lock, predicate) overload here but some threading error analysis tools produce false positives + while (_queue.empty() && !_shutdown) + _condition.wait(lock); if (_queue.empty() || _shutdown) return; diff --git a/src/server/shared/Updater/DBUpdater.cpp b/src/server/shared/Updater/DBUpdater.cpp new file mode 100644 index 00000000000..20ded669cec --- /dev/null +++ b/src/server/shared/Updater/DBUpdater.cpp @@ -0,0 +1,414 @@ +/* + * Copyright (C) 2008-2015 TrinityCore <http://www.trinitycore.org/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "DBUpdater.h" +#include "Log.h" +#include "revision.h" +#include "UpdateFetcher.h" +#include "DatabaseLoader.h" +#include "Config.h" + +#include <fstream> +#include <iostream> +#include <unordered_map> +#include <boost/process.hpp> +#include <boost/iostreams/device/file_descriptor.hpp> +#include <boost/system/system_error.hpp> + +using namespace boost::process; +using namespace boost::process::initializers; +using namespace boost::iostreams; + +template<class T> +std::string DBUpdater<T>::GetSourceDirectory() +{ + std::string const entry = sConfigMgr->GetStringDefault("Updates.SourcePath", ""); + if (!entry.empty()) + return entry; + else + return _SOURCE_DIRECTORY; +} + +template<class T> +std::string DBUpdater<T>::GetMySqlCli() +{ + std::string const entry = sConfigMgr->GetStringDefault("Updates.MySqlCLIPath", ""); + if (!entry.empty()) + return entry; + else + return _MYSQL_EXECUTABLE; +} + +// Auth Database +template<> +std::string DBUpdater<LoginDatabaseConnection>::GetConfigEntry() +{ + return "Updates.Auth"; +} + +template<> +std::string DBUpdater<LoginDatabaseConnection>::GetTableName() +{ + return "Auth"; +} + +template<> +std::string DBUpdater<LoginDatabaseConnection>::GetBaseFile() +{ + return DBUpdater<LoginDatabaseConnection>::GetSourceDirectory() + "/sql/base/auth_database.sql"; +} + +template<> +bool DBUpdater<LoginDatabaseConnection>::IsEnabled(uint32 const updateMask) +{ + // This way silences warnings under msvc + return (updateMask & DatabaseLoader::DATABASE_LOGIN) ? true : false; +} + +// World Database +template<> +std::string DBUpdater<WorldDatabaseConnection>::GetConfigEntry() +{ + return "Updates.World"; +} + +template<> +std::string DBUpdater<WorldDatabaseConnection>::GetTableName() +{ + return "World"; +} + +template<> +std::string DBUpdater<WorldDatabaseConnection>::GetBaseFile() +{ + return _FULL_DATABASE; +} + +template<> +bool DBUpdater<WorldDatabaseConnection>::IsEnabled(uint32 const updateMask) +{ + // This way silences warnings under msvc + return (updateMask & DatabaseLoader::DATABASE_WORLD) ? true : false; +} + +template<> +BaseLocation DBUpdater<WorldDatabaseConnection>::GetBaseLocationType() +{ + return LOCATION_DOWNLOAD; +} + +// Character Database +template<> +std::string DBUpdater<CharacterDatabaseConnection>::GetConfigEntry() +{ + return "Updates.Character"; +} + +template<> +std::string DBUpdater<CharacterDatabaseConnection>::GetTableName() +{ + return "Character"; +} + +template<> +std::string DBUpdater<CharacterDatabaseConnection>::GetBaseFile() +{ + return DBUpdater<CharacterDatabaseConnection>::GetSourceDirectory() + "/sql/base/characters_database.sql"; +} + +template<> +bool DBUpdater<CharacterDatabaseConnection>::IsEnabled(uint32 const updateMask) +{ + // This way silences warnings under msvc + return (updateMask & DatabaseLoader::DATABASE_CHARACTER) ? true : false; +} + +// All +template<class T> +BaseLocation DBUpdater<T>::GetBaseLocationType() +{ + return LOCATION_REPOSITORY; +} + +template<class T> +bool DBUpdater<T>::CheckExecutable() +{ + DBUpdater<T>::Path const exe(DBUpdater<T>::GetMySqlCli()); + if (!exists(exe)) + { + // Check for mysql in path + std::vector<std::string> args = {"--version"}; + uint32 ret; + try + { + child c = execute(run_exe("mysql"), set_args(args), throw_on_error(), close_stdout()); + ret = wait_for_exit(c); + } + catch (boost::system::system_error&) + { + ret = EXIT_FAILURE; + } + + if (ret == EXIT_FAILURE) + { + TC_LOG_FATAL("sql.updates", "Didn't find executeable mysql binary at \'%s\', correct the path in the *.conf (\"Updates.MySqlCLIPath\").", + absolute(exe).generic_string().c_str()); + + return false; + } + } + return true; +} + +template<class T> +bool DBUpdater<T>::Create(DatabaseWorkerPool<T>& pool) +{ + TC_LOG_INFO("sql.updates", "Database \"%s\" does not exist, do you want to create it? [yes (default) / no]: ", + pool.GetConnectionInfo()->database.c_str()); + + std::string answer; + std::getline(std::cin, answer); + if (!answer.empty() && !(answer.substr(0, 1) == "y")) + return false; + + TC_LOG_INFO("sql.updates", "Creating database \"%s\"...", pool.GetConnectionInfo()->database.c_str()); + + // Path of temp file + static Path const temp("create_table.sql"); + + // Create temporary query to use external mysql cli + std::ofstream file(temp.generic_string()); + if (!file.is_open()) + { + TC_LOG_FATAL("sql.updates", "Failed to create temporary query file \"%s\"!", temp.generic_string().c_str()); + return false; + } + + file << "CREATE DATABASE `" << pool.GetConnectionInfo()->database << "` DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci\n\n"; + + file.close(); + + try + { + DBUpdater<T>::ApplyFile(pool, pool.GetConnectionInfo()->host, pool.GetConnectionInfo()->user, pool.GetConnectionInfo()->password, + pool.GetConnectionInfo()->port_or_socket, "", temp); + } + catch (UpdateException&) + { + TC_LOG_FATAL("sql.updates", "Failed to create database %s! Has the user `CREATE` priviliges?", pool.GetConnectionInfo()->database.c_str()); + boost::filesystem::remove(temp); + return false; + } + + TC_LOG_INFO("sql.updates", "Done."); + boost::filesystem::remove(temp); + return true; +} + +template<class T> +bool DBUpdater<T>::Update(DatabaseWorkerPool<T>& pool) +{ + if (!DBUpdater<T>::CheckExecutable()) + return false; + + TC_LOG_INFO("sql.updates", "Updating %s database...", DBUpdater<T>::GetTableName().c_str()); + + Path const sourceDirectory(GetSourceDirectory()); + + if (!is_directory(sourceDirectory)) + { + TC_LOG_ERROR("sql.updates", "DBUpdater: Given source directory %s does not exist, skipped!", sourceDirectory.generic_string().c_str()); + return false; + } + + UpdateFetcher updateFetcher(sourceDirectory, [&](std::string const& query) { DBUpdater<T>::Apply(pool, query); }, + [&](Path const& file) { DBUpdater<T>::ApplyFile(pool, file); }, + [&](std::string const& query) -> QueryResult { return DBUpdater<T>::Retrieve(pool, query); }); + + uint32 count; + try + { + count = updateFetcher.Update( + sConfigMgr->GetBoolDefault("Updates.Redundancy", true), + sConfigMgr->GetBoolDefault("Updates.AllowRehash", true), + sConfigMgr->GetBoolDefault("Updates.ArchivedRedundancy", false), + sConfigMgr->GetIntDefault("Updates.CleanDeadRefMaxCount", 3)); + } + catch (UpdateException&) + { + return false; + } + + if (!count) + TC_LOG_INFO("sql.updates", ">> %s database is up-to-date!", DBUpdater<T>::GetTableName().c_str()); + else + TC_LOG_INFO("sql.updates", ">> Applied %d %s.", count, count == 1 ? "query" : "queries"); + + return true; +} + +template<class T> +bool DBUpdater<T>::Populate(DatabaseWorkerPool<T>& pool) +{ + { + QueryResult const result = Retrieve(pool, "SHOW TABLES"); + if (result && (result->GetRowCount() > 0)) + return true; + } + + if (!DBUpdater<T>::CheckExecutable()) + return false; + + TC_LOG_INFO("sql.updates", "Database %s is empty, auto populating it...", DBUpdater<T>::GetTableName().c_str()); + + std::string const p = DBUpdater<T>::GetBaseFile(); + if (p.empty()) + { + TC_LOG_INFO("sql.updates", ">> No base file provided, skipped!"); + return true; + } + + Path const base(p); + if (!exists(base)) + { + switch (DBUpdater<T>::GetBaseLocationType()) + { + case LOCATION_REPOSITORY: + { + TC_LOG_ERROR("sql.updates", ">> Base file \"%s\" is missing, try to clone the source again.", + base.generic_string().c_str()); + + break; + } + case LOCATION_DOWNLOAD: + { + TC_LOG_ERROR("sql.updates", ">> File \"%s\" is missing, download it from \"http://www.trinitycore.org/f/files/category/1-database/\"" \ + " and place it in your server directory.", base.filename().generic_string().c_str()); + break; + } + } + return false; + } + + // Update database + TC_LOG_INFO("sql.updates", ">> Applying \'%s\'...", base.generic_string().c_str()); + try + { + ApplyFile(pool, base); + } + catch (UpdateException&) + { + return false; + } + + TC_LOG_INFO("sql.updates", ">> Done!"); + return true; +} + +template<class T> +QueryResult DBUpdater<T>::Retrieve(DatabaseWorkerPool<T>& pool, std::string const& query) +{ + return pool.PQuery(query.c_str()); +} + +template<class T> +void DBUpdater<T>::Apply(DatabaseWorkerPool<T>& pool, std::string const& query) +{ + pool.DirectExecute(query.c_str()); +} + +template<class T> +void DBUpdater<T>::ApplyFile(DatabaseWorkerPool<T>& pool, Path const& path) +{ + DBUpdater<T>::ApplyFile(pool, pool.GetConnectionInfo()->host, pool.GetConnectionInfo()->user, pool.GetConnectionInfo()->password, + pool.GetConnectionInfo()->port_or_socket, pool.GetConnectionInfo()->database, path); +} + +template<class T> +void DBUpdater<T>::ApplyFile(DatabaseWorkerPool<T>& pool, std::string const& host, std::string const& user, + std::string const& password, std::string const& port_or_socket, std::string const& database, Path const& path) +{ + std::vector<std::string> args; + args.reserve(7); + + // CLI Client connection info + args.push_back("-h" + host); + args.push_back("-u" + user); + args.push_back("-p" + password); + + // Check if we want to connect through ip or socket (Unix only) +#ifdef _WIN32 + + args.push_back("-P" + port_or_socket); + +#else + + if (!std::isdigit(port_or_socket[0])) + { + // We can't check here if host == "." because is named localhost if socket option is enabled + args.push_back("-P0"); + args.push_back("--protocol=SOCKET"); + args.push_back("-S" + port_or_socket); + } + else + // generic case + args.push_back("-P" + port_or_socket); + +#endif + + // Set the default charset to utf8 + args.push_back("--default-character-set=utf8"); + + // Set max allowed packet to 1 GB + args.push_back("--max-allowed-packet=1GB"); + + // Database + if (!database.empty()) + args.push_back(database); + + // ToDo: use the existing query in memory as virtual file if possible + file_descriptor_source source(path); + + uint32 ret; + try + { + child c = execute(run_exe(DBUpdater<T>::GetMySqlCli().empty() ? "mysql" : + boost::filesystem::absolute(DBUpdater<T>::GetMySqlCli()).generic_string()), + set_args(args), bind_stdin(source), throw_on_error()); + + ret = wait_for_exit(c); + } + catch (boost::system::system_error&) + { + ret = EXIT_FAILURE; + } + + source.close(); + + if (ret != EXIT_SUCCESS) + { + TC_LOG_FATAL("sql.updates", "Applying of file \'%s\' to database \'%s\' failed!" \ + " If you are an user pull the latest revision from the repository. If you are a developer fix your sql query.", + path.generic_string().c_str(), pool.GetConnectionInfo()->database.c_str()); + + throw UpdateException("update failed"); + } +} + +template class DBUpdater<LoginDatabaseConnection>; +template class DBUpdater<WorldDatabaseConnection>; +template class DBUpdater<CharacterDatabaseConnection>; diff --git a/src/server/shared/Updater/DBUpdater.h b/src/server/shared/Updater/DBUpdater.h new file mode 100644 index 00000000000..0caf8a438fb --- /dev/null +++ b/src/server/shared/Updater/DBUpdater.h @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2008-2015 TrinityCore <http://www.trinitycore.org/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef DBUpdater_h__ +#define DBUpdater_h__ + +#include "DatabaseEnv.h" + +#include <string> +#include <boost/filesystem.hpp> + +class UpdateException : public std::exception +{ +public: + UpdateException(std::string const& msg) : _msg(msg) { } + ~UpdateException() throw() { } + + char const* what() const throw() override { return _msg.c_str(); } + +private: + std::string const _msg; +}; + +enum BaseLocation +{ + LOCATION_REPOSITORY, + LOCATION_DOWNLOAD +}; + +template <class T> +class DBUpdater +{ +public: + using Path = boost::filesystem::path; + + static std::string GetSourceDirectory(); + + static inline std::string GetConfigEntry(); + + static inline std::string GetTableName(); + + static std::string GetBaseFile(); + + static bool IsEnabled(uint32 const updateMask); + + static BaseLocation GetBaseLocationType(); + + static bool Create(DatabaseWorkerPool<T>& pool); + + static bool Update(DatabaseWorkerPool<T>& pool); + + static bool Populate(DatabaseWorkerPool<T>& pool); + +private: + static std::string GetMySqlCli(); + static bool CheckExecutable(); + + static QueryResult Retrieve(DatabaseWorkerPool<T>& pool, std::string const& query); + static void Apply(DatabaseWorkerPool<T>& pool, std::string const& query); + static void ApplyFile(DatabaseWorkerPool<T>& pool, Path const& path); + static void ApplyFile(DatabaseWorkerPool<T>& pool, std::string const& host, std::string const& user, + std::string const& password, std::string const& port_or_socket, std::string const& database, Path const& path); +}; + +#endif // DBUpdater_h__ diff --git a/src/server/shared/Updater/UpdateFetcher.cpp b/src/server/shared/Updater/UpdateFetcher.cpp new file mode 100644 index 00000000000..a4bdd193743 --- /dev/null +++ b/src/server/shared/Updater/UpdateFetcher.cpp @@ -0,0 +1,401 @@ +/* + * Copyright (C) 2008-2015 TrinityCore <http://www.trinitycore.org/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "UpdateFetcher.h" +#include "Log.h" +#include "Util.h" + +#include <fstream> +#include <chrono> +#include <vector> +#include <sstream> +#include <exception> +#include <unordered_map> +#include <openssl/sha.h> + +using namespace boost::filesystem; + +UpdateFetcher::UpdateFetcher(Path const& sourceDirectory, + std::function<void(std::string const&)> const& apply, + std::function<void(Path const& path)> const& applyFile, + std::function<QueryResult(std::string const&)> const& retrieve) : + _sourceDirectory(sourceDirectory), _apply(apply), _applyFile(applyFile), + _retrieve(retrieve) +{ +} + +UpdateFetcher::LocaleFileStorage UpdateFetcher::GetFileList() const +{ + LocaleFileStorage files; + DirectoryStorage directories = ReceiveIncludedDirectories(); + for (auto const& entry : directories) + FillFileListRecursively(entry.path, files, entry.state, 1); + + return files; +} + +void UpdateFetcher::FillFileListRecursively(Path const& path, LocaleFileStorage& storage, State const state, uint32 const depth) const +{ + static uint32 const MAX_DEPTH = 10; + static directory_iterator const end; + + for (directory_iterator itr(path); itr != end; ++itr) + { + if (is_directory(itr->path())) + { + if (depth < MAX_DEPTH) + FillFileListRecursively(itr->path(), storage, state, depth + 1); + } + else if (itr->path().extension() == ".sql") + { + TC_LOG_TRACE("sql.updates", "Added locale file \"%s\".", itr->path().filename().generic_string().c_str()); + + LocaleFileEntry const entry = { itr->path(), state }; + + // Check for doubled filenames + // Since elements are only compared through their filenames this is ok + if (storage.find(entry) != storage.end()) + { + TC_LOG_FATAL("sql.updates", "Duplicated filename occurred \"%s\", since updates are ordered " \ + "through its filename every name needs to be unique!", itr->path().generic_string().c_str()); + + throw UpdateException("Updating failed, see the log for details."); + } + + storage.insert(entry); + } + } +} + +UpdateFetcher::DirectoryStorage UpdateFetcher::ReceiveIncludedDirectories() const +{ + DirectoryStorage directories; + + QueryResult const result = _retrieve("SELECT `path`, `state` FROM `updates_include`"); + if (!result) + return directories; + + do + { + Field* fields = result->Fetch(); + + std::string path = fields[0].GetString(); + if (path.substr(0, 1) == "$") + path = _sourceDirectory.generic_string() + path.substr(1); + + Path const p(path); + + if (!is_directory(p)) + { + TC_LOG_WARN("sql.updates", "DBUpdater: Given update include directory \"%s\" isn't existing, skipped!", p.generic_string().c_str()); + continue; + } + + DirectoryEntry const entry = { p, AppliedFileEntry::StateConvert(fields[1].GetString()) }; + directories.push_back(entry); + + TC_LOG_TRACE("sql.updates", "Added applied file \"%s\" from remote.", p.filename().generic_string().c_str()); + + } while (result->NextRow()); + + return directories; +} + +UpdateFetcher::AppliedFileStorage UpdateFetcher::ReceiveAppliedFiles() const +{ + AppliedFileStorage map; + + QueryResult result = _retrieve("SELECT `name`, `hash`, `state`, UNIX_TIMESTAMP(`timestamp`) FROM `updates` ORDER BY `name` ASC"); + if (!result) + return map; + + do + { + Field* fields = result->Fetch(); + + AppliedFileEntry const entry = { fields[0].GetString(), fields[1].GetString(), + AppliedFileEntry::StateConvert(fields[2].GetString()), fields[3].GetUInt64() }; + + map.insert(std::make_pair(entry.name, entry)); + } + while (result->NextRow()); + + return map; +} + +UpdateFetcher::SQLUpdate UpdateFetcher::ReadSQLUpdate(boost::filesystem::path const& file) const +{ + std::ifstream in(file.c_str()); + WPFatal(in.is_open(), "Could not read an update file."); + + auto const start_pos = in.tellg(); + in.ignore(std::numeric_limits<std::streamsize>::max()); + auto const char_count = in.gcount(); + in.seekg(start_pos); + + SQLUpdate const update(new std::string(char_count, char{})); + + in.read(&(*update)[0], update->size()); + in.close(); + return update; +} + +uint32 UpdateFetcher::Update(bool const redundancyChecks, bool const allowRehash, bool const archivedRedundancy, int32 const cleanDeadReferencesMaxCount) const +{ + LocaleFileStorage const available = GetFileList(); + AppliedFileStorage applied = ReceiveAppliedFiles(); + + // Fill hash to name cache + HashToFileNameStorage hashToName; + for (auto entry : applied) + hashToName.insert(std::make_pair(entry.second.hash, entry.first)); + + uint32 importedUpdates = 0; + + for (auto const& availableQuery : available) + { + TC_LOG_DEBUG("sql.updates", "Checking update \"%s\"...", availableQuery.first.filename().generic_string().c_str()); + + AppliedFileStorage::const_iterator iter = applied.find(availableQuery.first.filename().string()); + if (iter != applied.end()) + { + // If redundancy is disabled skip it since the update is already applied. + if (!redundancyChecks) + { + TC_LOG_DEBUG("sql.updates", ">> Update is already applied, skipping redundancy checks."); + applied.erase(iter); + continue; + } + + // If the update is in an archived directory and is marked as archived in our database skip redundancy checks (archived updates never change). + if (!archivedRedundancy && (iter->second.state == ARCHIVED) && (availableQuery.second == ARCHIVED)) + { + TC_LOG_DEBUG("sql.updates", ">> Update is archived and marked as archived in database, skipping redundancy checks."); + applied.erase(iter); + continue; + } + } + + // Read update from file + SQLUpdate const update = ReadSQLUpdate(availableQuery.first); + + // Calculate hash + std::string const hash = CalculateHash(update); + + UpdateMode mode = MODE_APPLY; + + // Update is not in our applied list + if (iter == applied.end()) + { + // Catch renames (different filename but same hash) + HashToFileNameStorage::const_iterator const hashIter = hashToName.find(hash); + if (hashIter != hashToName.end()) + { + // Check if the original file was removed if not we've got a problem. + LocaleFileStorage::const_iterator localeIter; + // Push localeIter forward + for (localeIter = available.begin(); (localeIter != available.end()) && + (localeIter->first.filename().string() != hashIter->second); ++localeIter); + + // Conflict! + if (localeIter != available.end()) + { + TC_LOG_WARN("sql.updates", ">> Seems like update \"%s\" \'%s\' was renamed, but the old file is still there! " \ + "Trade it as a new file! (Probably its an unmodified copy of file \"%s\")", + availableQuery.first.filename().string().c_str(), hash.substr(0, 7).c_str(), + localeIter->first.filename().string().c_str()); + } + // Its save to trade the file as renamed here + else + { + TC_LOG_INFO("sql.updates", ">> Renaming update \"%s\" to \"%s\" \'%s\'.", + hashIter->second.c_str(), availableQuery.first.filename().string().c_str(), hash.substr(0, 7).c_str()); + + RenameEntry(hashIter->second, availableQuery.first.filename().string()); + applied.erase(hashIter->second); + continue; + } + } + // Apply the update if it was never seen before. + else + { + TC_LOG_INFO("sql.updates", ">> Applying update \"%s\" \'%s\'...", + availableQuery.first.filename().string().c_str(), hash.substr(0, 7).c_str()); + } + } + // Rehash the update entry if it is contained in our database but with an empty hash. + else if (allowRehash && iter->second.hash.empty()) + { + mode = MODE_REHASH; + + TC_LOG_INFO("sql.updates", ">> Re-hashing update \"%s\" \'%s\'...", availableQuery.first.filename().string().c_str(), + hash.substr(0, 7).c_str()); + } + else + { + // If the hash of the files differs from the one stored in our database reapply the update (because it was changed). + if (iter->second.hash != hash) + { + TC_LOG_INFO("sql.updates", ">> Reapplying update \"%s\" \'%s\' -> \'%s\' (it changed)...", availableQuery.first.filename().string().c_str(), + iter->second.hash.substr(0, 7).c_str(), hash.substr(0, 7).c_str()); + } + else + { + // If the file wasn't changed and just moved update its state if necessary. + if (iter->second.state != availableQuery.second) + { + TC_LOG_DEBUG("sql.updates", ">> Updating state of \"%s\" to \'%s\'...", + availableQuery.first.filename().string().c_str(), AppliedFileEntry::StateConvert(availableQuery.second).c_str()); + + UpdateState(availableQuery.first.filename().string(), availableQuery.second); + } + + TC_LOG_DEBUG("sql.updates", ">> Update is already applied and is matching hash \'%s\'.", hash.substr(0, 7).c_str()); + + applied.erase(iter); + continue; + } + } + + uint32 speed = 0; + AppliedFileEntry const file = { availableQuery.first.filename().string(), hash, availableQuery.second, 0 }; + + switch (mode) + { + case MODE_APPLY: + speed = Apply(availableQuery.first); + /*no break*/ + case MODE_REHASH: + UpdateEntry(file, speed); + break; + } + + if (iter != applied.end()) + applied.erase(iter); + + if (mode == MODE_APPLY) + ++importedUpdates; + } + + // Cleanup up orphaned entries if enabled + if (!applied.empty()) + { + bool const doCleanup = (cleanDeadReferencesMaxCount < 0) || (applied.size() <= static_cast<size_t>(cleanDeadReferencesMaxCount)); + + for (auto const& entry : applied) + { + TC_LOG_WARN("sql.updates", ">> File \'%s\' was applied to the database but is missing in" \ + " your update directory now!", entry.first.c_str()); + + if (doCleanup) + TC_LOG_INFO("sql.updates", "Deleting orphaned entry \'%s\'...", entry.first.c_str()); + } + + if (doCleanup) + CleanUp(applied); + else + { + TC_LOG_ERROR("sql.updates", "Cleanup is disabled! There are " SZFMTD " dirty files that were applied to your database " \ + "but are now missing in your source directory!", applied.size()); + } + } + + return importedUpdates; +} + +std::string UpdateFetcher::CalculateHash(SQLUpdate const& query) const +{ + // Calculate a Sha1 hash based on query content. + unsigned char digest[SHA_DIGEST_LENGTH]; + SHA1((unsigned char*)query->c_str(), query->length(), (unsigned char*)&digest); + + return ByteArrayToHexStr(digest, SHA_DIGEST_LENGTH); +} + +uint32 UpdateFetcher::Apply(Path const& path) const +{ + using Time = std::chrono::high_resolution_clock; + using ms = std::chrono::milliseconds; + + // Benchmark query speed + auto const begin = Time::now(); + + // Update database + _applyFile(path); + + // Return time the query took to apply + return std::chrono::duration_cast<ms>(Time::now() - begin).count(); +} + +void UpdateFetcher::UpdateEntry(AppliedFileEntry const& entry, uint32 const speed) const +{ + std::string const update = "REPLACE INTO `updates` (`name`, `hash`, `state`, `speed`) VALUES (\"" + + entry.name + "\", \"" + entry.hash + "\", \'" + entry.GetStateAsString() + "\', " + std::to_string(speed) + ")"; + + // Update database + _apply(update); +} + +void UpdateFetcher::RenameEntry(std::string const& from, std::string const& to) const +{ + // Delete target if it exists + { + std::string const update = "DELETE FROM `updates` WHERE `name`=\"" + to + "\""; + + // Update database + _apply(update); + } + + // Rename + { + std::string const update = "UPDATE `updates` SET `name`=\"" + to + "\" WHERE `name`=\"" + from + "\""; + + // Update database + _apply(update); + } +} + +void UpdateFetcher::CleanUp(AppliedFileStorage const& storage) const +{ + if (storage.empty()) + return; + + std::stringstream update; + size_t remaining = storage.size(); + + update << "DELETE FROM `updates` WHERE `name` IN("; + + for (auto const& entry : storage) + { + update << "\"" << entry.first << "\""; + if ((--remaining) > 0) + update << ", "; + } + + update << ")"; + + // Update database + _apply(update.str()); +} + +void UpdateFetcher::UpdateState(std::string const& name, State const state) const +{ + std::string const update = "UPDATE `updates` SET `state`=\'" + AppliedFileEntry::StateConvert(state) + "\' WHERE `name`=\"" + name + "\""; + + // Update database + _apply(update); +} diff --git a/src/server/shared/Updater/UpdateFetcher.h b/src/server/shared/Updater/UpdateFetcher.h new file mode 100644 index 00000000000..fa142547873 --- /dev/null +++ b/src/server/shared/Updater/UpdateFetcher.h @@ -0,0 +1,132 @@ +/* + * Copyright (C) 2008-2015 TrinityCore <http://www.trinitycore.org/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef UpdateFetcher_h__ +#define UpdateFetcher_h__ + +#include <DBUpdater.h> + +#include <functional> +#include <string> +#include <memory> +#include <vector> + +class UpdateFetcher +{ + typedef boost::filesystem::path Path; + +public: + UpdateFetcher(Path const& updateDirectory, + std::function<void(std::string const&)> const& apply, + std::function<void(Path const& path)> const& applyFile, + std::function<QueryResult(std::string const&)> const& retrieve); + + uint32 Update(bool const redundancyChecks, bool const allowRehash, + bool const archivedRedundancy, int32 const cleanDeadReferencesMaxCount) const; + +private: + enum UpdateMode + { + MODE_APPLY, + MODE_REHASH + }; + + enum State + { + RELEASED, + ARCHIVED + }; + + struct AppliedFileEntry + { + AppliedFileEntry(std::string const& name_, std::string const& hash_, State state_, uint64 timestamp_) + : name(name_), hash(hash_), state(state_), timestamp(timestamp_) { } + + std::string const name; + + std::string const hash; + + State const state; + + uint64 const timestamp; + + static inline State StateConvert(std::string const& state) + { + return (state == "RELEASED") ? RELEASED : ARCHIVED; + } + + static inline std::string StateConvert(State const state) + { + return (state == RELEASED) ? "RELEASED" : "ARCHIVED"; + } + + std::string GetStateAsString() const + { + return StateConvert(state); + } + }; + + struct DirectoryEntry + { + DirectoryEntry(Path const& path_, State state_) : path(path_), state(state_) { } + + Path const path; + + State const state; + }; + + typedef std::pair<Path, State> LocaleFileEntry; + + struct PathCompare + { + inline bool operator() (LocaleFileEntry const& left, LocaleFileEntry const& right) const + { + return left.first.filename().string() < right.first.filename().string(); + } + }; + + typedef std::set<LocaleFileEntry, PathCompare> LocaleFileStorage; + typedef std::unordered_map<std::string, std::string> HashToFileNameStorage; + typedef std::unordered_map<std::string, AppliedFileEntry> AppliedFileStorage; + typedef std::vector<UpdateFetcher::DirectoryEntry> DirectoryStorage; + typedef std::shared_ptr<std::string> SQLUpdate; + + LocaleFileStorage GetFileList() const; + void FillFileListRecursively(Path const& path, LocaleFileStorage& storage, State const state, uint32 const depth) const; + + DirectoryStorage ReceiveIncludedDirectories() const; + AppliedFileStorage ReceiveAppliedFiles() const; + + SQLUpdate ReadSQLUpdate(Path const& file) const; + std::string CalculateHash(SQLUpdate const& query) const; + + uint32 Apply(Path const& path) const; + + void UpdateEntry(AppliedFileEntry const& entry, uint32 const speed = 0) const; + void RenameEntry(std::string const& from, std::string const& to) const; + void CleanUp(AppliedFileStorage const& storage) const; + + void UpdateState(std::string const& name, State const state) const; + + Path const _sourceDirectory; + + std::function<void(std::string const&)> const _apply; + std::function<void(Path const& path)> const _applyFile; + std::function<QueryResult(std::string const&)> const _retrieve; +}; + +#endif // UpdateFetcher_h__ diff --git a/src/server/shared/Utilities/ServiceWin32.cpp b/src/server/shared/Utilities/ServiceWin32.cpp index c73949fc6a3..3e5e416b1a3 100644 --- a/src/server/shared/Utilities/ServiceWin32.cpp +++ b/src/server/shared/Utilities/ServiceWin32.cpp @@ -255,7 +255,7 @@ bool WinServiceRun() if (!StartServiceCtrlDispatcher(serviceTable)) { - TC_LOG_ERROR("server.worldserver", "StartService Failed. Error [%u]", ::GetLastError()); + TC_LOG_ERROR("server.worldserver", "StartService Failed. Error [%u]", uint32(::GetLastError())); return false; } return true; diff --git a/src/server/shared/Utilities/StringFormat.h b/src/server/shared/Utilities/StringFormat.h new file mode 100644 index 00000000000..70d9aefb14d --- /dev/null +++ b/src/server/shared/Utilities/StringFormat.h @@ -0,0 +1,34 @@ +/* + * Copyright (C) 2008-2015 TrinityCore <http://www.trinitycore.org/> + * Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_STRING_FORMAT_H +#define TRINITYCORE_STRING_FORMAT_H + +#include <format.h> + +namespace Trinity +{ + //! Default TC string format function + template<typename... Args> + inline std::string StringFormat(const char* fmt, Args const&... args) + { + return fmt::sprintf(fmt, args...); + } +} + +#endif diff --git a/src/server/shared/Utilities/TaskScheduler.cpp b/src/server/shared/Utilities/TaskScheduler.cpp new file mode 100644 index 00000000000..4b261413fd9 --- /dev/null +++ b/src/server/shared/Utilities/TaskScheduler.cpp @@ -0,0 +1,208 @@ +/* + * Copyright (C) 2008-2015 TrinityCore <http://www.trinitycore.org/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "TaskScheduler.h" + +TaskScheduler& TaskScheduler::Update() +{ + _now = clock_t::now(); + Dispatch(); + return *this; +} + +TaskScheduler& TaskScheduler::Update(size_t const milliseconds) +{ + return Update(std::chrono::milliseconds(milliseconds)); +} + +TaskScheduler& TaskScheduler::Async(std::function<void()> const& callable) +{ + _asyncHolder.push(callable); + return *this; +} + +TaskScheduler& TaskScheduler::CancelAll() +{ + /// Clear the task holder + _task_holder.Clear(); + _asyncHolder = AsyncHolder(); + return *this; +} + +TaskScheduler& TaskScheduler::CancelGroup(group_t const group) +{ + _task_holder.RemoveIf([group](TaskContainer const& task) -> bool + { + return task->IsInGroup(group); + }); + return *this; +} + +TaskScheduler& TaskScheduler::CancelGroupsOf(std::vector<group_t> const& groups) +{ + std::for_each(groups.begin(), groups.end(), + std::bind(&TaskScheduler::CancelGroup, this, std::placeholders::_1)); + + return *this; +} + +TaskScheduler& TaskScheduler::InsertTask(TaskContainer task) +{ + _task_holder.Push(std::move(task)); + return *this; +} + +void TaskScheduler::Dispatch() +{ + // Process all asyncs + while (!_asyncHolder.empty()) + { + _asyncHolder.front()(); + _asyncHolder.pop(); + } + + while (!_task_holder.IsEmpty()) + { + if (_task_holder.First()->_end > _now) + break; + + // Perfect forward the context to the handler + // Use weak references to catch destruction before callbacks. + TaskContext context(_task_holder.Pop(), std::weak_ptr<TaskScheduler>(self_reference)); + + // Invoke the context + context.Invoke(); + } +} + +void TaskScheduler::TaskQueue::Push(TaskContainer&& task) +{ + container.insert(task); +} + +auto TaskScheduler::TaskQueue::Pop() -> TaskContainer +{ + TaskContainer result = *container.begin(); + container.erase(container.begin()); + return result; +} + +auto TaskScheduler::TaskQueue::First() const -> TaskContainer const& +{ + return *container.begin(); +} + +void TaskScheduler::TaskQueue::Clear() +{ + container.clear(); +} + +void TaskScheduler::TaskQueue::RemoveIf(std::function<bool(TaskContainer const&)> const& filter) +{ + for (auto itr = container.begin(); itr != container.end();) + if (filter(*itr)) + itr = container.erase(itr); + else + ++itr; +} + +void TaskScheduler::TaskQueue::ModifyIf(std::function<bool(TaskContainer const&)> const& filter) +{ + std::vector<TaskContainer> cache; + for (auto itr = container.begin(); itr != container.end();) + if (filter(*itr)) + { + cache.push_back(*itr); + itr = container.erase(itr); + } + else + ++itr; + + container.insert(cache.begin(), cache.end()); +} + +bool TaskScheduler::TaskQueue::IsEmpty() const +{ + return container.empty(); +} + +TaskContext& TaskContext::Dispatch(std::function<TaskScheduler&(TaskScheduler&)> const& apply) +{ + if (auto const owner = _owner.lock()) + apply(*owner); + + return *this; +} + +bool TaskContext::IsExpired() const +{ + return _owner.expired(); +} + +bool TaskContext::IsInGroup(TaskScheduler::group_t const group) const +{ + return _task->IsInGroup(group); +} + +TaskContext& TaskContext::SetGroup(TaskScheduler::group_t const group) +{ + _task->_group = group; + return *this; +} + +TaskContext& TaskContext::ClearGroup() +{ + _task->_group = boost::none; + return *this; +} + +TaskScheduler::repeated_t TaskContext::GetRepeatCounter() const +{ + return _task->_repeated; +} + +TaskContext& TaskContext::Async(std::function<void()> const& callable) +{ + return Dispatch(std::bind(&TaskScheduler::Async, std::placeholders::_1, callable)); +} + +TaskContext& TaskContext::CancelAll() +{ + return Dispatch(std::mem_fn(&TaskScheduler::CancelAll)); +} + +TaskContext& TaskContext::CancelGroup(TaskScheduler::group_t const group) +{ + return Dispatch(std::bind(&TaskScheduler::CancelGroup, std::placeholders::_1, group)); +} + +TaskContext& TaskContext::CancelGroupsOf(std::vector<TaskScheduler::group_t> const& groups) +{ + return Dispatch(std::bind(&TaskScheduler::CancelGroupsOf, std::placeholders::_1, std::cref(groups))); +} + +void TaskContext::AssertOnConsumed() const +{ + // This was adapted to TC to prevent static analysis tools from complaining. + // If you encounter this assertion check if you repeat a TaskContext more then 1 time! + ASSERT(!(*_consumed) && "Bad task logic, task context was consumed already!"); +} + +void TaskContext::Invoke() +{ + _task->_task(*this); +} diff --git a/src/server/shared/Utilities/TaskScheduler.h b/src/server/shared/Utilities/TaskScheduler.h new file mode 100644 index 00000000000..98e210e55b1 --- /dev/null +++ b/src/server/shared/Utilities/TaskScheduler.h @@ -0,0 +1,627 @@ +/* + * Copyright (C) 2008-2015 TrinityCore <http://www.trinitycore.org/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef _TASK_SCHEDULER_H_ +#define _TASK_SCHEDULER_H_ + +#include <algorithm> +#include <chrono> +#include <vector> +#include <queue> +#include <memory> +#include <utility> +#include <set> + +#include <boost/optional.hpp> + +#include "Util.h" + +class TaskContext; + +/// The TaskScheduler class provides the ability to schedule std::function's in the near future. +/// Use TaskScheduler::Update to update the scheduler. +/// Popular methods are: +/// * Schedule (Schedules a std::function which will be executed in the near future). +/// * Schedules an asynchronous function which will be executed at the next update tick. +/// * Cancel, Delay & Reschedule (Methods to manipulate already scheduled tasks). +/// Tasks are organized in groups (uint), multiple tasks can have the same group id, +/// you can provide a group or not, but keep in mind that you can only manipulate specific tasks through its group id! +/// Tasks callbacks use the function signature void(TaskContext) where TaskContext provides +/// access to the function schedule plan which makes it possible to repeat the task +/// with the same duration or a new one. +/// It also provides access to the repeat counter which is useful for task that repeat itself often +/// but behave different every time (spoken event dialogs for example). +class TaskScheduler +{ + friend class TaskContext; + + // Time definitions (use steady clock) + typedef std::chrono::steady_clock clock_t; + typedef clock_t::time_point timepoint_t; + typedef clock_t::duration duration_t; + + // Task group type + typedef uint32 group_t; + // Task repeated type + typedef uint32 repeated_t; + // Task handle type + typedef std::function<void(TaskContext)> task_handler_t; + + class Task + { + friend class TaskContext; + friend class TaskScheduler; + + timepoint_t _end; + duration_t _duration; + boost::optional<group_t> _group; + repeated_t _repeated; + task_handler_t _task; + + public: + // All Argument construct + Task(timepoint_t const& end, duration_t const& duration, boost::optional<group_t> const& group, + repeated_t const repeated, task_handler_t const& task) + : _end(end), _duration(duration), _group(group), _repeated(repeated), _task(task) { } + + // Minimal Argument construct + Task(timepoint_t const& end, duration_t const& duration, task_handler_t const& task) + : _end(end), _duration(duration), _group(boost::none), _repeated(0), _task(task) { } + + // Copy construct + Task(Task const&) = delete; + // Move construct + Task(Task&&) = delete; + // Copy Assign + Task& operator= (Task const&) = default; + // Move Assign + Task& operator= (Task&& right) = delete; + + // Order tasks by its end + inline bool operator< (Task const& other) const + { + return _end < other._end; + } + + inline bool operator> (Task const& other) const + { + return _end > other._end; + } + + // Compare tasks with its end + inline bool operator== (Task const& other) + { + return _end == other._end; + } + + // Returns true if the task is in the given group + inline bool IsInGroup(group_t const group) const + { + return _group == group; + } + }; + + typedef std::shared_ptr<Task> TaskContainer; + + /// Container which provides Task order, insert and reschedule operations. + struct Compare + { + bool operator() (TaskContainer const& left, TaskContainer const& right) + { + return (*left.get()) < (*right.get()); + }; + }; + + class TaskQueue + { + std::multiset<TaskContainer, Compare> container; + + public: + // Pushes the task in the container + void Push(TaskContainer&& task); + + /// Pops the task out of the container + TaskContainer Pop(); + + TaskContainer const& First() const; + + void Clear(); + + void RemoveIf(std::function<bool(TaskContainer const&)> const& filter); + + void ModifyIf(std::function<bool(TaskContainer const&)> const& filter); + + bool IsEmpty() const; + }; + + /// Contains a self reference to track if this object was deleted or not. + std::shared_ptr<TaskScheduler> self_reference; + + /// The current time point (now) + timepoint_t _now; + + /// The Task Queue which contains all task objects. + TaskQueue _task_holder; + + typedef std::queue<std::function<void()>> AsyncHolder; + + /// Contains all asynchronous tasks which will be invoked at + /// the next update tick. + AsyncHolder _asyncHolder; + +public: + TaskScheduler() : self_reference(this, [](TaskScheduler const*) { }), + _now(clock_t::now()) { } + + TaskScheduler(TaskScheduler const&) = delete; + TaskScheduler(TaskScheduler&&) = delete; + TaskScheduler& operator= (TaskScheduler const&) = delete; + TaskScheduler& operator= (TaskScheduler&&) = delete; + + /// Update the scheduler to the current time. + TaskScheduler& Update(); + + /// Update the scheduler with a difftime in ms. + TaskScheduler& Update(size_t const milliseconds); + + /// Update the scheduler with a difftime. + template<class _Rep, class _Period> + TaskScheduler& Update(std::chrono::duration<_Rep, _Period> const& difftime) + { + _now += difftime; + Dispatch(); + return *this; + } + + /// Schedule an callable function that is executed at the next update tick. + /// Its safe to modify the TaskScheduler from within the callable. + TaskScheduler& Async(std::function<void()> const& callable); + + /// Schedule an event with a fixed rate. + /// Never call this from within a task context! Use TaskContext::Schedule instead! + template<class _Rep, class _Period> + TaskScheduler& Schedule(std::chrono::duration<_Rep, _Period> const& time, + task_handler_t const& task) + { + return ScheduleAt(_now, time, task); + } + + /// Schedule an event with a fixed rate. + /// Never call this from within a task context! Use TaskContext::Schedule instead! + template<class _Rep, class _Period> + TaskScheduler& Schedule(std::chrono::duration<_Rep, _Period> const& time, + group_t const group, task_handler_t const& task) + { + return ScheduleAt(_now, time, group, task); + } + + /// Schedule an event with a randomized rate between min and max rate. + /// Never call this from within a task context! Use TaskContext::Schedule instead! + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskScheduler& Schedule(std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max, task_handler_t const& task) + { + return Schedule(RandomDurationBetween(min, max), task); + } + + /// Schedule an event with a fixed rate. + /// Never call this from within a task context! Use TaskContext::Schedule instead! + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskScheduler& Schedule(std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max, group_t const group, + task_handler_t const& task) + { + return Schedule(RandomDurationBetween(min, max), group, task); + } + + /// Cancels all tasks. + /// Never call this from within a task context! Use TaskContext::CancelAll instead! + TaskScheduler& CancelAll(); + + /// Cancel all tasks of a single group. + /// Never call this from within a task context! Use TaskContext::CancelGroup instead! + TaskScheduler& CancelGroup(group_t const group); + + /// Cancels all groups in the given std::vector. + /// Hint: Use std::initializer_list for this: "{1, 2, 3, 4}" + TaskScheduler& CancelGroupsOf(std::vector<group_t> const& groups); + + /// Delays all tasks with the given duration. + template<class _Rep, class _Period> + TaskScheduler& DelayAll(std::chrono::duration<_Rep, _Period> const& duration) + { + _task_holder.ModifyIf([&duration](TaskContainer const& task) -> bool + { + task->_end += duration; + return true; + }); + return *this; + } + + /// Delays all tasks with a random duration between min and max. + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskScheduler& DelayAll(std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max) + { + return DelayAll(RandomDurationBetween(min, max)); + } + + /// Delays all tasks of a group with the given duration. + template<class _Rep, class _Period> + TaskScheduler& DelayGroup(group_t const group, std::chrono::duration<_Rep, _Period> const& duration) + { + _task_holder.ModifyIf([&duration, group](TaskContainer const& task) -> bool + { + if (task->IsInGroup(group)) + { + task->_end += duration; + return true; + } + else + return false; + }); + return *this; + } + + /// Delays all tasks of a group with a random duration between min and max. + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskScheduler& DelayGroup(group_t const group, + std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max) + { + return DelayGroup(group, RandomDurationBetween(min, max)); + } + + /// Reschedule all tasks with a given duration. + template<class _Rep, class _Period> + TaskScheduler& RescheduleAll(std::chrono::duration<_Rep, _Period> const& duration) + { + auto const end = _now + duration; + _task_holder.ModifyIf([end](TaskContainer const& task) -> bool + { + task->_end = end; + return true; + }); + return *this; + } + + /// Reschedule all tasks with a random duration between min and max. + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskScheduler& RescheduleAll(std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max) + { + return RescheduleAll(RandomDurationBetween(min, max)); + } + + /// Reschedule all tasks of a group with the given duration. + template<class _Rep, class _Period> + TaskScheduler& RescheduleGroup(group_t const group, std::chrono::duration<_Rep, _Period> const& duration) + { + auto const end = _now + duration; + _task_holder.ModifyIf([end, group](TaskContainer const& task) -> bool + { + if (task->IsInGroup(group)) + { + task->_end = end; + return true; + } + else + return false; + }); + return *this; + } + + /// Reschedule all tasks of a group with a random duration between min and max. + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskScheduler& RescheduleGroup(group_t const group, + std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max) + { + return RescheduleGroup(group, RandomDurationBetween(min, max)); + } + +private: + /// Insert a new task to the enqueued tasks. + TaskScheduler& InsertTask(TaskContainer task); + + template<class _Rep, class _Period> + TaskScheduler& ScheduleAt(timepoint_t const& end, + std::chrono::duration<_Rep, _Period> const& time, task_handler_t const& task) + { + return InsertTask(TaskContainer(new Task(end + time, time, task))); + } + + /// Schedule an event with a fixed rate. + /// Never call this from within a task context! Use TaskContext::schedule instead! + template<class _Rep, class _Period> + TaskScheduler& ScheduleAt(timepoint_t const& end, + std::chrono::duration<_Rep, _Period> const& time, + group_t const group, task_handler_t const& task) + { + static repeated_t const DEFAULT_REPEATED = 0; + return InsertTask(TaskContainer(new Task(end + time, time, group, DEFAULT_REPEATED, task))); + } + + // Returns a random duration between min and max + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + static std::chrono::milliseconds + RandomDurationBetween(std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max) + { + auto const milli_min = std::chrono::duration_cast<std::chrono::milliseconds>(min); + auto const milli_max = std::chrono::duration_cast<std::chrono::milliseconds>(max); + + // TC specific: use SFMT URandom + return std::chrono::milliseconds(urand(milli_min.count(), milli_max.count())); + } + + /// Dispatch remaining tasks + void Dispatch(); +}; + +class TaskContext +{ + friend class TaskScheduler; + + /// Associated task + TaskScheduler::TaskContainer _task; + + /// Owner + std::weak_ptr<TaskScheduler> _owner; + + /// Marks the task as consumed + std::shared_ptr<bool> _consumed; + + /// Dispatches an action safe on the TaskScheduler + TaskContext& Dispatch(std::function<TaskScheduler&(TaskScheduler&)> const& apply); + +public: + // Empty constructor + TaskContext() + : _task(), _owner(), _consumed(std::make_shared<bool>(true)) { } + + // Construct from task and owner + explicit TaskContext(TaskScheduler::TaskContainer&& task, std::weak_ptr<TaskScheduler>&& owner) + : _task(task), _owner(owner), _consumed(std::make_shared<bool>(false)) { } + + // Copy construct + TaskContext(TaskContext const& right) + : _task(right._task), _owner(right._owner), _consumed(right._consumed) { } + + // Move construct + TaskContext(TaskContext&& right) + : _task(std::move(right._task)), _owner(std::move(right._owner)), _consumed(std::move(right._consumed)) { } + + // Copy assign + TaskContext& operator= (TaskContext const& right) + { + _task = right._task; + _owner = right._owner; + _consumed = right._consumed; + return *this; + } + + // Move assign + TaskContext& operator= (TaskContext&& right) + { + _task = std::move(right._task); + _owner = std::move(right._owner); + _consumed = std::move(right._consumed); + return *this; + } + + /// Returns true if the owner was deallocated and this context has expired. + bool IsExpired() const; + + /// Returns true if the event is in the given group + bool IsInGroup(TaskScheduler::group_t const group) const; + + /// Sets the event in the given group + TaskContext& SetGroup(TaskScheduler::group_t const group); + + /// Removes the group from the event + TaskContext& ClearGroup(); + + /// Returns the repeat counter which increases every time the task is repeated. + TaskScheduler::repeated_t GetRepeatCounter() const; + + /// Repeats the event and sets a new duration. + /// std::chrono::seconds(5) for example. + /// This will consume the task context, its not possible to repeat the task again + /// from the same task context! + template<class _Rep, class _Period> + TaskContext& Repeat(std::chrono::duration<_Rep, _Period> const& duration) + { + AssertOnConsumed(); + + // Set new duration, in-context timing and increment repeat counter + _task->_duration = duration; + _task->_end += duration; + _task->_repeated += 1; + (*_consumed) = true; + return Dispatch(std::bind(&TaskScheduler::InsertTask, std::placeholders::_1, _task)); + } + + /// Repeats the event with the same duration. + /// This will consume the task context, its not possible to repeat the task again + /// from the same task context! + TaskContext& Repeat() + { + return Repeat(_task->_duration); + } + + /// Repeats the event and set a new duration that is randomized between min and max. + /// std::chrono::seconds(5) for example. + /// This will consume the task context, its not possible to repeat the task again + /// from the same task context! + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskContext& Repeat(std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max) + { + return Repeat(TaskScheduler::RandomDurationBetween(min, max)); + } + + /// Schedule a callable function that is executed at the next update tick from within the context. + /// Its safe to modify the TaskScheduler from within the callable. + TaskContext& Async(std::function<void()> const& callable); + + /// Schedule an event with a fixed rate from within the context. + /// Its possible that the new event is executed immediately! + /// Use TaskScheduler::Async to create a task + /// which will be called at the next update tick. + template<class _Rep, class _Period> + TaskContext& Schedule(std::chrono::duration<_Rep, _Period> const& time, + TaskScheduler::task_handler_t const& task) + { + auto const end = _task->_end; + return Dispatch([end, time, task](TaskScheduler& scheduler) -> TaskScheduler& + { + return scheduler.ScheduleAt<_Rep, _Period>(end, time, task); + }); + } + + /// Schedule an event with a fixed rate from within the context. + /// Its possible that the new event is executed immediately! + /// Use TaskScheduler::Async to create a task + /// which will be called at the next update tick. + template<class _Rep, class _Period> + TaskContext& Schedule(std::chrono::duration<_Rep, _Period> const& time, + TaskScheduler::group_t const group, TaskScheduler::task_handler_t const& task) + { + auto const end = _task->_end; + return Dispatch([end, time, group, task](TaskScheduler& scheduler) -> TaskScheduler& + { + return scheduler.ScheduleAt<_Rep, _Period>(end, time, group, task); + }); + } + + /// Schedule an event with a randomized rate between min and max rate from within the context. + /// Its possible that the new event is executed immediately! + /// Use TaskScheduler::Async to create a task + /// which will be called at the next update tick. + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskContext& Schedule(std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max, TaskScheduler::task_handler_t const& task) + { + return Schedule(TaskScheduler::RandomDurationBetween(min, max), task); + } + + /// Schedule an event with a randomized rate between min and max rate from within the context. + /// Its possible that the new event is executed immediately! + /// Use TaskScheduler::Async to create a task + /// which will be called at the next update tick. + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskContext& Schedule(std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max, TaskScheduler::group_t const group, + TaskScheduler::task_handler_t const& task) + { + return Schedule(TaskScheduler::RandomDurationBetween(min, max), group, task); + } + + /// Cancels all tasks from within the context. + TaskContext& CancelAll(); + + /// Cancel all tasks of a single group from within the context. + TaskContext& CancelGroup(TaskScheduler::group_t const group); + + /// Cancels all groups in the given std::vector from within the context. + /// Hint: Use std::initializer_list for this: "{1, 2, 3, 4}" + TaskContext& CancelGroupsOf(std::vector<TaskScheduler::group_t> const& groups); + + /// Delays all tasks with the given duration from within the context. + template<class _Rep, class _Period> + TaskContext& DelayAll(std::chrono::duration<_Rep, _Period> const& duration) + { + return Dispatch(std::bind(&TaskScheduler::DelayAll<_Rep, _Period>, std::placeholders::_1, duration)); + } + + /// Delays all tasks with a random duration between min and max from within the context. + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskContext& DelayAll(std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max) + { + return DelayAll(TaskScheduler::RandomDurationBetween(min, max)); + } + + /// Delays all tasks of a group with the given duration from within the context. + template<class _Rep, class _Period> + TaskContext& DelayGroup(TaskScheduler::group_t const group, std::chrono::duration<_Rep, _Period> const& duration) + { + return Dispatch(std::bind(&TaskScheduler::DelayGroup<_Rep, _Period>, std::placeholders::_1, group, duration)); + } + + /// Delays all tasks of a group with a random duration between min and max from within the context. + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskContext& DelayGroup(TaskScheduler::group_t const group, + std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max) + { + return DelayGroup(group, TaskScheduler::RandomDurationBetween(min, max)); + } + + /// Reschedule all tasks with the given duration. + template<class _Rep, class _Period> + TaskContext& RescheduleAll(std::chrono::duration<_Rep, _Period> const& duration) + { + return Dispatch(std::bind(&TaskScheduler::RescheduleAll, std::placeholders::_1, duration)); + } + + /// Reschedule all tasks with a random duration between min and max. + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskContext& RescheduleAll(std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max) + { + return RescheduleAll(TaskScheduler::RandomDurationBetween(min, max)); + } + + /// Reschedule all tasks of a group with the given duration. + template<class _Rep, class _Period> + TaskContext& RescheduleGroup(TaskScheduler::group_t const group, std::chrono::duration<_Rep, _Period> const& duration) + { + return Dispatch(std::bind(&TaskScheduler::RescheduleGroup<_Rep, _Period>, std::placeholders::_1, group, duration)); + } + + /// Reschedule all tasks of a group with a random duration between min and max. + template<class _RepLeft, class _PeriodLeft, class _RepRight, class _PeriodRight> + TaskContext& RescheduleGroup(TaskScheduler::group_t const group, + std::chrono::duration<_RepLeft, _PeriodLeft> const& min, + std::chrono::duration<_RepRight, _PeriodRight> const& max) + { + return RescheduleGroup(group, TaskScheduler::RandomDurationBetween(min, max)); + } + +private: + /// Asserts if the task was consumed already. + void AssertOnConsumed() const; + + /// Invokes the associated hook of the task. + void Invoke(); +}; + +/// Milliseconds shorthand typedef. +typedef std::chrono::milliseconds Milliseconds; + +/// Seconds shorthand typedef. +typedef std::chrono::seconds Seconds; + +/// Minutes shorthand typedef. +typedef std::chrono::minutes Minutes; + +/// Hours shorthand typedef. +typedef std::chrono::hours Hours; + +#endif /// _TASK_SCHEDULER_H_ diff --git a/src/server/shared/Utilities/Util.h b/src/server/shared/Utilities/Util.h index 0a4e5e4a4a0..cd523511c1d 100644 --- a/src/server/shared/Utilities/Util.h +++ b/src/server/shared/Utilities/Util.h @@ -688,8 +688,9 @@ class EventMap /** * @name RepeatEvent - * @brief Repeats the mostly recently executed event. - * @param time Time until the event occurs. Equivalent to Repeat(urand(minTime, maxTime). + * @brief Repeats the mostly recently executed event, Equivalent to Repeat(urand(minTime, maxTime). + * @param minTime Minimum time until the event occurs. + * @param maxTime Maximum time until the event occurs. */ void Repeat(uint32 minTime, uint32 maxTime) { @@ -839,7 +840,7 @@ class EventMap /** * @name GetTimeUntilEvent * @brief Returns time in milliseconds until next event. - * @param Id of the event. + * @param eventId of the event. * @return Time of next event. */ uint32 GetTimeUntilEvent(uint32 eventId) const; |
