diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/server/authserver/Server/AuthSession.cpp | 81 | ||||
-rw-r--r-- | src/server/authserver/Server/AuthSession.h | 4 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocket.cpp | 109 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocket.h | 4 | ||||
-rw-r--r-- | src/server/shared/Common.h | 40 | ||||
-rw-r--r-- | src/server/shared/Database/DatabaseWorkerPool.h | 95 | ||||
-rw-r--r-- | src/server/shared/Database/MySQLConnection.cpp | 8 | ||||
-rw-r--r-- | src/server/shared/Networking/MessageBuffer.h | 93 | ||||
-rw-r--r-- | src/server/shared/Networking/Socket.h | 129 | ||||
-rw-r--r-- | src/server/shared/Packets/ByteBuffer.cpp | 5 | ||||
-rw-r--r-- | src/server/shared/Packets/ByteBuffer.h | 12 | ||||
-rw-r--r-- | src/server/shared/Packets/WorldPacket.h | 2 | ||||
-rw-r--r-- | src/server/worldserver/Main.cpp | 57 | ||||
-rw-r--r-- | src/tools/mmaps_generator/PathGenerator.cpp | 2 |
14 files changed, 374 insertions, 267 deletions
diff --git a/src/server/authserver/Server/AuthSession.cpp b/src/server/authserver/Server/AuthSession.cpp index cdd8298174f..76f8b8c27b0 100644 --- a/src/server/authserver/Server/AuthSession.cpp +++ b/src/server/authserver/Server/AuthSession.cpp @@ -21,6 +21,7 @@ #include "AuthCodes.h" #include "Database/DatabaseEnv.h" #include "SHA1.h" +#include "TOTP.h" #include "openssl/crypto.h" #include "Configuration/Config.h" #include "RealmList.h" @@ -52,7 +53,6 @@ enum eStatus typedef struct AUTH_LOGON_CHALLENGE_C { - uint8 cmd; uint8 error; uint16 size; uint8 gamename[4]; @@ -71,7 +71,6 @@ typedef struct AUTH_LOGON_CHALLENGE_C typedef struct AUTH_LOGON_PROOF_C { - uint8 cmd; uint8 A[32]; uint8 M1[20]; uint8 crc_hash[20]; @@ -99,7 +98,6 @@ typedef struct AUTH_LOGON_PROOF_S_OLD typedef struct AUTH_RECONNECT_PROOF_C { - uint8 cmd; uint8 R1[16]; uint8 R2[20]; uint8 R3[20]; @@ -114,10 +112,10 @@ enum class BufferSizes : uint32 SRP_6_S = 0x20, }; -#define REALM_LIST_PACKET_SIZE 5 -#define XFER_ACCEPT_SIZE 1 -#define XFER_RESUME_SIZE 9 -#define XFER_CANCEL_SIZE 1 +#define REALM_LIST_PACKET_SIZE 4 +#define XFER_ACCEPT_SIZE 0 +#define XFER_RESUME_SIZE 8 +#define XFER_CANCEL_SIZE 0 std::unordered_map<uint8, AuthHandler> AuthSession::InitHandlers() { @@ -137,44 +135,36 @@ std::unordered_map<uint8, AuthHandler> AuthSession::InitHandlers() std::unordered_map<uint8, AuthHandler> const Handlers = AuthSession::InitHandlers(); -void AuthSession::ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) +void AuthSession::ReadHeaderHandler() { - if (!error && transferedBytes == 1) + uint8 cmd = GetHeaderBuffer()[0]; + auto itr = Handlers.find(cmd); + if (itr != Handlers.end()) { - uint8 cmd = GetReadBuffer()[0]; - auto itr = Handlers.find(cmd); - if (itr != Handlers.end()) + // Handle dynamic size packet + if (cmd == AUTH_LOGON_CHALLENGE || cmd == AUTH_RECONNECT_CHALLENGE) { - // Handle dynamic size packet - if (cmd == AUTH_LOGON_CHALLENGE || cmd == AUTH_RECONNECT_CHALLENGE) - { - ReadData(sizeof(uint8) + sizeof(uint16), sizeof(cmd)); //error + size - sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer()); + ReadData(sizeof(uint8) + sizeof(uint16)); //error + size + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetDataBuffer()); - AsyncReadData(challenge->size, sizeof(uint8) + sizeof(uint8) + sizeof(uint16)); // cmd + error + size - } - else - AsyncReadData(itr->second.packetSize, sizeof(uint8)); + AsyncReadData(challenge->size); } + else + AsyncReadData(itr->second.packetSize); } else CloseSocket(); } -void AuthSession::ReadDataHandler(boost::system::error_code error, size_t transferedBytes) +void AuthSession::ReadDataHandler() { - if (!error && transferedBytes > 0) + if (!(*this.*Handlers.at(GetHeaderBuffer()[0]).handler)()) { - if (!(*this.*Handlers.at(GetReadBuffer()[0]).handler)()) - { - CloseSocket(); - return; - } - - AsyncReadHeader(); - } - else CloseSocket(); + return; + } + + AsyncReadHeader(); } void AuthSession::AsyncWrite(ByteBuffer& packet) @@ -191,7 +181,7 @@ void AuthSession::AsyncWrite(ByteBuffer& packet) bool AuthSession::HandleLogonChallenge() { - sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer()); + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetDataBuffer()); //TC_LOG_DEBUG("server.authserver", "[AuthChallenge] got full packet, %#04x bytes", challenge->size); TC_LOG_DEBUG("server.authserver", "[AuthChallenge] name(%d): '%s'", challenge->I_len, challenge->I); @@ -410,7 +400,7 @@ bool AuthSession::HandleLogonProof() TC_LOG_DEBUG("server.authserver", "Entering _HandleLogonProof"); // Read the packet - sAuthLogonProof_C *logonProof = reinterpret_cast<sAuthLogonProof_C*>(GetReadBuffer()); + sAuthLogonProof_C *logonProof = reinterpret_cast<sAuthLogonProof_C*>(GetDataBuffer()); // If the client has no valid version if (_expversion == NO_VALID_EXP_FLAG) @@ -522,17 +512,12 @@ bool AuthSession::HandleLogonProof() // Check auth token if ((logonProof->securityFlags & 0x04) || !_tokenKey.empty()) { - // TODO To be fixed - - /* - uint8 size; - socket().recv((char*)&size, 1); - char* token = new char[size + 1]; - token[size] = '\0'; - socket().recv(token, size); - unsigned int validToken = TOTP::GenerateToken(_tokenKey.c_str()); - unsigned int incomingToken = atoi(token); - delete[] token; + ReadData(1); + uint8 size = *(GetDataBuffer() + sizeof(sAuthLogonProof_C)); + ReadData(size); + std::string token(reinterpret_cast<char*>(GetDataBuffer() + sizeof(sAuthLogonProof_C) + sizeof(size)), size); + uint32 validToken = TOTP::GenerateToken(_tokenKey.c_str()); + uint32 incomingToken = atoi(token.c_str()); if (validToken != incomingToken) { ByteBuffer packet; @@ -542,7 +527,7 @@ bool AuthSession::HandleLogonProof() packet << uint8(0); AsyncWrite(packet); return false; - }*/ + } } ByteBuffer packet; @@ -650,7 +635,7 @@ bool AuthSession::HandleLogonProof() bool AuthSession::HandleReconnectChallenge() { TC_LOG_DEBUG("server.authserver", "Entering _HandleReconnectChallenge"); - sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer()); + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetDataBuffer()); //TC_LOG_DEBUG("server.authserver", "[AuthChallenge] got full packet, %#04x bytes", challenge->size); TC_LOG_DEBUG("server.authserver", "[AuthChallenge] name(%d): '%s'", challenge->I_len, challenge->I); @@ -701,7 +686,7 @@ bool AuthSession::HandleReconnectChallenge() bool AuthSession::HandleReconnectProof() { TC_LOG_DEBUG("server.authserver", "Entering _HandleReconnectProof"); - sAuthReconnectProof_C *reconnectProof = reinterpret_cast<sAuthReconnectProof_C*>(GetReadBuffer()); + sAuthReconnectProof_C *reconnectProof = reinterpret_cast<sAuthReconnectProof_C*>(GetDataBuffer()); if (_login.empty() || !_reconnectProof.GetNumBytes() || !K.GetNumBytes()) return false; diff --git a/src/server/authserver/Server/AuthSession.h b/src/server/authserver/Server/AuthSession.h index 5a05ee6f8e9..3497e3a030c 100644 --- a/src/server/authserver/Server/AuthSession.h +++ b/src/server/authserver/Server/AuthSession.h @@ -53,8 +53,8 @@ public: void AsyncWrite(ByteBuffer& packet); protected: - void ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) override; - void ReadDataHandler(boost::system::error_code error, size_t transferedBytes) override; + void ReadHeaderHandler() override; + void ReadDataHandler() override; private: bool HandleLogonChallenge(); diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp index 65a424d5d75..046cdc0acd3 100644 --- a/src/server/game/Server/WorldSocket.cpp +++ b/src/server/game/Server/WorldSocket.cpp @@ -54,89 +54,72 @@ void WorldSocket::HandleSendAuthSession() AsyncWrite(packet); } -void WorldSocket::ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) +void WorldSocket::ReadHeaderHandler() { - if (!error && transferedBytes == sizeof(ClientPktHeader)) - { - _authCrypt.DecryptRecv(GetReadBuffer(), sizeof(ClientPktHeader)); + _authCrypt.DecryptRecv(GetHeaderBuffer(), sizeof(ClientPktHeader)); - ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetReadBuffer()); - EndianConvertReverse(header->size); - EndianConvert(header->cmd); + ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetHeaderBuffer()); + EndianConvertReverse(header->size); + EndianConvert(header->cmd); - AsyncReadData(header->size - sizeof(header->cmd), sizeof(ClientPktHeader)); - } - else - CloseSocket(); + AsyncReadData(header->size - sizeof(header->cmd)); } -void WorldSocket::ReadDataHandler(boost::system::error_code error, size_t transferedBytes) +void WorldSocket::ReadDataHandler() { - ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetReadBuffer()); + ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetHeaderBuffer()); - if (!error && transferedBytes == (header->size - sizeof(header->cmd))) - { - header->size -= sizeof(header->cmd); - - uint16 opcode = uint16(header->cmd); + header->size -= sizeof(header->cmd); - std::string opcodeName = GetOpcodeNameForLogging(opcode); + uint16 opcode = uint16(header->cmd); - WorldPacket packet(opcode, header->size); + std::string opcodeName = GetOpcodeNameForLogging(opcode); - if (header->size > 0) - { - packet.resize(header->size); + WorldPacket packet(opcode, MoveData()); - std::memcpy(packet.contents(), &(GetReadBuffer()[sizeof(ClientPktHeader)]), header->size); - } - - if (sPacketLog->CanLogPacket()) - sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort()); + if (sPacketLog->CanLogPacket()) + sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort()); - TC_LOG_TRACE("network.opcode", "C->S: %s %s", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress().to_string()).c_str(), GetOpcodeNameForLogging(opcode).c_str()); + TC_LOG_TRACE("network.opcode", "C->S: %s %s", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress().to_string()).c_str(), opcodeName.c_str()); - switch (opcode) - { - case CMSG_PING: - HandlePing(packet); + switch (opcode) + { + case CMSG_PING: + HandlePing(packet); + break; + case CMSG_AUTH_SESSION: + if (_worldSession) + { + TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str()); break; - case CMSG_AUTH_SESSION: - if (_worldSession) - { - TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str()); - break; - } + } - sScriptMgr->OnPacketReceive(shared_from_this(), packet); - HandleAuthSession(packet); - break; - case CMSG_KEEP_ALIVE: - TC_LOG_DEBUG("network", "%s", opcodeName.c_str()); - sScriptMgr->OnPacketReceive(shared_from_this(), packet); - break; - default: + sScriptMgr->OnPacketReceive(shared_from_this(), packet); + HandleAuthSession(packet); + break; + case CMSG_KEEP_ALIVE: + TC_LOG_DEBUG("network", "%s", opcodeName.c_str()); + sScriptMgr->OnPacketReceive(shared_from_this(), packet); + break; + default: + { + if (!_worldSession) { - if (!_worldSession) - { - TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode)); - break; - } - - // Our Idle timer will reset on any non PING opcodes. - // Catches people idling on the login screen and any lingering ingame connections. - _worldSession->ResetTimeOutTime(); - - // Copy the packet to the heap before enqueuing - _worldSession->QueuePacket(new WorldPacket(packet)); + TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode)); break; } - } - AsyncReadHeader(); + // Our Idle timer will reset on any non PING opcodes. + // Catches people idling on the login screen and any lingering ingame connections. + _worldSession->ResetTimeOutTime(); + + // Copy the packet to the heap before enqueuing + _worldSession->QueuePacket(new WorldPacket(std::move(packet))); + break; + } } - else - CloseSocket(); + + AsyncReadHeader(); } void WorldSocket::AsyncWrite(WorldPacket& packet) diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h index 7275da5ff29..8d452677650 100644 --- a/src/server/game/Server/WorldSocket.h +++ b/src/server/game/Server/WorldSocket.h @@ -108,8 +108,8 @@ public: void AsyncWrite(WorldPacket& packet); protected: - void ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) override; - void ReadDataHandler(boost::system::error_code error, size_t transferedBytes) override; + void ReadHeaderHandler() override; + void ReadDataHandler() override; private: void HandleSendAuthSession(); diff --git a/src/server/shared/Common.h b/src/server/shared/Common.h index 0a1389c1f38..ab268835046 100644 --- a/src/server/shared/Common.h +++ b/src/server/shared/Common.h @@ -19,46 +19,6 @@ #ifndef TRINITYCORE_COMMON_H #define TRINITYCORE_COMMON_H -// config.h needs to be included 1st -/// @todo this thingy looks like hack, but its not, need to -// make separate header however, because It makes mess here. -#ifdef HAVE_CONFIG_H -// Remove Some things that we will define -// This is in case including another config.h -// before trinity config.h -#ifdef PACKAGE -#undef PACKAGE -#endif //PACKAGE -#ifdef PACKAGE_BUGREPORT -#undef PACKAGE_BUGREPORT -#endif //PACKAGE_BUGREPORT -#ifdef PACKAGE_NAME -#undef PACKAGE_NAME -#endif //PACKAGE_NAME -#ifdef PACKAGE_STRING -#undef PACKAGE_STRING -#endif //PACKAGE_STRING -#ifdef PACKAGE_TARNAME -#undef PACKAGE_TARNAME -#endif //PACKAGE_TARNAME -#ifdef PACKAGE_VERSION -#undef PACKAGE_VERSION -#endif //PACKAGE_VERSION -#ifdef VERSION -#undef VERSION -#endif //VERSION - -# include "Config.h" - -#undef PACKAGE -#undef PACKAGE_BUGREPORT -#undef PACKAGE_NAME -#undef PACKAGE_STRING -#undef PACKAGE_TARNAME -#undef PACKAGE_VERSION -#undef VERSION -#endif //HAVE_CONFIG_H - #include "Define.h" #include <unordered_map> diff --git a/src/server/shared/Database/DatabaseWorkerPool.h b/src/server/shared/Database/DatabaseWorkerPool.h index 39f1a8da3c2..e95dfc1e484 100644 --- a/src/server/shared/Database/DatabaseWorkerPool.h +++ b/src/server/shared/Database/DatabaseWorkerPool.h @@ -45,6 +45,14 @@ class PingOperation : public SQLOperation template <class T> class DatabaseWorkerPool { + private: + enum InternalIndex + { + IDX_ASYNC, + IDX_SYNCH, + IDX_SIZE + }; + public: /* Activity state */ DatabaseWorkerPool() : _connectionInfo(NULL) @@ -74,34 +82,17 @@ class DatabaseWorkerPool TC_LOG_INFO("sql.driver", "Opening DatabasePool '%s'. Asynchronous connections: %u, synchronous connections: %u.", GetDatabaseName(), async_threads, synch_threads); - //! Open asynchronous connections (delayed operations) - _connections[IDX_ASYNC].resize(async_threads); - for (uint8 i = 0; i < async_threads; ++i) - { - T* t = new T(_queue, *_connectionInfo); - res &= t->Open(); - if (res) // only check mysql version if connection is valid - WPFatal(mysql_get_server_version(t->GetHandle()) >= MIN_MYSQL_SERVER_VERSION, "TrinityCore does not support MySQL versions below 5.1"); - _connections[IDX_ASYNC][i] = t; - ++_connectionCount[IDX_ASYNC]; - } + res = OpenConnections(IDX_ASYNC, async_threads); - //! Open synchronous connections (direct, blocking operations) - _connections[IDX_SYNCH].resize(synch_threads); - for (uint8 i = 0; i < synch_threads; ++i) - { - T* t = new T(*_connectionInfo); - res &= t->Open(); - _connections[IDX_SYNCH][i] = t; - ++_connectionCount[IDX_SYNCH]; - } + if (!res) + return res; + + res = OpenConnections(IDX_SYNCH, synch_threads); if (res) TC_LOG_INFO("sql.driver", "DatabasePool '%s' opened successfully. %u total connections running.", GetDatabaseName(), (_connectionCount[IDX_SYNCH] + _connectionCount[IDX_ASYNC])); - else - TC_LOG_ERROR("sql.driver", "DatabasePool %s NOT opened. There were errors opening the MySQL connections. Check your SQLDriverLogFile " - "for specific errors. Read wiki at http://collab.kpsn.org/display/tc/TrinityCore+Home", GetDatabaseName()); + return res; } @@ -112,8 +103,6 @@ class DatabaseWorkerPool for (uint8 i = 0; i < _connectionCount[IDX_ASYNC]; ++i) { T* t = _connections[IDX_ASYNC][i]; - DatabaseWorker* worker = t->m_worker; - delete worker; t->Close(); //! Closes the actualy MySQL connection. } @@ -442,7 +431,7 @@ class DatabaseWorkerPool if (str.empty()) return; - char* buf = new char[str.size()*2+1]; + char* buf = new char[str.size() * 2 + 1]; EscapeString(buf, str.c_str(), str.size()); str = buf; delete[] buf; @@ -470,6 +459,52 @@ class DatabaseWorkerPool } private: + bool OpenConnections(InternalIndex type, uint8 numConnections) + { + _connections[type].resize(numConnections); + for (uint8 i = 0; i < numConnections; ++i) + { + T* t; + + if (type == IDX_ASYNC) + t = new T(_queue, *_connectionInfo); + else if (type == IDX_SYNCH) + t = new T(*_connectionInfo); + + _connections[type][i] = t; + ++_connectionCount[type]; + + bool res = t->Open(); + + if (res) + { + if (mysql_get_server_version(t->GetHandle()) < MIN_MYSQL_SERVER_VERSION) + { + TC_LOG_ERROR("sql.driver", "TrinityCore does not support MySQL versions below 5.1"); + res = false; + } + } + + // Failed to open a connection or invalid version, abort and cleanup + if (!res) + { + TC_LOG_ERROR("sql.driver", "DatabasePool %s NOT opened. There were errors opening the MySQL connections. Check your SQLDriverLogFile " + "for specific errors. Read wiki at http://collab.kpsn.org/display/tc/TrinityCore+Home", GetDatabaseName()); + + while (_connectionCount[type] != 0) + { + T* t = _connections[type][i--]; + delete t; + --_connectionCount[type]; + } + + return false; + } + } + + return true; + } + unsigned long EscapeString(char *to, const char *from, unsigned long length) { if (!to || !from || !length) @@ -507,14 +542,6 @@ class DatabaseWorkerPool return _connectionInfo->database.c_str(); } - private: - enum _internalIndex - { - IDX_ASYNC, - IDX_SYNCH, - IDX_SIZE - }; - ProducerConsumerQueue<SQLOperation*>* _queue; //! Queue shared by async worker threads. std::vector< std::vector<T*> > _connections; uint32 _connectionCount[2]; //! Counter of MySQL connections; diff --git a/src/server/shared/Database/MySQLConnection.cpp b/src/server/shared/Database/MySQLConnection.cpp index e9fc20aef82..4e46ff0e3a1 100644 --- a/src/server/shared/Database/MySQLConnection.cpp +++ b/src/server/shared/Database/MySQLConnection.cpp @@ -57,12 +57,14 @@ m_connectionFlags(CONNECTION_ASYNC) MySQLConnection::~MySQLConnection() { - ASSERT (m_Mysql); /// MySQL context must be present at this point - for (size_t i = 0; i < m_stmts.size(); ++i) delete m_stmts[i]; - mysql_close(m_Mysql); + if (m_Mysql) + mysql_close(m_Mysql); + + if (m_worker) + delete m_worker; } void MySQLConnection::Close() diff --git a/src/server/shared/Networking/MessageBuffer.h b/src/server/shared/Networking/MessageBuffer.h new file mode 100644 index 00000000000..fff94b86c1e --- /dev/null +++ b/src/server/shared/Networking/MessageBuffer.h @@ -0,0 +1,93 @@ +/* +* Copyright (C) 2008-2014 TrinityCore <http://www.trinitycore.org/> +* +* This program is free software; you can redistribute it and/or modify it +* under the terms of the GNU General Public License as published by the +* Free Software Foundation; either version 2 of the License, or (at your +* option) any later version. +* +* This program is distributed in the hope that it will be useful, but WITHOUT +* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for +* more details. +* +* You should have received a copy of the GNU General Public License along +* with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +#ifndef __MESSAGEBUFFER_H_ +#define __MESSAGEBUFFER_H_ + +#include "Define.h" +#include <vector> + +class MessageBuffer +{ + typedef std::vector<uint8>::size_type size_type; + +public: + MessageBuffer() : _wpos(0), _storage() { } + + MessageBuffer(MessageBuffer const& right) : _wpos(right._wpos), _storage(right._storage) { } + + MessageBuffer(MessageBuffer&& right) : _wpos(right._wpos), _storage(right.Move()) { } + + void Reset() + { + _storage.clear(); + _wpos = 0; + } + + bool IsMessageReady() const { return _wpos == _storage.size(); } + + size_type GetMissingSize() const { return _storage.size() - _wpos; } + + uint8* Data() { return _storage.data(); } + + void Grow(size_type bytes) + { + _storage.resize(_storage.size() + bytes); + } + + uint8* GetWritePointer() { return &_storage[_wpos]; } + + void WriteCompleted(size_type bytes) { _wpos += bytes; } + + void ResetWritePointer() { _wpos = 0; } + + size_type GetSize() { return _storage.size(); } + + std::vector<uint8>&& Move() + { + _wpos = 0; + return std::move(_storage); + } + + MessageBuffer& operator=(MessageBuffer& right) + { + if (this != &right) + { + _wpos = right._wpos; + _storage = right._storage; + } + + return *this; + } + + MessageBuffer& operator=(MessageBuffer&& right) + { + if (this != &right) + { + _wpos = right._wpos; + _storage = right.Move(); + } + + return *this; + } + +private: + size_type _wpos; + std::vector<uint8> _storage; +}; + +#endif /* __MESSAGEBUFFER_H_ */ diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h index 6bf67e06d9c..f86890ed7ef 100644 --- a/src/server/shared/Networking/Socket.h +++ b/src/server/shared/Networking/Socket.h @@ -18,8 +18,9 @@ #ifndef __SOCKET_H__ #define __SOCKET_H__ -#include "Define.h" +#include "MessageBuffer.h" #include "Log.h" +#include <atomic> #include <vector> #include <mutex> #include <queue> @@ -28,19 +29,22 @@ #include <type_traits> #include <boost/asio/ip/tcp.hpp> #include <boost/asio/write.hpp> +#include <boost/asio/read.hpp> using boost::asio::ip::tcp; +#define READ_BLOCK_SIZE 4096 + template<class T, class PacketType> class Socket : public std::enable_shared_from_this<T> { typedef typename std::conditional<std::is_pointer<PacketType>::value, PacketType, PacketType const&>::type WritePacketType; public: - Socket(tcp::socket&& socket, std::size_t headerSize) : _socket(std::move(socket)), _headerSize(headerSize) + Socket(tcp::socket&& socket, std::size_t headerSize) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()), + _remotePort(_socket.remote_endpoint().port()), _readHeaderBuffer(), _readDataBuffer(), _closed(false) { - _remotePort = _socket.remote_endpoint().port(); - _remoteAddress = _socket.remote_endpoint().address(); + _readHeaderBuffer.Grow(headerSize); } virtual void Start() = 0; @@ -57,25 +61,39 @@ public: void AsyncReadHeader() { - _socket.async_read_some(boost::asio::buffer(_readBuffer, _headerSize), std::bind(&Socket<T, PacketType>::ReadHeaderHandlerInternal, this->shared_from_this(), - std::placeholders::_1, std::placeholders::_2)); + _readHeaderBuffer.ResetWritePointer(); + _readDataBuffer.Reset(); + + AsyncReadMissingHeaderData(); } - void AsyncReadData(std::size_t size, std::size_t bufferOffset) + void AsyncReadData(std::size_t size) { - _socket.async_read_some(boost::asio::buffer(&_readBuffer[bufferOffset], size), std::bind(&Socket<T, PacketType>::ReadDataHandlerInternal, this->shared_from_this(), - std::placeholders::_1, std::placeholders::_2)); + if (!size) + { + // if this is a packet with 0 length body just invoke handler directly + ReadDataHandler(); + return; + } + + _readDataBuffer.Grow(size); + AsyncReadMissingData(); } - void ReadData(std::size_t size, std::size_t bufferOffset) + void ReadData(std::size_t size) { boost::system::error_code error; - _socket.read_some(boost::asio::buffer(&_readBuffer[bufferOffset], size), error); + _readDataBuffer.Grow(size); - if (error) + std::size_t bytesRead = boost::asio::read(_socket, boost::asio::buffer(_readDataBuffer.GetWritePointer(), size), error); + + _readDataBuffer.WriteCompleted(bytesRead); + + if (error || !_readDataBuffer.IsMessageReady()) { - TC_LOG_DEBUG("network", "Socket::ReadData: %s errored with: %i (%s)", GetRemoteIpAddress().to_string().c_str(), error.value(), error.message().c_str()); + TC_LOG_DEBUG("network", "Socket::ReadData: %s errored with: %i (%s)", GetRemoteIpAddress().to_string().c_str(), error.value(), + error.message().c_str()); CloseSocket(); } @@ -83,18 +101,22 @@ public: void AsyncWrite(WritePacketType data) { - boost::asio::async_write(_socket, boost::asio::buffer(data), std::bind(&Socket<T, PacketType>::WriteHandler, this->shared_from_this(), std::placeholders::_1, - std::placeholders::_2)); + boost::asio::async_write(_socket, boost::asio::buffer(data), std::bind(&Socket<T, PacketType>::WriteHandler, this->shared_from_this(), + std::placeholders::_1, std::placeholders::_2)); } - bool IsOpen() const { return _socket.is_open(); } + bool IsOpen() const { return !_closed; } + void CloseSocket() { + if (_closed.exchange(true)) + return; + boost::system::error_code shutdownError; _socket.shutdown(boost::asio::socket_base::shutdown_both, shutdownError); if (shutdownError) TC_LOG_DEBUG("network", "Socket::CloseSocket: %s errored when shutting down socket: %i (%s)", GetRemoteIpAddress().to_string().c_str(), - shutdownError.value(), shutdownError.message().c_str()); + shutdownError.value(), shutdownError.message().c_str()); boost::system::error_code error; _socket.close(error); @@ -103,18 +125,72 @@ public: error.value(), error.message().c_str()); } - uint8* GetReadBuffer() { return _readBuffer; } + virtual bool IsHeaderReady() const { return _readHeaderBuffer.IsMessageReady(); } + virtual bool IsDataReady() const { return _readDataBuffer.IsMessageReady(); } + + uint8* GetHeaderBuffer() { return _readHeaderBuffer.Data(); } + uint8* GetDataBuffer() { return _readDataBuffer.Data(); } + + MessageBuffer&& MoveHeader() { return std::move(_readHeaderBuffer); } + MessageBuffer&& MoveData() { return std::move(_readDataBuffer); } protected: - virtual void ReadHeaderHandler(boost::system::error_code error, size_t transferedBytes) = 0; - virtual void ReadDataHandler(boost::system::error_code error, size_t transferedBytes) = 0; + virtual void ReadHeaderHandler() = 0; + virtual void ReadDataHandler() = 0; std::mutex _writeLock; std::queue<PacketType> _writeQueue; private: - void ReadHeaderHandlerInternal(boost::system::error_code error, size_t transferedBytes) { ReadHeaderHandler(error, transferedBytes); } - void ReadDataHandlerInternal(boost::system::error_code error, size_t transferedBytes) { ReadDataHandler(error, transferedBytes); } + void AsyncReadMissingHeaderData() + { + _socket.async_read_some(boost::asio::buffer(_readHeaderBuffer.GetWritePointer(), std::min<std::size_t>(READ_BLOCK_SIZE, _readHeaderBuffer.GetMissingSize())), + std::bind(&Socket<T, PacketType>::ReadHeaderHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); + } + + void AsyncReadMissingData() + { + _socket.async_read_some(boost::asio::buffer(_readDataBuffer.GetWritePointer(), std::min<std::size_t>(READ_BLOCK_SIZE, _readDataBuffer.GetMissingSize())), + std::bind(&Socket<T, PacketType>::ReadDataHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); + } + + void ReadHeaderHandlerInternal(boost::system::error_code error, size_t transferredBytes) + { + if (error) + { + CloseSocket(); + return; + } + + _readHeaderBuffer.WriteCompleted(transferredBytes); + if (!IsHeaderReady()) + { + // incomplete, read more + AsyncReadMissingHeaderData(); + return; + } + + ReadHeaderHandler(); + } + + void ReadDataHandlerInternal(boost::system::error_code error, size_t transferredBytes) + { + if (error) + { + CloseSocket(); + return; + } + + _readDataBuffer.WriteCompleted(transferredBytes); + if (!IsDataReady()) + { + // incomplete, read more + AsyncReadMissingData(); + return; + } + + ReadDataHandler(); + } void WriteHandler(boost::system::error_code error, size_t /*transferedBytes*/) { @@ -140,12 +216,13 @@ private: tcp::socket _socket; - uint8 _readBuffer[4096]; - - uint16 _remotePort; boost::asio::ip::address _remoteAddress; + uint16 _remotePort; + + MessageBuffer _readHeaderBuffer; + MessageBuffer _readDataBuffer; - std::size_t _headerSize; + std::atomic<bool> _closed; }; #endif // __SOCKET_H__ diff --git a/src/server/shared/Packets/ByteBuffer.cpp b/src/server/shared/Packets/ByteBuffer.cpp index 86234039a4a..3785d1c29fa 100644 --- a/src/server/shared/Packets/ByteBuffer.cpp +++ b/src/server/shared/Packets/ByteBuffer.cpp @@ -17,11 +17,16 @@ */ #include "ByteBuffer.h" +#include "MessageBuffer.h" #include "Common.h" #include "Log.h" #include <sstream> +ByteBuffer::ByteBuffer(MessageBuffer&& buffer) : _rpos(0), _wpos(0), _storage(buffer.Move()) +{ +} + ByteBufferPositionException::ByteBufferPositionException(bool add, size_t pos, size_t size, size_t valueSize) { diff --git a/src/server/shared/Packets/ByteBuffer.h b/src/server/shared/Packets/ByteBuffer.h index c678e9dce06..456223d744d 100644 --- a/src/server/shared/Packets/ByteBuffer.h +++ b/src/server/shared/Packets/ByteBuffer.h @@ -34,6 +34,8 @@ #include <math.h> #include <boost/asio/buffer.hpp> +class MessageBuffer; + // Root of ByteBuffer exception hierarchy class ByteBufferException : public std::exception { @@ -82,14 +84,12 @@ class ByteBuffer } ByteBuffer(ByteBuffer&& buf) : _rpos(buf._rpos), _wpos(buf._wpos), - _storage(std::move(buf._storage)) - { - } + _storage(std::move(buf._storage)) { } ByteBuffer(ByteBuffer const& right) : _rpos(right._rpos), _wpos(right._wpos), - _storage(right._storage) - { - } + _storage(right._storage) { } + + ByteBuffer(MessageBuffer&& buffer); ByteBuffer& operator=(ByteBuffer const& right) { diff --git a/src/server/shared/Packets/WorldPacket.h b/src/server/shared/Packets/WorldPacket.h index 8851b9f3e45..848a00739fe 100644 --- a/src/server/shared/Packets/WorldPacket.h +++ b/src/server/shared/Packets/WorldPacket.h @@ -51,6 +51,8 @@ class WorldPacket : public ByteBuffer return *this; } + WorldPacket(uint16 opcode, MessageBuffer&& buffer) : ByteBuffer(std::move(buffer)), m_opcode(opcode) { } + void Initialize(uint16 opcode, size_t newres=200) { clear(); diff --git a/src/server/worldserver/Main.cpp b/src/server/worldserver/Main.cpp index 3afa9e84e8b..2c393215f7d 100644 --- a/src/server/worldserver/Main.cpp +++ b/src/server/worldserver/Main.cpp @@ -87,6 +87,7 @@ bool StartDB(); void StopDB(); void WorldUpdateLoop(); void ClearOnlineAccounts(); +void ShutdownThreadPool(std::vector<std::thread>& threadPool); variables_map GetConsoleArguments(int argc, char** argv, std::string& cfg_file, std::string& cfg_service); /// Launch the Trinity server @@ -179,7 +180,10 @@ extern int main(int argc, char** argv) // Start the databases if (!StartDB()) + { + ShutdownThreadPool(threadPool); return 1; + } // Set server offline (not connectable) LoginDatabase.DirectPExecute("UPDATE realmlist SET flag = (flag & ~%u) | %u WHERE id = '%d'", REALM_FLAG_OFFLINE, REALM_FLAG_INVALID, realmID); @@ -236,13 +240,7 @@ extern int main(int argc, char** argv) WorldUpdateLoop(); // Shutdown starts here - - _ioService.stop(); - - for (auto& thread : threadPool) - { - thread.join(); - } + ShutdownThreadPool(threadPool); sScriptMgr->OnShutdown(); @@ -281,41 +279,7 @@ extern int main(int argc, char** argv) if (cliThread != nullptr) { #ifdef _WIN32 - - // this only way to terminate CLI thread exist at Win32 (alt. way exist only in Windows Vista API) - //_exit(1); - // send keyboard input to safely unblock the CLI thread - INPUT_RECORD b[4]; - HANDLE hStdIn = GetStdHandle(STD_INPUT_HANDLE); - b[0].EventType = KEY_EVENT; - b[0].Event.KeyEvent.bKeyDown = TRUE; - b[0].Event.KeyEvent.uChar.AsciiChar = 'X'; - b[0].Event.KeyEvent.wVirtualKeyCode = 'X'; - b[0].Event.KeyEvent.wRepeatCount = 1; - - b[1].EventType = KEY_EVENT; - b[1].Event.KeyEvent.bKeyDown = FALSE; - b[1].Event.KeyEvent.uChar.AsciiChar = 'X'; - b[1].Event.KeyEvent.wVirtualKeyCode = 'X'; - b[1].Event.KeyEvent.wRepeatCount = 1; - - b[2].EventType = KEY_EVENT; - b[2].Event.KeyEvent.bKeyDown = TRUE; - b[2].Event.KeyEvent.dwControlKeyState = 0; - b[2].Event.KeyEvent.uChar.AsciiChar = '\r'; - b[2].Event.KeyEvent.wVirtualKeyCode = VK_RETURN; - b[2].Event.KeyEvent.wRepeatCount = 1; - b[2].Event.KeyEvent.wVirtualScanCode = 0x1c; - - b[3].EventType = KEY_EVENT; - b[3].Event.KeyEvent.bKeyDown = FALSE; - b[3].Event.KeyEvent.dwControlKeyState = 0; - b[3].Event.KeyEvent.uChar.AsciiChar = '\r'; - b[3].Event.KeyEvent.wVirtualKeyCode = VK_RETURN; - b[3].Event.KeyEvent.wVirtualScanCode = 0x1c; - b[3].Event.KeyEvent.wRepeatCount = 1; - DWORD numb; - WriteConsoleInput(hStdIn, b, 4, &numb); + CancelSynchronousIo(cliThread->native_handle()); #endif cliThread->join(); delete cliThread; @@ -330,6 +294,15 @@ extern int main(int argc, char** argv) return World::GetExitCode(); } +void ShutdownThreadPool(std::vector<std::thread>& threadPool) +{ + _ioService.stop(); + + for (auto& thread : threadPool) + { + thread.join(); + } +} void WorldUpdateLoop() { diff --git a/src/tools/mmaps_generator/PathGenerator.cpp b/src/tools/mmaps_generator/PathGenerator.cpp index 3e2025dace8..c2ca184905e 100644 --- a/src/tools/mmaps_generator/PathGenerator.cpp +++ b/src/tools/mmaps_generator/PathGenerator.cpp @@ -276,7 +276,7 @@ int main(int argc, char** argv) } if (!checkDirectories(debugOutput)) - return silent ? -3 : finish("Press any key to close...", -3); + return silent ? -3 : finish("Press ENTER to close...", -3); MapBuilder builder(maxAngle, skipLiquid, skipContinents, skipJunkMaps, skipBattlegrounds, debugOutput, bigBaseUnit, offMeshInputPath); |