diff options
author | Shauren <shauren.trinity@gmail.com> | 2014-09-09 19:19:25 +0200 |
---|---|---|
committer | Shauren <shauren.trinity@gmail.com> | 2014-09-09 19:19:25 +0200 |
commit | e0ce4528c5ffd43f651f88821723311541e9e461 (patch) | |
tree | c5f3906d17114120b2ea2b382cea39db07137d18 | |
parent | a2ba49afa428ed9297f98bf8a5e00f6f7a6f4c3a (diff) |
Core/NetworkIO: Use reactor style sending on linux to reduce locking overhead
-rw-r--r-- | src/server/authserver/Main.cpp | 3 | ||||
-rw-r--r-- | src/server/authserver/Server/AuthSession.cpp | 121 | ||||
-rw-r--r-- | src/server/authserver/Server/AuthSession.h | 14 | ||||
-rw-r--r-- | src/server/game/Server/WorldSession.cpp | 2 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocket.cpp | 122 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocket.h | 67 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocketMgr.cpp | 115 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocketMgr.h | 66 | ||||
-rw-r--r-- | src/server/shared/Networking/AsyncAcceptor.h | 52 | ||||
-rw-r--r-- | src/server/shared/Networking/MessageBuffer.h | 67 | ||||
-rw-r--r-- | src/server/shared/Networking/NetworkThread.h | 166 | ||||
-rw-r--r-- | src/server/shared/Networking/Socket.h | 255 | ||||
-rw-r--r-- | src/server/shared/Networking/SocketMgr.h | 111 | ||||
-rw-r--r-- | src/server/worldserver/Main.cpp | 18 |
14 files changed, 881 insertions, 298 deletions
diff --git a/src/server/authserver/Main.cpp b/src/server/authserver/Main.cpp index f26c0342654..01dbaa6aa9a 100644 --- a/src/server/authserver/Main.cpp +++ b/src/server/authserver/Main.cpp @@ -118,7 +118,8 @@ int main(int argc, char** argv) } std::string bindIp = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0"); - AsyncAcceptor<AuthSession> authServer(_ioService, bindIp, port); + AsyncAcceptor authServer(_ioService, bindIp, port); + authServer.AsyncAccept<AuthSession>(); // Set signal handlers boost::asio::signal_set signals(_ioService, SIGINT, SIGTERM); diff --git a/src/server/authserver/Server/AuthSession.cpp b/src/server/authserver/Server/AuthSession.cpp index 8276183fb10..e8241ac49b1 100644 --- a/src/server/authserver/Server/AuthSession.cpp +++ b/src/server/authserver/Server/AuthSession.cpp @@ -53,6 +53,7 @@ enum eStatus typedef struct AUTH_LOGON_CHALLENGE_C { + uint8 cmd; uint8 error; uint16 size; uint8 gamename[4]; @@ -71,6 +72,7 @@ typedef struct AUTH_LOGON_CHALLENGE_C typedef struct AUTH_LOGON_PROOF_C { + uint8 cmd; uint8 A[32]; uint8 M1[20]; uint8 crc_hash[20]; @@ -98,6 +100,7 @@ typedef struct AUTH_LOGON_PROOF_S_OLD typedef struct AUTH_RECONNECT_PROOF_C { + uint8 cmd; uint8 R1[16]; uint8 R2[20]; uint8 R3[20]; @@ -112,79 +115,88 @@ enum class BufferSizes : uint32 SRP_6_S = 0x20, }; -#define REALM_LIST_PACKET_SIZE 4 -#define XFER_ACCEPT_SIZE 0 -#define XFER_RESUME_SIZE 8 -#define XFER_CANCEL_SIZE 0 +#define AUTH_LOGON_CHALLENGE_INITIAL_SIZE 4 +#define REALM_LIST_PACKET_SIZE 5 +#define XFER_ACCEPT_SIZE 1 +#define XFER_RESUME_SIZE 9 +#define XFER_CANCEL_SIZE 1 std::unordered_map<uint8, AuthHandler> AuthSession::InitHandlers() { std::unordered_map<uint8, AuthHandler> handlers; - handlers[AUTH_LOGON_CHALLENGE] = { STATUS_CONNECTED, sizeof(AUTH_LOGON_CHALLENGE_C), &AuthSession::HandleLogonChallenge }; - handlers[AUTH_LOGON_PROOF] = { STATUS_CONNECTED, sizeof(AUTH_LOGON_PROOF_C), &AuthSession::HandleLogonProof }; - handlers[AUTH_RECONNECT_CHALLENGE] = { STATUS_CONNECTED, sizeof(AUTH_LOGON_CHALLENGE_C), &AuthSession::HandleReconnectChallenge }; - handlers[AUTH_RECONNECT_PROOF] = { STATUS_CONNECTED, sizeof(AUTH_RECONNECT_PROOF_C), &AuthSession::HandleReconnectProof }; - handlers[REALM_LIST] = { STATUS_AUTHED, REALM_LIST_PACKET_SIZE, &AuthSession::HandleRealmList }; - handlers[XFER_ACCEPT] = { STATUS_AUTHED, XFER_ACCEPT_SIZE, &AuthSession::HandleXferAccept }; - handlers[XFER_RESUME] = { STATUS_AUTHED, XFER_RESUME_SIZE, &AuthSession::HandleXferResume }; - handlers[XFER_CANCEL] = { STATUS_AUTHED, XFER_CANCEL_SIZE, &AuthSession::HandleXferCancel }; + handlers[AUTH_LOGON_CHALLENGE] = { STATUS_CONNECTED, AUTH_LOGON_CHALLENGE_INITIAL_SIZE, &AuthSession::HandleLogonChallenge }; + handlers[AUTH_LOGON_PROOF] = { STATUS_CONNECTED, sizeof(AUTH_LOGON_PROOF_C), &AuthSession::HandleLogonProof }; + handlers[AUTH_RECONNECT_CHALLENGE] = { STATUS_CONNECTED, AUTH_LOGON_CHALLENGE_INITIAL_SIZE, &AuthSession::HandleReconnectChallenge }; + handlers[AUTH_RECONNECT_PROOF] = { STATUS_CONNECTED, sizeof(AUTH_RECONNECT_PROOF_C), &AuthSession::HandleReconnectProof }; + handlers[REALM_LIST] = { STATUS_AUTHED, REALM_LIST_PACKET_SIZE, &AuthSession::HandleRealmList }; + handlers[XFER_ACCEPT] = { STATUS_AUTHED, XFER_ACCEPT_SIZE, &AuthSession::HandleXferAccept }; + handlers[XFER_RESUME] = { STATUS_AUTHED, XFER_RESUME_SIZE, &AuthSession::HandleXferResume }; + handlers[XFER_CANCEL] = { STATUS_AUTHED, XFER_CANCEL_SIZE, &AuthSession::HandleXferCancel }; return handlers; } std::unordered_map<uint8, AuthHandler> const Handlers = AuthSession::InitHandlers(); -void AuthSession::ReadHeaderHandler() +void AuthSession::ReadHandler() { - uint8 cmd = GetHeaderBuffer()[0]; - auto itr = Handlers.find(cmd); - if (itr != Handlers.end()) + MessageBuffer& packet = GetReadBuffer(); + while (packet.GetActiveSize()) { - // Handle dynamic size packet + uint8 cmd = packet.GetReadPointer()[0]; + auto itr = Handlers.find(cmd); + if (itr == Handlers.end()) + { + // well we dont handle this, lets just ignore it + packet.Reset(); + break; + } + + uint16 size = uint16(itr->second.packetSize); + if (packet.GetActiveSize() < size) + break; + if (cmd == AUTH_LOGON_CHALLENGE || cmd == AUTH_RECONNECT_CHALLENGE) { - ReadData(sizeof(uint8) + sizeof(uint16)); //error + size - sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetDataBuffer()); + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(packet.GetReadPointer()); + size += challenge->size; + } - AsyncReadData(challenge->size); + if (packet.GetActiveSize() < size) + break; + + if (!(*this.*Handlers.at(cmd).handler)()) + { + CloseSocket(); + return; } - else - AsyncReadData(itr->second.packetSize); - } - else - CloseSocket(); -} -void AuthSession::ReadDataHandler() -{ - if (!(*this.*Handlers.at(GetHeaderBuffer()[0]).handler)()) - { - CloseSocket(); - return; + packet.ReadCompleted(size); } - AsyncReadHeader(); + AsyncRead(); } -void AuthSession::AsyncWrite(ByteBuffer& packet) +void AuthSession::SendPacket(ByteBuffer& packet) { if (!IsOpen()) return; - std::lock_guard<std::mutex> guard(_writeLock); - - bool needsWriteStart = _writeQueue.empty(); + if (!packet.empty()) + { + MessageBuffer buffer; + buffer.Write(packet.contents(), packet.size()); - _writeQueue.push(std::move(packet)); + std::unique_lock<std::mutex> guard(_writeLock); - if (needsWriteStart) - Base::AsyncWrite(_writeQueue.front()); + QueuePacket(std::move(buffer), guard); + } } bool AuthSession::HandleLogonChallenge() { - sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetDataBuffer()); + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer().GetReadPointer()); //TC_LOG_DEBUG("server.authserver", "[AuthChallenge] got full packet, %#04x bytes", challenge->size); TC_LOG_DEBUG("server.authserver", "[AuthChallenge] name(%d): '%s'", challenge->I_len, challenge->I); @@ -393,7 +405,7 @@ bool AuthSession::HandleLogonChallenge() pkt << uint8(WOW_FAIL_UNKNOWN_ACCOUNT); } - AsyncWrite(pkt); + SendPacket(pkt); return true; } @@ -403,7 +415,7 @@ bool AuthSession::HandleLogonProof() TC_LOG_DEBUG("server.authserver", "Entering _HandleLogonProof"); // Read the packet - sAuthLogonProof_C *logonProof = reinterpret_cast<sAuthLogonProof_C*>(GetDataBuffer()); + sAuthLogonProof_C *logonProof = reinterpret_cast<sAuthLogonProof_C*>(GetReadBuffer().GetReadPointer()); // If the client has no valid version if (_expversion == NO_VALID_EXP_FLAG) @@ -515,10 +527,9 @@ bool AuthSession::HandleLogonProof() // Check auth token if ((logonProof->securityFlags & 0x04) || !_tokenKey.empty()) { - ReadData(1); - uint8 size = *(GetDataBuffer() + sizeof(sAuthLogonProof_C)); - ReadData(size); - std::string token(reinterpret_cast<char*>(GetDataBuffer() + sizeof(sAuthLogonProof_C) + sizeof(size)), size); + uint8 size = *(GetReadBuffer().GetReadPointer() + sizeof(sAuthLogonProof_C)); + std::string token(reinterpret_cast<char*>(GetReadBuffer().GetReadPointer() + sizeof(sAuthLogonProof_C) + sizeof(size)), size); + GetReadBuffer().ReadCompleted(sizeof(size) + size); uint32 validToken = TOTP::GenerateToken(_tokenKey.c_str()); uint32 incomingToken = atoi(token.c_str()); if (validToken != incomingToken) @@ -528,7 +539,7 @@ bool AuthSession::HandleLogonProof() packet << uint8(WOW_FAIL_UNKNOWN_ACCOUNT); packet << uint8(3); packet << uint8(0); - AsyncWrite(packet); + SendPacket(packet); return false; } } @@ -559,7 +570,7 @@ bool AuthSession::HandleLogonProof() std::memcpy(packet.contents(), &proof, sizeof(proof)); } - AsyncWrite(packet); + SendPacket(packet); _isAuthenticated = true; } else @@ -569,7 +580,7 @@ bool AuthSession::HandleLogonProof() packet << uint8(WOW_FAIL_UNKNOWN_ACCOUNT); packet << uint8(3); packet << uint8(0); - AsyncWrite(packet); + SendPacket(packet); TC_LOG_DEBUG("server.authserver", "'%s:%d' [AuthChallenge] account %s tried to login with invalid password!", GetRemoteIpAddress().to_string().c_str(), GetRemotePort(), _login.c_str()); @@ -638,7 +649,7 @@ bool AuthSession::HandleLogonProof() bool AuthSession::HandleReconnectChallenge() { TC_LOG_DEBUG("server.authserver", "Entering _HandleReconnectChallenge"); - sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetDataBuffer()); + sAuthLogonChallenge_C* challenge = reinterpret_cast<sAuthLogonChallenge_C*>(GetReadBuffer().GetReadPointer()); //TC_LOG_DEBUG("server.authserver", "[AuthChallenge] got full packet, %#04x bytes", challenge->size); TC_LOG_DEBUG("server.authserver", "[AuthChallenge] name(%d): '%s'", challenge->I_len, challenge->I); @@ -682,14 +693,14 @@ bool AuthSession::HandleReconnectChallenge() pkt.append(_reconnectProof.AsByteArray(16).get(), 16); // 16 bytes random pkt << uint64(0x00) << uint64(0x00); // 16 bytes zeros - AsyncWrite(pkt); + SendPacket(pkt); return true; } bool AuthSession::HandleReconnectProof() { TC_LOG_DEBUG("server.authserver", "Entering _HandleReconnectProof"); - sAuthReconnectProof_C *reconnectProof = reinterpret_cast<sAuthReconnectProof_C*>(GetDataBuffer()); + sAuthReconnectProof_C *reconnectProof = reinterpret_cast<sAuthReconnectProof_C*>(GetReadBuffer().GetReadPointer()); if (_login.empty() || !_reconnectProof.GetNumBytes() || !K.GetNumBytes()) return false; @@ -710,7 +721,7 @@ bool AuthSession::HandleReconnectProof() pkt << uint8(AUTH_RECONNECT_PROOF); pkt << uint8(0x00); pkt << uint16(0x00); // 2 bytes zeros - AsyncWrite(pkt); + SendPacket(pkt); _isAuthenticated = true; return true; } @@ -870,7 +881,7 @@ bool AuthSession::HandleRealmList() hdr << uint16(pkt.size() + RealmListSizeBuffer.size()); hdr.append(RealmListSizeBuffer); // append RealmList's size buffer hdr.append(pkt); // append realms in the realmlist - AsyncWrite(hdr); + SendPacket(hdr); return true; } diff --git a/src/server/authserver/Server/AuthSession.h b/src/server/authserver/Server/AuthSession.h index 3497e3a030c..07af61d9c1d 100644 --- a/src/server/authserver/Server/AuthSession.h +++ b/src/server/authserver/Server/AuthSession.h @@ -30,14 +30,12 @@ using boost::asio::ip::tcp; struct AuthHandler; -class AuthSession : public Socket<AuthSession, ByteBuffer> +class AuthSession : public Socket<AuthSession> { - typedef Socket<AuthSession, ByteBuffer> Base; - public: static std::unordered_map<uint8, AuthHandler> InitHandlers(); - AuthSession(tcp::socket&& socket) : Socket(std::move(socket), 1), + AuthSession(tcp::socket&& socket) : Socket(std::move(socket)), _isAuthenticated(false), _build(0), _expversion(0), _accountSecurityLevel(SEC_PLAYER) { N.SetHexStr("894B645E89E1535BBDAD5B8B290650530801B18EBFBF5E8FAB3C82872A3E9BB7"); @@ -46,15 +44,13 @@ public: void Start() override { - AsyncReadHeader(); + AsyncRead(); } - using Base::AsyncWrite; - void AsyncWrite(ByteBuffer& packet); + void SendPacket(ByteBuffer& packet); protected: - void ReadHeaderHandler() override; - void ReadDataHandler() override; + void ReadHandler() override; private: bool HandleLogonChallenge(); diff --git a/src/server/game/Server/WorldSession.cpp b/src/server/game/Server/WorldSession.cpp index 445e42a7f08..321bc707879 100644 --- a/src/server/game/Server/WorldSession.cpp +++ b/src/server/game/Server/WorldSession.cpp @@ -227,7 +227,7 @@ void WorldSession::SendPacket(WorldPacket* packet) sScriptMgr->OnPacketSend(this, *packet); - m_Socket->AsyncWrite(*packet); + m_Socket->SendPacket(*packet); } /// Add an incoming packet to the queue diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp index ca8e2cd5a34..f8673e5d5b7 100644 --- a/src/server/game/Server/WorldSocket.cpp +++ b/src/server/game/Server/WorldSocket.cpp @@ -28,14 +28,14 @@ using boost::asio::ip::tcp; WorldSocket::WorldSocket(tcp::socket&& socket) - : Socket(std::move(socket), sizeof(ClientPktHeader)), _authSeed(rand32()), _OverSpeedPings(0), _worldSession(nullptr) + : Socket(std::move(socket)), _authSeed(rand32()), _OverSpeedPings(0), _worldSession(nullptr) { + _headerBuffer.Resize(sizeof(ClientPktHeader)); } void WorldSocket::Start() { - sScriptMgr->OnSocketOpen(shared_from_this()); - AsyncReadHeader(); + AsyncRead(); HandleSendAuthSession(); } @@ -53,14 +53,69 @@ void WorldSocket::HandleSendAuthSession() seed2.SetRand(16 * 8); packet.append(seed2.AsByteArray(16).get(), 16); // new encryption seeds - AsyncWrite(packet); + SendPacket(packet); } -void WorldSocket::ReadHeaderHandler() +void WorldSocket::ReadHandler() { - _authCrypt.DecryptRecv(GetHeaderBuffer(), sizeof(ClientPktHeader)); + if (!IsOpen()) + return; + + MessageBuffer& packet = GetReadBuffer(); + while (packet.GetActiveSize() > 0) + { + if (_headerBuffer.GetRemainingSpace() > 0) + { + // need to receive the header + std::size_t readHeaderSize = std::min(packet.GetActiveSize(), _headerBuffer.GetRemainingSpace()); + _headerBuffer.Write(packet.GetReadPointer(), readHeaderSize); + packet.ReadCompleted(readHeaderSize); + + if (_headerBuffer.GetRemainingSpace() > 0) + { + // Couldn't receive the whole header this time. + ASSERT(packet.GetActiveSize() == 0); + break; + } + + // We just received nice new header + if (!ReadHeaderHandler()) + return; + } + + // We have full read header, now check the data payload + if (_packetBuffer.GetRemainingSpace() > 0) + { + // need more data in the payload + std::size_t readDataSize = std::min(packet.GetActiveSize(), _packetBuffer.GetRemainingSpace()); + _packetBuffer.Write(packet.GetReadPointer(), readDataSize); + packet.ReadCompleted(readDataSize); + + if (_packetBuffer.GetRemainingSpace() > 0) + { + // Couldn't receive the whole data this time. + ASSERT(packet.GetActiveSize() == 0); + break; + } + } - ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetHeaderBuffer()); + // just received fresh new payload + if (!ReadDataHandler()) + return; + + _headerBuffer.Reset(); + } + + AsyncRead(); +} + +bool WorldSocket::ReadHeaderHandler() +{ + ASSERT(_headerBuffer.GetActiveSize() == sizeof(ClientPktHeader)); + + _authCrypt.DecryptRecv(_headerBuffer.GetReadPointer(), sizeof(ClientPktHeader)); + + ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(_headerBuffer.GetReadPointer()); EndianConvertReverse(header->size); EndianConvert(header->cmd); @@ -74,24 +129,26 @@ void WorldSocket::ReadHeaderHandler() } else TC_LOG_ERROR("network", "WorldSocket::ReadHeaderHandler(): client %s sent malformed packet (size: %hu, cmd: %u)", - GetRemoteIpAddress().to_string().c_str(), header->size, header->cmd); + GetRemoteIpAddress().to_string().c_str(), header->size, header->cmd); CloseSocket(); - return; + return false; } - AsyncReadData(header->size - sizeof(header->cmd)); + header->size -= sizeof(header->cmd); + _packetBuffer.Resize(header->size); + return true; } -void WorldSocket::ReadDataHandler() +bool WorldSocket::ReadDataHandler() { - ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(GetHeaderBuffer()); + ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(_headerBuffer.GetReadPointer()); uint16 opcode = uint16(header->cmd); std::string opcodeName = GetOpcodeNameForLogging(opcode); - WorldPacket packet(opcode, MoveData()); + WorldPacket packet(opcode, std::move(_packetBuffer)); if (sPacketLog->CanLogPacket()) sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort()); @@ -122,7 +179,7 @@ void WorldSocket::ReadDataHandler() { TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode)); CloseSocket(); - return; + return false; } // Our Idle timer will reset on any non PING opcodes. @@ -135,10 +192,10 @@ void WorldSocket::ReadDataHandler() } } - AsyncReadHeader(); + return true; } -void WorldSocket::AsyncWrite(WorldPacket& packet) +void WorldSocket::SendPacket(WorldPacket& packet) { if (!IsOpen()) return; @@ -150,15 +207,27 @@ void WorldSocket::AsyncWrite(WorldPacket& packet) ServerPktHeader header(packet.size() + 2, packet.GetOpcode()); - std::lock_guard<std::mutex> guard(_writeLock); + std::unique_lock<std::mutex> guard(_writeLock); - bool needsWriteStart = _writeQueue.empty(); _authCrypt.EncryptSend(header.header, header.getHeaderLength()); - _writeQueue.emplace(header, packet); +#ifndef BOOST_ASIO_HAS_IOCP + if (_writeQueue.empty() && _writeBuffer.GetRemainingSpace() >= header.getHeaderLength() + packet.size()) + { + _writeBuffer.Write(header.header, header.getHeaderLength()); + if (!packet.empty()) + _writeBuffer.Write(packet.contents(), packet.size()); + } + else +#endif + { + MessageBuffer buffer(header.getHeaderLength() + packet.size()); + buffer.Write(header.header, header.getHeaderLength()); + if (!packet.empty()) + buffer.Write(packet.contents(), packet.size()); - if (needsWriteStart) - AsyncWrite(_writeQueue.front()); + QueuePacket(std::move(buffer), guard); + } } void WorldSocket::HandleAuthSession(WorldPacket& recvPacket) @@ -410,7 +479,7 @@ void WorldSocket::SendAuthResponseError(uint8 code) WorldPacket packet(SMSG_AUTH_RESPONSE, 1); packet << uint8(code); - AsyncWrite(packet); + SendPacket(packet); } void WorldSocket::HandlePing(WorldPacket& recvPacket) @@ -471,12 +540,5 @@ void WorldSocket::HandlePing(WorldPacket& recvPacket) WorldPacket packet(SMSG_PONG, 4); packet << ping; - return AsyncWrite(packet); -} - -void WorldSocket::CloseSocket() -{ - sScriptMgr->OnSocketClose(shared_from_this()); - - Socket::CloseSocket(); + return SendPacket(packet); } diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h index 0667c1b090a..d301e239340 100644 --- a/src/server/game/Server/WorldSocket.h +++ b/src/server/game/Server/WorldSocket.h @@ -19,16 +19,6 @@ #ifndef __WORLDSOCKET_H__ #define __WORLDSOCKET_H__ -// Forward declare buffer function here - Socket.h must know about it -struct WorldPacketBuffer; -namespace boost -{ - namespace asio - { - WorldPacketBuffer const& buffer(WorldPacketBuffer const& buf); - } -} - #include "Common.h" #include "AuthCrypt.h" #include "ServerPktHeader.h" @@ -54,50 +44,8 @@ struct ClientPktHeader #pragma pack(pop) -struct WorldPacketBuffer +class WorldSocket : public Socket<WorldSocket> { - typedef boost::asio::const_buffer value_type; - - typedef boost::asio::const_buffer const* const_iterator; - - WorldPacketBuffer(ServerPktHeader header, WorldPacket const& packet) : _header(header), _packet(packet) - { - _buffers[0] = boost::asio::const_buffer(_header.header, _header.getHeaderLength()); - if (!_packet.empty()) - _buffers[1] = boost::asio::const_buffer(_packet.contents(), _packet.size()); - } - - const_iterator begin() const - { - return _buffers; - } - - const_iterator end() const - { - return _buffers + (_packet.empty() ? 1 : 2); - } - -private: - boost::asio::const_buffer _buffers[2]; - ServerPktHeader _header; - WorldPacket _packet; -}; - -namespace boost -{ - namespace asio - { - inline WorldPacketBuffer const& buffer(WorldPacketBuffer const& buf) - { - return buf; - } - } -} - -class WorldSocket : public Socket<WorldSocket, WorldPacketBuffer> -{ - typedef Socket<WorldSocket, WorldPacketBuffer> Base; - public: WorldSocket(tcp::socket&& socket); @@ -106,14 +54,12 @@ public: void Start() override; - void CloseSocket() override; - - using Base::AsyncWrite; - void AsyncWrite(WorldPacket& packet); + void SendPacket(WorldPacket& packet); protected: - void ReadHeaderHandler() override; - void ReadDataHandler() override; + void ReadHandler() override; + bool ReadHeaderHandler(); + bool ReadDataHandler(); private: void HandleSendAuthSession(); @@ -129,6 +75,9 @@ private: uint32 _OverSpeedPings; WorldSession* _worldSession; + + MessageBuffer _headerBuffer; + MessageBuffer _packetBuffer; }; #endif diff --git a/src/server/game/Server/WorldSocketMgr.cpp b/src/server/game/Server/WorldSocketMgr.cpp new file mode 100644 index 00000000000..21f62fa265c --- /dev/null +++ b/src/server/game/Server/WorldSocketMgr.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2008-2014 TrinityCore <http://www.trinitycore.org/> + * Copyright (C) 2005-2008 MaNGOS <http://getmangos.com/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "Config.h" +#include "NetworkThread.h" +#include "ScriptMgr.h" +#include "WorldSocket.h" +#include "WorldSocketMgr.h" +#include <boost/system/error_code.hpp> + +static void OnSocketAccept(tcp::socket&& sock) +{ + sWorldSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock)); +} + +class WorldSocketThread : public NetworkThread<WorldSocket> +{ +public: + void SocketAdded(std::shared_ptr<WorldSocket> sock) override + { + sScriptMgr->OnSocketOpen(sock); + } + + void SocketRemoved(std::shared_ptr<WorldSocket> sock) override + { + sScriptMgr->OnSocketClose(sock); + } +}; + +WorldSocketMgr::WorldSocketMgr() : BaseSocketMgr(), _socketSendBufferSize(-1), m_SockOutUBuff(65536), _tcpNoDelay(true) +{ +} + +bool WorldSocketMgr::StartNetwork(boost::asio::io_service& service, std::string const& bindIp, uint16 port) +{ + _tcpNoDelay = sConfigMgr->GetBoolDefault("Network.TcpNodelay", true); + + TC_LOG_DEBUG("misc", "Max allowed socket connections %d", boost::asio::socket_base::max_connections); + + // -1 means use default + _socketSendBufferSize = sConfigMgr->GetIntDefault("Network.OutKBuff", -1); + + m_SockOutUBuff = sConfigMgr->GetIntDefault("Network.OutUBuff", 65536); + + if (m_SockOutUBuff <= 0) + { + TC_LOG_ERROR("misc", "Network.OutUBuff is wrong in your config file"); + return false; + } + + BaseSocketMgr::StartNetwork(service, bindIp, port); + + _acceptor->AsyncAcceptManaged(&OnSocketAccept); + + sScriptMgr->OnNetworkStart(); + return true; +} + +void WorldSocketMgr::StopNetwork() +{ + BaseSocketMgr::StopNetwork(); + + sScriptMgr->OnNetworkStop(); +} + +void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock) +{ + // set some options here + if (_socketSendBufferSize >= 0) + { + boost::system::error_code err; + sock.set_option(boost::asio::socket_base::send_buffer_size(_socketSendBufferSize), err); + if (err && err != boost::system::errc::not_supported) + { + TC_LOG_ERROR("misc", "WorldSocketMgr::OnSocketOpen sock.set_option(boost::asio::socket_base::send_buffer_size) err = %s", err.message().c_str()); + return; + } + } + + // Set TCP_NODELAY. + if (_tcpNoDelay) + { + boost::system::error_code err; + sock.set_option(boost::asio::ip::tcp::no_delay(true), err); + if (err) + { + TC_LOG_ERROR("misc", "WorldSocketMgr::OnSocketOpen sock.set_option(boost::asio::ip::tcp::no_delay) err = %s", err.message().c_str()); + return; + } + } + + //sock->m_OutBufferSize = static_cast<size_t> (m_SockOutUBuff); + + BaseSocketMgr::OnSocketOpen(std::forward<tcp::socket>(sock)); +} + +NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const +{ + return new WorldSocketThread[GetNetworkThreadCount()]; +} diff --git a/src/server/game/Server/WorldSocketMgr.h b/src/server/game/Server/WorldSocketMgr.h new file mode 100644 index 00000000000..92a28d0c135 --- /dev/null +++ b/src/server/game/Server/WorldSocketMgr.h @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2008-2013 TrinityCore <http://www.trinitycore.org/> + * Copyright (C) 2005-2009 MaNGOS <http://getmangos.com/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +/** \addtogroup u2w User to World Communication + * @{ + * \file WorldSocketMgr.h + * \author Derex <derex101@gmail.com> + */ + +#ifndef __WORLDSOCKETMGR_H +#define __WORLDSOCKETMGR_H + +#include "SocketMgr.h" + +class WorldSocket; + +/// Manages all sockets connected to peers and network threads +class WorldSocketMgr : public SocketMgr<WorldSocket> +{ + typedef SocketMgr<WorldSocket> BaseSocketMgr; + +public: + static WorldSocketMgr& Instance() + { + static WorldSocketMgr instance; + return instance; + } + + /// Start network, listen at address:port . + bool StartNetwork(boost::asio::io_service& service, std::string const& bindIp, uint16 port) override; + + /// Stops all network threads, It will wait for all running threads . + void StopNetwork() override; + + void OnSocketOpen(tcp::socket&& sock) override; + +protected: + WorldSocketMgr(); + + NetworkThread<WorldSocket>* CreateThreads() const override; + +private: + int32 _socketSendBufferSize; + int32 m_SockOutUBuff; + bool _tcpNoDelay; +}; + +#define sWorldSocketMgr WorldSocketMgr::Instance() + +#endif +/// @} diff --git a/src/server/shared/Networking/AsyncAcceptor.h b/src/server/shared/Networking/AsyncAcceptor.h index 64665c2b198..a8b688e6b26 100644 --- a/src/server/shared/Networking/AsyncAcceptor.h +++ b/src/server/shared/Networking/AsyncAcceptor.h @@ -23,37 +23,32 @@ using boost::asio::ip::tcp; -template <class T> class AsyncAcceptor { public: - AsyncAcceptor(boost::asio::io_service& ioService, std::string bindIp, int port) : - _acceptor(ioService, tcp::endpoint(boost::asio::ip::address::from_string(bindIp), port)), - _socket(ioService) - { - AsyncAccept(); - }; + typedef void(*ManagerAcceptHandler)(tcp::socket&& newSocket); - AsyncAcceptor(boost::asio::io_service& ioService, std::string bindIp, int port, bool tcpNoDelay) : + AsyncAcceptor(boost::asio::io_service& ioService, std::string const& bindIp, uint16 port) : _acceptor(ioService, tcp::endpoint(boost::asio::ip::address::from_string(bindIp), port)), _socket(ioService) { - _acceptor.set_option(boost::asio::ip::tcp::no_delay(tcpNoDelay)); + boost::system::error_code error; + _acceptor.non_blocking(true, error); + } - AsyncAccept(); - }; + template <class T> + void AsyncAccept(); -private: - void AsyncAccept() + void AsyncAcceptManaged(ManagerAcceptHandler mgrHandler) { - _acceptor.async_accept(_socket, [this](boost::system::error_code error) + _acceptor.async_accept(_socket, [this, mgrHandler](boost::system::error_code error) { if (!error) { try { // this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class - std::make_shared<T>(std::move(this->_socket))->Start(); + mgrHandler(std::move(this->_socket)); } catch (boost::system::system_error const& err) { @@ -61,13 +56,36 @@ private: } } - // lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face - this->AsyncAccept(); + AsyncAcceptManaged(mgrHandler); }); } +private: tcp::acceptor _acceptor; tcp::socket _socket; }; +template<class T> +void AsyncAcceptor::AsyncAccept() +{ + _acceptor.async_accept(_socket, [this](boost::system::error_code error) + { + if (!error) + { + try + { + // this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class + std::make_shared<T>(std::move(this->_socket))->Start(); + } + catch (boost::system::system_error const& err) + { + TC_LOG_INFO("network", "Failed to retrieve client's remote address %s", err.what()); + } + } + + // lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face + this->AsyncAccept<T>(); + }); +} + #endif /* __ASYNCACCEPT_H_ */ diff --git a/src/server/shared/Networking/MessageBuffer.h b/src/server/shared/Networking/MessageBuffer.h index c7f8ba31a71..2115bea3f47 100644 --- a/src/server/shared/Networking/MessageBuffer.h +++ b/src/server/shared/Networking/MessageBuffer.h @@ -26,42 +26,74 @@ class MessageBuffer typedef std::vector<uint8>::size_type size_type; public: - MessageBuffer() : _wpos(0), _storage() { } + MessageBuffer() : _wpos(0), _rpos(0), _storage() + { + _storage.resize(4096); + } - MessageBuffer(MessageBuffer const& right) : _wpos(right._wpos), _storage(right._storage) { } + explicit MessageBuffer(std::size_t initialSize) : _wpos(0), _rpos(0), _storage() + { + _storage.resize(initialSize); + } + + MessageBuffer(MessageBuffer const& right) : _wpos(right._wpos), _rpos(right._rpos), _storage(right._storage) + { + } - MessageBuffer(MessageBuffer&& right) : _wpos(right._wpos), _storage(right.Move()) { } + MessageBuffer(MessageBuffer&& right) : _wpos(right._wpos), _rpos(right._rpos), _storage(right.Move()) { } void Reset() { - _storage.clear(); _wpos = 0; + _rpos = 0; } - bool IsMessageReady() const { return _wpos == _storage.size(); } + void Resize(size_type bytes) + { + _storage.resize(bytes); + } - size_type GetSize() const { return _storage.size(); } + uint8* GetBasePointer() { return _storage.data(); } - size_type GetReadyDataSize() const { return _wpos; } + uint8* GetReadPointer() { return &_storage[_rpos]; } - size_type GetMissingSize() const { return _storage.size() - _wpos; } + uint8* GetWritePointer() { return &_storage[_wpos]; } - uint8* Data() { return _storage.data(); } + void ReadCompleted(size_type bytes) { _rpos += bytes; } - void Grow(size_type bytes) - { - _storage.resize(_storage.size() + bytes); - } + void WriteCompleted(size_type bytes) { _wpos += bytes; } - uint8* GetWritePointer() { return &_storage[_wpos]; } + size_type GetActiveSize() const { return _wpos - _rpos; } - void WriteCompleted(size_type bytes) { _wpos += bytes; } + size_type GetRemainingSpace() const { return _storage.size() - _wpos; } + + size_type GetBufferSize() const { return _storage.size(); } + + // Discards inactive data + void Normalize() + { + if (_rpos) + { + if (_rpos != _wpos) + memmove(GetBasePointer(), GetReadPointer(), GetActiveSize()); + _wpos -= _rpos; + _rpos = 0; + } + } - void ResetWritePointer() { _wpos = 0; } + void Write(void* data, std::size_t size) + { + if (size) + { + memcpy(GetWritePointer(), data, size); + WriteCompleted(size); + } + } std::vector<uint8>&& Move() { _wpos = 0; + _rpos = 0; return std::move(_storage); } @@ -70,6 +102,7 @@ public: if (this != &right) { _wpos = right._wpos; + _rpos = right._rpos; _storage = right._storage; } @@ -81,6 +114,7 @@ public: if (this != &right) { _wpos = right._wpos; + _rpos = right._rpos; _storage = right.Move(); } @@ -89,6 +123,7 @@ public: private: size_type _wpos; + size_type _rpos; std::vector<uint8> _storage; }; diff --git a/src/server/shared/Networking/NetworkThread.h b/src/server/shared/Networking/NetworkThread.h new file mode 100644 index 00000000000..701d0d97f36 --- /dev/null +++ b/src/server/shared/Networking/NetworkThread.h @@ -0,0 +1,166 @@ +/* + * Copyright (C) 2008-2014 TrinityCore <http://www.trinitycore.org/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef NetworkThread_h__ +#define NetworkThread_h__ + +#include "Define.h" +#include "Errors.h" +#include "Log.h" +#include "Timer.h" +#include <atomic> +#include <chrono> +#include <memory> +#include <mutex> +#include <set> +#include <thread> + +template<class SocketType> +class NetworkThread +{ +public: + NetworkThread() : _connections(0), _stopped(false), _thread(nullptr) + { + } + + virtual ~NetworkThread() + { + Stop(); + if (_thread) + { + Wait(); + delete _thread; + } + } + + void Stop() + { + _stopped = true; + } + + bool Start() + { + if (_thread) + return false; + + _thread = new std::thread(&NetworkThread::Run, this); + return true; + } + + void Wait() + { + ASSERT(_thread); + + _thread->join(); + delete _thread; + _thread = nullptr; + } + + int32 GetConnectionCount() const + { + return _connections; + } + + virtual void AddSocket(std::shared_ptr<SocketType> sock) + { + std::lock_guard<std::mutex> lock(_newSocketsLock); + + ++_connections; + _newSockets.insert(sock); + SocketAdded(sock); + } + +protected: + virtual void SocketAdded(std::shared_ptr<SocketType> sock) { } + virtual void SocketRemoved(std::shared_ptr<SocketType> sock) { } + + void AddNewSockets() + { + std::lock_guard<std::mutex> lock(_newSocketsLock); + + if (_newSockets.empty()) + return; + + for (SocketSet::const_iterator i = _newSockets.begin(); i != _newSockets.end(); ++i) + { + if (!(*i)->IsOpen()) + { + SocketRemoved(*i); + + --_connections; + } + else + _Sockets.insert(*i); + } + + _newSockets.clear(); + } + + void Run() + { + TC_LOG_DEBUG("misc", "Network Thread Starting"); + + SocketSet::iterator i, t; + + uint32 sleepTime = 10; + uint32 tickStart = 0, diff = 0; + while (!_stopped) + { + std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); + + tickStart = getMSTime(); + + AddNewSockets(); + + for (i = _Sockets.begin(); i != _Sockets.end();) + { + if (!(*i)->Update()) + { + if ((*i)->IsOpen()) + (*i)->CloseSocket(); + + SocketRemoved(*i); + + --_connections; + _Sockets.erase(i++); + } + else + ++i; + } + + diff = GetMSTimeDiffToNow(tickStart); + sleepTime = diff > 10 ? 0 : 10 - diff; + } + + TC_LOG_DEBUG("misc", "Network Thread exits"); + } + +private: + typedef std::set<std::shared_ptr<SocketType> > SocketSet; + + std::atomic<int32> _connections; + std::atomic<bool> _stopped; + + std::thread* _thread; + + SocketSet _Sockets; + + std::mutex _newSocketsLock; + SocketSet _newSockets; +}; + +#endif // NetworkThread_h__ diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h index 3bd30bd731b..17f48343485 100644 --- a/src/server/shared/Networking/Socket.h +++ b/src/server/shared/Networking/Socket.h @@ -35,32 +35,40 @@ using boost::asio::ip::tcp; #define READ_BLOCK_SIZE 4096 -template<class T, class PacketType> +template<class T> class Socket : public std::enable_shared_from_this<T> { - typedef typename std::conditional<std::is_pointer<PacketType>::value, PacketType, PacketType const&>::type WritePacketType; - public: - Socket(tcp::socket&& socket, std::size_t headerSize) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()), - _remotePort(_socket.remote_endpoint().port()), _readHeaderBuffer(), _readDataBuffer(), _closed(false), _closing(false) + explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()), + _remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false) { - _readHeaderBuffer.Grow(headerSize); + _readBuffer.Resize(READ_BLOCK_SIZE); } virtual ~Socket() { boost::system::error_code error; _socket.close(error); - - while (!_writeQueue.empty()) - { - DeletePacket(_writeQueue.front()); - _writeQueue.pop(); - } } virtual void Start() = 0; + virtual bool Update() + { + if (!IsOpen()) + return false; + +#ifndef BOOST_ASIO_HAS_IOCP + if (_isWritingAsync || (!_writeBuffer.GetActiveSize() && _writeQueue.empty())) + return true; + + for (; WriteHandler(boost::system::error_code(), 0);) + ; +#endif + + return true; + } + boost::asio::ip::address GetRemoteIpAddress() const { return _remoteAddress; @@ -71,31 +79,14 @@ public: return _remotePort; } - void AsyncReadHeader() + void AsyncRead() { if (!IsOpen()) return; - _readHeaderBuffer.ResetWritePointer(); - _readDataBuffer.Reset(); - - AsyncReadMissingHeaderData(); - } - - void AsyncReadData(std::size_t size) - { - if (!IsOpen()) - return; - - if (!size) - { - // if this is a packet with 0 length body just invoke handler directly - ReadDataHandler(); - return; - } - - _readDataBuffer.Grow(size); - AsyncReadMissingData(); + _readBuffer.Normalize(); + _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), READ_BLOCK_SIZE), + std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); } void ReadData(std::size_t size) @@ -105,13 +96,11 @@ public: boost::system::error_code error; - _readDataBuffer.Grow(size); - - std::size_t bytesRead = boost::asio::read(_socket, boost::asio::buffer(_readDataBuffer.GetWritePointer(), size), error); + std::size_t bytesRead = boost::asio::read(_socket, boost::asio::buffer(_readBuffer.GetWritePointer(), size), error); - _readDataBuffer.WriteCompleted(bytesRead); + _readBuffer.WriteCompleted(bytesRead); - if (error || !_readDataBuffer.IsMessageReady()) + if (error || bytesRead != size) { TC_LOG_DEBUG("network", "Socket::ReadData: %s errored with: %i (%s)", GetRemoteIpAddress().to_string().c_str(), error.value(), error.message().c_str()); @@ -120,15 +109,19 @@ public: } } - void AsyncWrite(WritePacketType data) + void QueuePacket(MessageBuffer&& buffer, std::unique_lock<std::mutex>& guard) { - boost::asio::async_write(_socket, boost::asio::buffer(data), std::bind(&Socket<T, PacketType>::WriteHandler, this->shared_from_this(), - std::placeholders::_1, std::placeholders::_2)); + + _writeQueue.push(std::move(buffer)); + +#ifdef BOOST_ASIO_HAS_IOCP + AsyncProcessQueue(guard); +#endif } bool IsOpen() const { return !_closed && !_closing; } - virtual void CloseSocket() + void CloseSocket() { if (_closed.exchange(true)) return; @@ -143,39 +136,37 @@ public: /// Marks the socket for closing after write buffer becomes empty void DelayedCloseSocket() { _closing = true; } - virtual bool IsHeaderReady() const { return _readHeaderBuffer.IsMessageReady(); } - virtual bool IsDataReady() const { return _readDataBuffer.IsMessageReady(); } - - uint8* GetHeaderBuffer() { return _readHeaderBuffer.Data(); } - uint8* GetDataBuffer() { return _readDataBuffer.Data(); } - - size_t GetHeaderSize() const { return _readHeaderBuffer.GetReadyDataSize(); } - size_t GetDataSize() const { return _readDataBuffer.GetReadyDataSize(); } - - MessageBuffer&& MoveHeader() { return std::move(_readHeaderBuffer); } - MessageBuffer&& MoveData() { return std::move(_readDataBuffer); } + MessageBuffer& GetReadBuffer() { return _readBuffer; } protected: - virtual void ReadHeaderHandler() = 0; - virtual void ReadDataHandler() = 0; - - std::mutex _writeLock; - std::queue<PacketType> _writeQueue; + virtual void ReadHandler() = 0; -private: - void AsyncReadMissingHeaderData() + bool AsyncProcessQueue(std::unique_lock<std::mutex>&) { - _socket.async_read_some(boost::asio::buffer(_readHeaderBuffer.GetWritePointer(), std::min<std::size_t>(READ_BLOCK_SIZE, _readHeaderBuffer.GetMissingSize())), - std::bind(&Socket<T, PacketType>::ReadHeaderHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); + if (_isWritingAsync) + return true; + + _isWritingAsync = true; + +#ifdef BOOST_ASIO_HAS_IOCP + MessageBuffer& buffer = _writeQueue.front(); + _socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), std::bind(&Socket<T>::WriteHandler, + this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); +#else + _socket.async_write_some(boost::asio::null_buffers(), std::bind(&Socket<T>::WriteHandler, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); +#endif + + return true; } - void AsyncReadMissingData() - { - _socket.async_read_some(boost::asio::buffer(_readDataBuffer.GetWritePointer(), std::min<std::size_t>(READ_BLOCK_SIZE, _readDataBuffer.GetMissingSize())), - std::bind(&Socket<T, PacketType>::ReadDataHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); - } + std::mutex _writeLock; + std::queue<MessageBuffer> _writeQueue; +#ifndef BOOST_ASIO_HAS_IOCP + MessageBuffer _writeBuffer; +#endif - void ReadHeaderHandlerInternal(boost::system::error_code error, size_t transferredBytes) +private: + void ReadHandlerInternal(boost::system::error_code error, size_t transferredBytes) { if (error) { @@ -183,70 +174,130 @@ private: return; } - _readHeaderBuffer.WriteCompleted(transferredBytes); - if (!IsHeaderReady()) + _readBuffer.WriteCompleted(transferredBytes); + ReadHandler(); + } + +#ifdef BOOST_ASIO_HAS_IOCP + + void WriteHandler(boost::system::error_code error, std::size_t transferedBytes) + { + if (!error) { - // incomplete, read more - AsyncReadMissingHeaderData(); - return; - } + std::unique_lock<std::mutex> deleteGuard(_writeLock); + + _isWritingAsync = false; + _writeQueue.front().ReadCompleted(transferedBytes); + if (!_writeQueue.front().GetActiveSize()) + _writeQueue.pop(); - ReadHeaderHandler(); + if (!_writeQueue.empty()) + AsyncProcessQueue(deleteGuard); + else if (_closing) + CloseSocket(); + } + else + CloseSocket(); } - void ReadDataHandlerInternal(boost::system::error_code error, size_t transferredBytes) +#else + + bool WriteHandler(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/) { + std::unique_lock<std::mutex> guard(_writeLock, std::try_to_lock); + if (!guard) + return false; + + if (!IsOpen()) + return false; + + std::size_t bytesToSend = _writeBuffer.GetActiveSize(); + + if (bytesToSend == 0) + return HandleQueue(guard); + + boost::system::error_code error; + std::size_t bytesWritten = _socket.write_some(boost::asio::buffer(_writeBuffer.GetReadPointer(), bytesToSend), error); + if (error) { - CloseSocket(); - return; - } + if (error == boost::asio::error::would_block || error == boost::asio::error::try_again) + return AsyncProcessQueue(guard); - _readDataBuffer.WriteCompleted(transferredBytes); - if (!IsDataReady()) + return false; + } + else if (bytesWritten == 0) + return false; + else if (bytesWritten < bytesToSend) //now n > 0 { - // incomplete, read more - AsyncReadMissingData(); - return; + _writeBuffer.ReadCompleted(bytesWritten); + _writeBuffer.Normalize(); + return AsyncProcessQueue(guard); } - ReadDataHandler(); + // now bytesWritten == bytesToSend + _writeBuffer.Reset(); + + return HandleQueue(guard); } - void WriteHandler(boost::system::error_code error, size_t /*transferedBytes*/) + bool HandleQueue(std::unique_lock<std::mutex>& guard) { - if (!error) + if (_writeQueue.empty()) + { + _isWritingAsync = false; + return false; + } + + MessageBuffer& queuedMessage = _writeQueue.front(); + + std::size_t bytesToSend = queuedMessage.GetActiveSize(); + + boost::system::error_code error; + std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error); + + if (error) { - std::lock_guard<std::mutex> deleteGuard(_writeLock); + if (error == boost::asio::error::would_block || error == boost::asio::error::try_again) + return AsyncProcessQueue(guard); - DeletePacket(_writeQueue.front()); _writeQueue.pop(); + return false; + } + else if (bytesSent == 0) + { + _writeQueue.pop(); + return false; + } + else if (bytesSent < bytesToSend) // now n > 0 + { + queuedMessage.ReadCompleted(bytesSent); + return AsyncProcessQueue(guard); + } - if (!_writeQueue.empty()) - AsyncWrite(_writeQueue.front()); - else if (_closing) - CloseSocket(); + _writeQueue.pop(); + if (_writeQueue.empty()) + { + _isWritingAsync = false; + return false; } - else - CloseSocket(); - } - template<typename Q = PacketType> - typename std::enable_if<std::is_pointer<Q>::value>::type DeletePacket(PacketType& packet) { delete packet; } + return true; + } - template<typename Q = PacketType> - typename std::enable_if<!std::is_pointer<Q>::value>::type DeletePacket(PacketType const& /*packet*/) { } +#endif tcp::socket _socket; boost::asio::ip::address _remoteAddress; uint16 _remotePort; - MessageBuffer _readHeaderBuffer; - MessageBuffer _readDataBuffer; + MessageBuffer _readBuffer; std::atomic<bool> _closed; std::atomic<bool> _closing; + + bool _isWritingAsync; }; #endif // __SOCKET_H__ diff --git a/src/server/shared/Networking/SocketMgr.h b/src/server/shared/Networking/SocketMgr.h new file mode 100644 index 00000000000..ed638ab89f3 --- /dev/null +++ b/src/server/shared/Networking/SocketMgr.h @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2008-2014 TrinityCore <http://www.trinitycore.org/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef SocketMgr_h__ +#define SocketMgr_h__ + +#include "AsyncAcceptor.h" +#include "Config.h" +#include "Errors.h" +#include "NetworkThread.h" +#include <boost/asio/ip/tcp.hpp> +#include <memory> + +using boost::asio::ip::tcp; + +template<class SocketType> +class SocketMgr +{ +public: + virtual ~SocketMgr() + { + delete[] _threads; + } + + virtual bool StartNetwork(boost::asio::io_service& service, std::string const& bindIp, uint16 port) + { + _threadCount = sConfigMgr->GetIntDefault("Network.Threads", 1); + + if (_threadCount <= 0) + { + TC_LOG_ERROR("misc", "Network.Threads is wrong in your config file"); + return false; + } + + _acceptor = new AsyncAcceptor(service, bindIp, port); + _threads = CreateThreads(); + + ASSERT(_threads); + + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Start(); + + return true; + } + + virtual void StopNetwork() + { + if (_threadCount != 0) + for (size_t i = 0; i < _threadCount; ++i) + _threads[i].Stop(); + + Wait(); + } + + void Wait() + { + if (_threadCount != 0) + for (size_t i = 0; i < _threadCount; ++i) + _threads[i].Wait(); + } + + virtual void OnSocketOpen(tcp::socket&& sock) + { + size_t min = 0; + + for (size_t i = 1; i < _threadCount; ++i) + if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount()) + min = i; + + try + { + std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock)); + newSocket->Start(); + + _threads[min].AddSocket(newSocket); + } + catch (boost::system::system_error const& err) + { + TC_LOG_INFO("network", "Failed to retrieve client's remote address %s", err.what()); + } + } + + int32 GetNetworkThreadCount() const { return _threadCount; } + +protected: + SocketMgr() : _threads(nullptr), _threadCount(1) + { + } + + virtual NetworkThread<SocketType>* CreateThreads() const = 0; + + AsyncAcceptor* _acceptor; + NetworkThread<SocketType>* _threads; + int32 _threadCount; +}; + +#endif // SocketMgr_h__ diff --git a/src/server/worldserver/Main.cpp b/src/server/worldserver/Main.cpp index e149902af02..191e44b3e93 100644 --- a/src/server/worldserver/Main.cpp +++ b/src/server/worldserver/Main.cpp @@ -46,6 +46,7 @@ #include "CliRunnable.h" #include "SystemConfig.h" #include "WorldSocket.h" +#include "WorldSocketMgr.h" using namespace boost::program_options; @@ -82,7 +83,7 @@ uint32 realmID; ///< Id of the realm void SignalHandler(const boost::system::error_code& error, int signalNumber); void FreezeDetectorHandler(const boost::system::error_code& error); -AsyncAcceptor<RASession>* StartRaSocketAcceptor(boost::asio::io_service& ioService); +AsyncAcceptor* StartRaSocketAcceptor(boost::asio::io_service& ioService); bool StartDB(); void StopDB(); void WorldUpdateLoop(); @@ -203,7 +204,7 @@ extern int main(int argc, char** argv) } // Start the Remote Access port (acceptor) if enabled - AsyncAcceptor<RASession>* raAcceptor = nullptr; + AsyncAcceptor* raAcceptor = nullptr; if (sConfigMgr->GetBoolDefault("Ra.Enable", false)) raAcceptor = StartRaSocketAcceptor(_ioService); @@ -217,11 +218,8 @@ extern int main(int argc, char** argv) // Launch the worldserver listener socket uint16 worldPort = uint16(sWorld->getIntConfig(CONFIG_PORT_WORLD)); std::string worldListener = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0"); - bool tcpNoDelay = sConfigMgr->GetBoolDefault("Network.TcpNodelay", true); - AsyncAcceptor<WorldSocket> worldAcceptor(_ioService, worldListener, worldPort, tcpNoDelay); - - sScriptMgr->OnNetworkStart(); + sWorldSocketMgr.StartNetwork(_ioService, worldListener, worldPort); // Set server online (allow connecting now) LoginDatabase.DirectPExecute("UPDATE realmlist SET flag = flag & ~%u, population = 0 WHERE id = '%u'", REALM_FLAG_INVALID, realmID); @@ -252,6 +250,8 @@ extern int main(int argc, char** argv) // unload battleground templates before different singletons destroyed sBattlegroundMgr->DeleteAllBattlegrounds(); + sWorldSocketMgr.StopNetwork(); + sInstanceSaveMgr->Unload(); sMapMgr->UnloadAll(); // unload all grids (including locked in memory) sObjectAccessor->UnloadAll(); // unload 'i_player2corpse' storage and remove from world @@ -379,12 +379,14 @@ void FreezeDetectorHandler(const boost::system::error_code& error) } } -AsyncAcceptor<RASession>* StartRaSocketAcceptor(boost::asio::io_service& ioService) +AsyncAcceptor* StartRaSocketAcceptor(boost::asio::io_service& ioService) { uint16 raPort = uint16(sConfigMgr->GetIntDefault("Ra.Port", 3443)); std::string raListener = sConfigMgr->GetStringDefault("Ra.IP", "0.0.0.0"); - return new AsyncAcceptor<RASession>(ioService, raListener, raPort); + AsyncAcceptor* acceptor = new AsyncAcceptor(ioService, raListener, raPort); + acceptor->AsyncAccept<RASession>(); + return acceptor; } /// Initialize connection to the databases |