diff options
author | Shauren <shauren.trinity@gmail.com> | 2016-02-19 19:23:04 +0100 |
---|---|---|
committer | Shauren <shauren.trinity@gmail.com> | 2016-02-19 19:23:04 +0100 |
commit | 97a79af4701621ec04b88c8b548dbc35d120e99e (patch) | |
tree | c3e7f3f2f7f5ee41565bf16ea884cf55aa75c911 | |
parent | 06ec1b8fe8dfe9bb8a225ed57a053eb546d386ad (diff) |
Core/Networking: Rewrite networking threading model
Each network thread has its own io_service - this means that all operations on a given socket except queueing packets run from a single thread, removing the need for locking
Sending packets now writes to a lockfree intermediate queue directly, encryption is applied in network thread if it was required at the time of sending the packet
-rw-r--r-- | src/common/Threading/MPSCQueue.h | 83 | ||||
-rw-r--r-- | src/server/bnetserver/Server/Session.cpp | 38 | ||||
-rw-r--r-- | src/server/bnetserver/Server/Session.h | 8 | ||||
-rw-r--r-- | src/server/bnetserver/Server/SessionManager.cpp | 7 | ||||
-rw-r--r-- | src/server/bnetserver/Server/SessionManager.h | 2 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocket.cpp | 105 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocket.h | 6 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocketMgr.cpp | 15 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocketMgr.h | 2 | ||||
-rw-r--r-- | src/server/game/World/World.cpp | 37 | ||||
-rw-r--r-- | src/server/game/World/World.h | 6 | ||||
-rw-r--r-- | src/server/shared/Networking/AsyncAcceptor.h | 26 | ||||
-rw-r--r-- | src/server/shared/Networking/MessageBuffer.h | 2 | ||||
-rw-r--r-- | src/server/shared/Networking/NetworkThread.h | 86 | ||||
-rw-r--r-- | src/server/shared/Networking/Socket.h | 75 | ||||
-rw-r--r-- | src/server/shared/Networking/SocketMgr.h | 27 |
16 files changed, 320 insertions, 205 deletions
diff --git a/src/common/Threading/MPSCQueue.h b/src/common/Threading/MPSCQueue.h new file mode 100644 index 00000000000..09648b844be --- /dev/null +++ b/src/common/Threading/MPSCQueue.h @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/> + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef MPSCQueue_h__ +#define MPSCQueue_h__ + +#include <atomic> +#include <utility> + +// C++ implementation of Dmitry Vyukov's lock free MPSC queue +// http://www.1024cores.net/home/lock-free-algorithms/queues/non-intrusive-mpsc-node-based-queue +template<typename T> +class MPSCQueue +{ +public: + MPSCQueue() : _head(new Node()), _tail(_head.load(std::memory_order_relaxed)) + { + Node* front = _head.load(std::memory_order_relaxed); + front->Next.store(nullptr, std::memory_order_relaxed); + } + + ~MPSCQueue() + { + T* output; + while (this->Dequeue(output)) + ; + + Node* front = _head.load(std::memory_order_relaxed); + delete front; + } + + void Enqueue(T* input) + { + Node* node = new Node(input); + Node* prevHead = _head.exchange(node, std::memory_order_acq_rel); + prevHead->Next.store(node, std::memory_order_release); + } + + bool Dequeue(T*& result) + { + Node* tail = _tail.load(std::memory_order_relaxed); + Node* next = tail->Next.load(std::memory_order_acquire); + if (!next) + return false; + + result = next->Data; + _tail.store(next, std::memory_order_release); + delete tail; + return true; + } + +private: + struct Node + { + Node() = default; + explicit Node(T* data) : Data(data) { Next.store(nullptr, std::memory_order_relaxed); } + + T* Data; + std::atomic<Node*> Next; + }; + + std::atomic<Node*> _head; + std::atomic<Node*> _tail; + + MPSCQueue(MPSCQueue const&) = delete; + MPSCQueue& operator=(MPSCQueue const&) = delete; +}; + +#endif // MPSCQueue_h__ diff --git a/src/server/bnetserver/Server/Session.cpp b/src/server/bnetserver/Server/Session.cpp index a5ede8d6524..4d54562501f 100644 --- a/src/server/bnetserver/Server/Session.cpp +++ b/src/server/bnetserver/Server/Session.cpp @@ -654,12 +654,37 @@ void Battlenet::Session::CheckIpCallback(PreparedQueryResult result) bool Battlenet::Session::Update() { + EncryptableBuffer* queued; + MessageBuffer buffer((std::size_t(BufferSizes::Read))); + while (_bufferQueue.Dequeue(queued)) + { + std::size_t packetSize = queued->Buffer.GetActiveSize(); + if (queued->Encrypt) + _crypt.EncryptSend(queued->Buffer.GetReadPointer(), packetSize); + + if (buffer.GetRemainingSpace() < packetSize) + { + QueuePacket(std::move(buffer)); + buffer.Resize(std::size_t(BufferSizes::Read)); + } + + if (buffer.GetRemainingSpace() >= packetSize) + buffer.Write(queued->Buffer.GetReadPointer(), packetSize); + else // single packet larger than 16384 bytes - client will reject. + QueuePacket(std::move(queued->Buffer)); + + delete queued; + } + + if (buffer.GetActiveSize() > 0) + QueuePacket(std::move(buffer)); + if (!BattlenetSocket::Update()) return false; if (_queryFuture.valid() && _queryFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready) { - auto callback = std::move(_queryCallback); + auto callback = _queryCallback; _queryCallback = nullptr; callback(_queryFuture.get()); } @@ -679,15 +704,12 @@ void Battlenet::Session::AsyncWrite(ServerPacket* packet) packet->Write(); - MessageBuffer buffer; - buffer.Write(packet->GetData(), packet->GetSize()); + EncryptableBuffer* buffer = new EncryptableBuffer(); + buffer->Buffer.Write(packet->GetData(), packet->GetSize()); + buffer->Encrypt = _crypt.IsInitialized(); delete packet; - std::unique_lock<std::mutex> guard(_writeLock); - - _crypt.EncryptSend(buffer.GetReadPointer(), buffer.GetActiveSize()); - - QueuePacket(std::move(buffer), guard); + _bufferQueue.Enqueue(buffer); } inline void ReplaceResponse(Battlenet::ServerPacket** oldResponse, Battlenet::ServerPacket* newResponse) diff --git a/src/server/bnetserver/Server/Session.h b/src/server/bnetserver/Server/Session.h index 75c30096417..2443d694a80 100644 --- a/src/server/bnetserver/Server/Session.h +++ b/src/server/bnetserver/Server/Session.h @@ -23,6 +23,7 @@ #include "Socket.h" #include "BigNumber.h" #include "Callback.h" +#include "MPSCQueue.h" #include <memory> #include <boost/asio/ip/tcp.hpp> @@ -174,6 +175,13 @@ namespace Battlenet std::queue<ModuleType> _modulesWaitingForData; + struct EncryptableBuffer + { + MessageBuffer Buffer; + bool Encrypt; + }; + + MPSCQueue<EncryptableBuffer> _bufferQueue; PacketCrypt _crypt; bool _authed; bool _subscribedToRealmListUpdates; diff --git a/src/server/bnetserver/Server/SessionManager.cpp b/src/server/bnetserver/Server/SessionManager.cpp index 8201f4869b4..c53214495d4 100644 --- a/src/server/bnetserver/Server/SessionManager.cpp +++ b/src/server/bnetserver/Server/SessionManager.cpp @@ -22,7 +22,8 @@ bool Battlenet::SessionManager::StartNetwork(boost::asio::io_service& service, s if (!BaseSocketMgr::StartNetwork(service, bindIp, port)) return false; - _acceptor->AsyncAcceptManaged(&OnSocketAccept); + _acceptor->SetSocketFactory(std::bind(&BaseSocketMgr::GetSocketForAccept, this)); + _acceptor->AsyncAcceptWithCallback<&OnSocketAccept>(); return true; } @@ -31,9 +32,9 @@ NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() co return new NetworkThread<Session>[GetNetworkThreadCount()]; } -void Battlenet::SessionManager::OnSocketAccept(tcp::socket&& sock) +void Battlenet::SessionManager::OnSocketAccept(tcp::socket&& sock, uint32 threadIndex) { - sSessionMgr.OnSocketOpen(std::forward<tcp::socket>(sock)); + sSessionMgr.OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex); } void Battlenet::SessionManager::AddSession(Session* session) diff --git a/src/server/bnetserver/Server/SessionManager.h b/src/server/bnetserver/Server/SessionManager.h index fe262b29f4e..5cf0b199f15 100644 --- a/src/server/bnetserver/Server/SessionManager.h +++ b/src/server/bnetserver/Server/SessionManager.h @@ -75,7 +75,7 @@ namespace Battlenet NetworkThread<Session>* CreateThreads() const override; private: - static void OnSocketAccept(tcp::socket&& sock); + static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex); SessionMap _sessions; SessionByAccountMap _sessionsByAccountId; diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp index 8f8c9d89502..acec25a7363 100644 --- a/src/server/game/Server/WorldSocket.cpp +++ b/src/server/game/Server/WorldSocket.cpp @@ -38,6 +38,17 @@ struct CompressedWorldPacket uint32 CompressedAdler; }; +class EncryptablePacket : public WorldPacket +{ +public: + EncryptablePacket(WorldPacket const& packet, bool encrypt) : WorldPacket(packet), _encrypt(encrypt) { } + + bool NeedsEncryption() const { return _encrypt; } + +private: + bool _encrypt; +}; + #pragma pack(pop) using boost::asio::ip::tcp; @@ -76,11 +87,8 @@ void WorldSocket::Start() stmt->setString(0, ip_address); stmt->setUInt32(1, inet_addr(ip_address.c_str())); - { - std::lock_guard<std::mutex> guard(_queryLock); - _queryCallback = io_service().wrap(std::bind(&WorldSocket::CheckIpCallback, this, std::placeholders::_1)); - _queryFuture = LoginDatabase.AsyncQuery(stmt); - } + _queryCallback = std::bind(&WorldSocket::CheckIpCallback, this, std::placeholders::_1); + _queryFuture = LoginDatabase.AsyncQuery(stmt); } void WorldSocket::CheckIpCallback(PreparedQueryResult result) @@ -116,23 +124,50 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result) initializer.Write(&header, sizeof(header.Setup.Size)); initializer.Write(ServerConnectionInitialize.c_str(), ServerConnectionInitialize.length()); - std::unique_lock<std::mutex> guard(_writeLock); - QueuePacket(std::move(initializer), guard); + // - io_service.run thread, safe. + QueuePacket(std::move(initializer)); } bool WorldSocket::Update() { + EncryptablePacket* queued; + MessageBuffer buffer; + while (_bufferQueue.Dequeue(queued)) + { + uint32 sizeOfHeader = SizeOfServerHeader[queued->NeedsEncryption()]; + uint32 packetSize = queued->size(); + if (packetSize > MinSizeForCompression && queued->NeedsEncryption()) + packetSize = compressBound(packetSize) + sizeof(CompressedWorldPacket); + + if (buffer.GetRemainingSpace() < packetSize + sizeOfHeader) + { + QueuePacket(std::move(buffer)); + buffer.Resize(4096); + } + + if (buffer.GetRemainingSpace() >= packetSize + sizeOfHeader) + WritePacketToBuffer(*queued, buffer); + else // single packet larger than 4096 bytes + { + MessageBuffer packetBuffer(packetSize + sizeOfHeader); + WritePacketToBuffer(*queued, packetBuffer); + QueuePacket(std::move(packetBuffer)); + } + + delete queued; + } + + if (buffer.GetActiveSize() > 0) + QueuePacket(std::move(buffer)); + if (!BaseSocket::Update()) return false; + if (_queryFuture.valid() && _queryFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready) { - std::lock_guard<std::mutex> guard(_queryLock); - if (_queryFuture.valid() && _queryFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready) - { - auto callback = std::move(_queryCallback); - _queryCallback = nullptr; - callback(_queryFuture.get()); - } + auto callback = _queryCallback; + _queryCallback = nullptr; + callback(_queryFuture.get()); } return true; @@ -428,29 +463,13 @@ void WorldSocket::SendPacket(WorldPacket const& packet) if (sPacketLog->CanLogPacket()) sPacketLog->LogPacket(packet, SERVER_TO_CLIENT, GetRemoteIpAddress(), GetRemotePort(), GetConnectionType()); - uint32 packetSize = packet.size(); - uint32 sizeOfHeader = SizeOfServerHeader[_authCrypt.IsInitialized()]; - if (packetSize > MinSizeForCompression && _authCrypt.IsInitialized()) - packetSize = compressBound(packetSize) + sizeof(CompressedWorldPacket); - - std::unique_lock<std::mutex> guard(_writeLock); - -#ifndef TC_SOCKET_USE_IOCP - if (_writeQueue.empty() && _writeBuffer.GetRemainingSpace() >= sizeOfHeader + packetSize) - WritePacketToBuffer(packet, _writeBuffer); - else -#endif - { - MessageBuffer buffer(sizeOfHeader + packetSize); - WritePacketToBuffer(packet, buffer); - QueuePacket(std::move(buffer), guard); - } + _bufferQueue.Enqueue(new EncryptablePacket(packet, _authCrypt.IsInitialized())); } -void WorldSocket::WritePacketToBuffer(WorldPacket const& packet, MessageBuffer& buffer) +void WorldSocket::WritePacketToBuffer(EncryptablePacket const& packet, MessageBuffer& buffer) { ServerPktHeader header; - uint32 sizeOfHeader = SizeOfServerHeader[_authCrypt.IsInitialized()]; + uint32 sizeOfHeader = SizeOfServerHeader[packet.NeedsEncryption()]; uint32 opcode = packet.GetOpcode(); uint32 packetSize = packet.size(); @@ -458,7 +477,7 @@ void WorldSocket::WritePacketToBuffer(WorldPacket const& packet, MessageBuffer& uint8* headerPos = buffer.GetWritePointer(); buffer.WriteCompleted(sizeOfHeader); - if (packetSize > MinSizeForCompression && _authCrypt.IsInitialized()) + if (packetSize > MinSizeForCompression && packet.NeedsEncryption()) { CompressedWorldPacket cmp; cmp.UncompressedSize = packetSize + 4; @@ -481,7 +500,7 @@ void WorldSocket::WritePacketToBuffer(WorldPacket const& packet, MessageBuffer& else if (!packet.empty()) buffer.Write(packet.contents(), packet.size()); - if (_authCrypt.IsInitialized()) + if (packet.NeedsEncryption()) { header.Normal.Size = packetSize; header.Normal.Command = opcode; @@ -598,11 +617,8 @@ void WorldSocket::HandleAuthSession(std::shared_ptr<WorldPackets::Auth::AuthSess stmt->setInt32(0, int32(realm.Id.Realm)); stmt->setString(1, authSession->Account); - { - std::lock_guard<std::mutex> guard(_queryLock); - _queryCallback = io_service().wrap(std::bind(&WorldSocket::HandleAuthSessionCallback, this, authSession, std::placeholders::_1)); - _queryFuture = LoginDatabase.AsyncQuery(stmt); - } + _queryCallback = std::bind(&WorldSocket::HandleAuthSessionCallback, this, authSession, std::placeholders::_1); + _queryFuture = LoginDatabase.AsyncQuery(stmt); } void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<WorldPackets::Auth::AuthSession> authSession, PreparedQueryResult result) @@ -768,7 +784,7 @@ void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<WorldPackets::Auth:: if (wardenActive) _worldSession->InitWarden(&account.Game.SessionKey, account.BattleNet.OS); - _queryCallback = io_service().wrap(std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1)); + _queryCallback = std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1); _queryFuture = _worldSession->LoadPermissionsAsync(); AsyncRead(); } @@ -801,11 +817,8 @@ void WorldSocket::HandleAuthContinuedSession(std::shared_ptr<WorldPackets::Auth: PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_CONTINUED_SESSION); stmt->setUInt32(0, accountId); - { - std::lock_guard<std::mutex> guard(_queryLock); - _queryCallback = io_service().wrap(std::bind(&WorldSocket::HandleAuthContinuedSessionCallback, this, authSession, std::placeholders::_1)); - _queryFuture = LoginDatabase.AsyncQuery(stmt); - } + _queryCallback = std::bind(&WorldSocket::HandleAuthContinuedSessionCallback, this, authSession, std::placeholders::_1); + _queryFuture = LoginDatabase.AsyncQuery(stmt); } void WorldSocket::HandleAuthContinuedSessionCallback(std::shared_ptr<WorldPackets::Auth::AuthContinuedSession> authSession, PreparedQueryResult result) diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h index d6d29fb2826..205494ca4ea 100644 --- a/src/server/game/Server/WorldSocket.h +++ b/src/server/game/Server/WorldSocket.h @@ -26,11 +26,13 @@ #include "Util.h" #include "WorldPacket.h" #include "WorldSession.h" +#include "MPSCQueue.h" #include <chrono> #include <boost/asio/ip/tcp.hpp> using boost::asio::ip::tcp; struct z_stream_s; +class EncryptablePacket; namespace WorldPackets { @@ -111,7 +113,7 @@ private: void LogOpcodeText(OpcodeClient opcode, std::unique_lock<std::mutex> const& guard) const; /// sends and logs network.opcode without accessing WorldSession void SendPacketAndLogOpcode(WorldPacket const& packet); - void WritePacketToBuffer(WorldPacket const& packet, MessageBuffer& buffer); + void WritePacketToBuffer(EncryptablePacket const& packet, MessageBuffer& buffer); uint32 CompressPacket(uint8* buffer, WorldPacket const& packet); void HandleSendAuthSession(); @@ -142,12 +144,12 @@ private: MessageBuffer _headerBuffer; MessageBuffer _packetBuffer; + MPSCQueue<EncryptablePacket> _bufferQueue; z_stream_s* _compressionStream; bool _initialized; - std::mutex _queryLock; PreparedQueryResultFuture _queryFuture; std::function<void(PreparedQueryResult&&)> _queryCallback; std::string _ipCountry; diff --git a/src/server/game/Server/WorldSocketMgr.cpp b/src/server/game/Server/WorldSocketMgr.cpp index 937483e1179..94c5a8f6979 100644 --- a/src/server/game/Server/WorldSocketMgr.cpp +++ b/src/server/game/Server/WorldSocketMgr.cpp @@ -24,9 +24,9 @@ #include <boost/system/error_code.hpp> -static void OnSocketAccept(tcp::socket&& sock) +static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex) { - sWorldSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock)); + sWorldSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex); } class WorldSocketThread : public NetworkThread<WorldSocket> @@ -73,8 +73,11 @@ bool WorldSocketMgr::StartNetwork(boost::asio::io_service& service, std::string BaseSocketMgr::StartNetwork(service, bindIp, port); _instanceAcceptor = new AsyncAcceptor(service, bindIp, uint16(sWorld->getIntConfig(CONFIG_PORT_INSTANCE))); - _acceptor->AsyncAcceptManaged(&OnSocketAccept); - _instanceAcceptor->AsyncAcceptManaged(&OnSocketAccept); + _acceptor->SetSocketFactory(std::bind(&BaseSocketMgr::GetSocketForAccept, this)); + _instanceAcceptor->SetSocketFactory(std::bind(&BaseSocketMgr::GetSocketForAccept, this)); + + _acceptor->AsyncAcceptWithCallback<&OnSocketAccept>(); + _instanceAcceptor->AsyncAcceptWithCallback<&OnSocketAccept>(); sScriptMgr->OnNetworkStart(); return true; @@ -87,7 +90,7 @@ void WorldSocketMgr::StopNetwork() sScriptMgr->OnNetworkStop(); } -void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock) +void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) { // set some options here if (_socketSendBufferSize >= 0) @@ -115,7 +118,7 @@ void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock) //sock->m_OutBufferSize = static_cast<size_t> (m_SockOutUBuff); - BaseSocketMgr::OnSocketOpen(std::forward<tcp::socket>(sock)); + BaseSocketMgr::OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex); } NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const diff --git a/src/server/game/Server/WorldSocketMgr.h b/src/server/game/Server/WorldSocketMgr.h index d4bf4115deb..2079b62d14f 100644 --- a/src/server/game/Server/WorldSocketMgr.h +++ b/src/server/game/Server/WorldSocketMgr.h @@ -49,7 +49,7 @@ public: /// Stops all network threads, It will wait for all running threads . void StopNetwork() override; - void OnSocketOpen(tcp::socket&& sock) override; + void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) override; protected: WorldSocketMgr(); diff --git a/src/server/game/World/World.cpp b/src/server/game/World/World.cpp index 331c2114b4e..d2ccb025165 100644 --- a/src/server/game/World/World.cpp +++ b/src/server/game/World/World.cpp @@ -227,7 +227,7 @@ void World::AddSession(WorldSession* s) addSessQueue.add(s); } -void World::AddInstanceSocket(std::shared_ptr<WorldSocket> sock, uint64 connectToKey) +void World::AddInstanceSocket(std::weak_ptr<WorldSocket> sock, uint64 connectToKey) { _linkSocketQueue.add(std::make_pair(sock, connectToKey)); } @@ -298,25 +298,28 @@ void World::AddSession_(WorldSession* s) } } -void World::ProcessLinkInstanceSocket(std::pair<std::shared_ptr<WorldSocket>, uint64> linkInfo) +void World::ProcessLinkInstanceSocket(std::pair<std::weak_ptr<WorldSocket>, uint64> linkInfo) { - if (!linkInfo.first->IsOpen()) - return; + if (std::shared_ptr<WorldSocket> sock = linkInfo.first.lock()) + { + if (!sock->IsOpen()) + return; - WorldSession::ConnectToKey key; - key.Raw = linkInfo.second; + WorldSession::ConnectToKey key; + key.Raw = linkInfo.second; - WorldSession* session = FindSession(uint32(key.Fields.AccountId)); - if (!session || session->GetConnectToInstanceKey() != linkInfo.second) - { - linkInfo.first->SendAuthResponseError(AUTH_SESSION_EXPIRED); - linkInfo.first->DelayedCloseSocket(); - return; - } + WorldSession* session = FindSession(uint32(key.Fields.AccountId)); + if (!session || session->GetConnectToInstanceKey() != linkInfo.second) + { + sock->SendAuthResponseError(AUTH_SESSION_EXPIRED); + sock->DelayedCloseSocket(); + return; + } - linkInfo.first->SetWorldSession(session); - session->AddInstanceConnection(linkInfo.first); - session->HandleContinuePlayerLogin(); + sock->SetWorldSession(session); + session->AddInstanceConnection(sock); + session->HandleContinuePlayerLogin(); + } } bool World::HasRecentlyDisconnected(WorldSession* session) @@ -2821,7 +2824,7 @@ void World::SendServerMessage(ServerMessageType messageID, std::string stringPar void World::UpdateSessions(uint32 diff) { - std::pair<std::shared_ptr<WorldSocket>, uint64> linkInfo; + std::pair<std::weak_ptr<WorldSocket>, uint64> linkInfo; while (_linkSocketQueue.next(linkInfo)) ProcessLinkInstanceSocket(std::move(linkInfo)); diff --git a/src/server/game/World/World.h b/src/server/game/World/World.h index d28738ffc8c..0f9a27c733f 100644 --- a/src/server/game/World/World.h +++ b/src/server/game/World/World.h @@ -568,7 +568,7 @@ class World WorldSession* FindSession(uint32 id) const; void AddSession(WorldSession* s); - void AddInstanceSocket(std::shared_ptr<WorldSocket> sock, uint64 connectToKey); + void AddInstanceSocket(std::weak_ptr<WorldSocket> sock, uint64 connectToKey); void SendAutoBroadcast(); bool RemoveSession(uint32 id); /// Get the number of current active sessions @@ -878,8 +878,8 @@ class World void AddSession_(WorldSession* s); LockedQueue<WorldSession*> addSessQueue; - void ProcessLinkInstanceSocket(std::pair<std::shared_ptr<WorldSocket>, uint64> linkInfo); - LockedQueue<std::pair<std::shared_ptr<WorldSocket>, uint64>> _linkSocketQueue; + void ProcessLinkInstanceSocket(std::pair<std::weak_ptr<WorldSocket>, uint64> linkInfo); + LockedQueue<std::pair<std::weak_ptr<WorldSocket>, uint64>> _linkSocketQueue; // used versions std::string m_DBVersion; diff --git a/src/server/shared/Networking/AsyncAcceptor.h b/src/server/shared/Networking/AsyncAcceptor.h index 2fa1e448ff8..d21801a64ac 100644 --- a/src/server/shared/Networking/AsyncAcceptor.h +++ b/src/server/shared/Networking/AsyncAcceptor.h @@ -20,34 +20,39 @@ #include "Log.h" #include <boost/asio.hpp> +#include <functional> using boost::asio::ip::tcp; class AsyncAcceptor { public: - typedef void(*ManagerAcceptHandler)(tcp::socket&& newSocket); + typedef void(*AcceptCallback)(tcp::socket&& newSocket, uint32 threadIndex); AsyncAcceptor(boost::asio::io_service& ioService, std::string const& bindIp, uint16 port) : _acceptor(ioService, tcp::endpoint(boost::asio::ip::address::from_string(bindIp), port)), - _socket(ioService), _closed(false) + _socket(ioService), _closed(false), _socketFactory(std::bind(&AsyncAcceptor::DefeaultSocketFactory, this)) { } - template <class T> + template<class T> void AsyncAccept(); - void AsyncAcceptManaged(ManagerAcceptHandler mgrHandler) + template<AcceptCallback acceptCallback> + void AsyncAcceptWithCallback() { - _acceptor.async_accept(_socket, [this, mgrHandler](boost::system::error_code error) + tcp::socket* socket; + uint32 threadIndex; + std::tie(socket, threadIndex) = _socketFactory(); + _acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error) { if (!error) { try { - _socket.non_blocking(true); + socket->non_blocking(true); - mgrHandler(std::move(_socket)); + acceptCallback(std::move(*socket), threadIndex); } catch (boost::system::system_error const& err) { @@ -56,7 +61,7 @@ public: } if (!_closed) - AsyncAcceptManaged(mgrHandler); + this->AsyncAcceptWithCallback<acceptCallback>(); }); } @@ -69,10 +74,15 @@ public: _acceptor.close(err); } + void SetSocketFactory(std::function<std::pair<tcp::socket*, uint32>()> func) { _socketFactory = func; } + private: + std::pair<tcp::socket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); } + tcp::acceptor _acceptor; tcp::socket _socket; std::atomic<bool> _closed; + std::function<std::pair<tcp::socket*, uint32>()> _socketFactory; }; template<class T> diff --git a/src/server/shared/Networking/MessageBuffer.h b/src/server/shared/Networking/MessageBuffer.h index a6ed9b31e8f..42b65be8398 100644 --- a/src/server/shared/Networking/MessageBuffer.h +++ b/src/server/shared/Networking/MessageBuffer.h @@ -105,7 +105,7 @@ public: return std::move(_storage); } - MessageBuffer& operator=(MessageBuffer& right) + MessageBuffer& operator=(MessageBuffer const& right) { if (this != &right) { diff --git a/src/server/shared/Networking/NetworkThread.h b/src/server/shared/Networking/NetworkThread.h index e183209e989..5eb2fcb2f6a 100644 --- a/src/server/shared/Networking/NetworkThread.h +++ b/src/server/shared/Networking/NetworkThread.h @@ -22,6 +22,8 @@ #include "Errors.h" #include "Log.h" #include "Timer.h" +#include <boost/asio/ip/tcp.hpp> +#include <boost/asio/deadline_timer.hpp> #include <atomic> #include <chrono> #include <memory> @@ -29,11 +31,14 @@ #include <set> #include <thread> +using boost::asio::ip::tcp; + template<class SocketType> class NetworkThread { public: - NetworkThread() : _connections(0), _stopped(false), _thread(nullptr) + NetworkThread() : _connections(0), _stopped(false), _thread(nullptr), + _acceptSocket(_io_service), _updateTimer(_io_service) { } @@ -50,6 +55,7 @@ public: void Stop() { _stopped = true; + _io_service.stop(); } bool Start() @@ -80,10 +86,12 @@ public: std::lock_guard<std::mutex> lock(_newSocketsLock); ++_connections; - _newSockets.insert(sock); + _newSockets.push_back(sock); SocketAdded(sock); } + tcp::socket* GetSocketForAccept() { return &_acceptSocket; } + protected: virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { } virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { } @@ -95,16 +103,15 @@ protected: if (_newSockets.empty()) return; - for (typename SocketSet::const_iterator i = _newSockets.begin(); i != _newSockets.end(); ++i) + for (std::shared_ptr<SocketType> sock : _newSockets) { - if (!(*i)->IsOpen()) + if (!sock->IsOpen()) { - SocketRemoved(*i); - + SocketRemoved(sock); --_connections; } else - _Sockets.insert(*i); + _sockets.push_back(sock); } _newSockets.clear(); @@ -114,55 +121,58 @@ protected: { TC_LOG_DEBUG("misc", "Network Thread Starting"); - typename SocketSet::iterator i, t; + _updateTimer.expires_from_now(boost::posix_time::milliseconds(10)); + _updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this)); + _io_service.run(); - uint32 sleepTime = 10; - uint32 tickStart = 0, diff = 0; - while (!_stopped) - { - std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime)); + TC_LOG_DEBUG("misc", "Network Thread exits"); + _newSockets.clear(); + _sockets.clear(); + } + + void Update() + { + if (_stopped) + return; - tickStart = getMSTime(); + _updateTimer.expires_from_now(boost::posix_time::milliseconds(10)); + _updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this)); - AddNewSockets(); + AddNewSockets(); - for (i = _Sockets.begin(); i != _Sockets.end();) + _sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr<SocketType> sock) + { + if (!sock->Update()) { - if (!(*i)->Update()) - { - if ((*i)->IsOpen()) - (*i)->CloseSocket(); - - SocketRemoved(*i); - - --_connections; - _Sockets.erase(i++); - } - else - ++i; - } + if (sock->IsOpen()) + sock->CloseSocket(); - diff = GetMSTimeDiffToNow(tickStart); - sleepTime = diff > 10 ? 0 : 10 - diff; - } + SocketRemoved(sock); - TC_LOG_DEBUG("misc", "Network Thread exits"); - _newSockets.clear(); - _Sockets.clear(); + --_connections; + return true; + } + + return false; + }), _sockets.end()); } private: - typedef std::set<std::shared_ptr<SocketType> > SocketSet; + typedef std::vector<std::shared_ptr<SocketType>> SocketContainer; std::atomic<int32> _connections; std::atomic<bool> _stopped; std::thread* _thread; - SocketSet _Sockets; + SocketContainer _sockets; std::mutex _newSocketsLock; - SocketSet _newSockets; + SocketContainer _newSockets; + + boost::asio::io_service _io_service; + tcp::socket _acceptSocket; + boost::asio::deadline_timer _updateTimer; }; #endif // NetworkThread_h__ diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h index 34ee50eb84e..07f427652aa 100644 --- a/src/server/shared/Networking/Socket.h +++ b/src/server/shared/Networking/Socket.h @@ -21,15 +21,11 @@ #include "MessageBuffer.h" #include "Log.h" #include <atomic> -#include <vector> -#include <mutex> #include <queue> #include <memory> #include <functional> #include <type_traits> #include <boost/asio/ip/tcp.hpp> -#include <boost/asio/write.hpp> -#include <boost/asio/read.hpp> using boost::asio::ip::tcp; @@ -63,14 +59,10 @@ public: return false; #ifndef TC_SOCKET_USE_IOCP - std::unique_lock<std::mutex> guard(_writeLock); - if (!guard) + if (_isWritingAsync || _writeQueue.empty()) return true; - if (_isWritingAsync || (!_writeBuffer.GetActiveSize() && _writeQueue.empty())) - return true; - - for (; WriteHandler(guard);) + for (; HandleQueue();) ; #endif @@ -98,14 +90,12 @@ public: std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); } - void QueuePacket(MessageBuffer&& buffer, std::unique_lock<std::mutex>& guard) + void QueuePacket(MessageBuffer&& buffer) { _writeQueue.push(std::move(buffer)); #ifdef TC_SOCKET_USE_IOCP - AsyncProcessQueue(guard); -#else - (void)guard; + AsyncProcessQueue(); #endif } @@ -135,7 +125,7 @@ protected: virtual void ReadHandler() = 0; - bool AsyncProcessQueue(std::unique_lock<std::mutex>&) + bool AsyncProcessQueue() { if (_isWritingAsync) return false; @@ -157,19 +147,12 @@ protected: void SetNoDelay(bool enable) { boost::system::error_code err; - _socket.set_option(boost::asio::ip::tcp::no_delay(enable), err); + _socket.set_option(tcp::no_delay(enable), err); if (err) TC_LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for %s - %d (%s)", GetRemoteIpAddress().to_string().c_str(), err.value(), err.message().c_str()); } - std::mutex _writeLock; - std::queue<MessageBuffer> _writeQueue; -#ifndef TC_SOCKET_USE_IOCP - MessageBuffer _writeBuffer; -#endif - - boost::asio::io_service& io_service() { return _socket.get_io_service(); } private: void ReadHandlerInternal(boost::system::error_code error, size_t transferredBytes) @@ -190,15 +173,13 @@ private: { if (!error) { - std::unique_lock<std::mutex> deleteGuard(_writeLock); - _isWritingAsync = false; _writeQueue.front().ReadCompleted(transferedBytes); if (!_writeQueue.front().GetActiveSize()) _writeQueue.pop(); if (!_writeQueue.empty()) - AsyncProcessQueue(deleteGuard); + AsyncProcessQueue(); else if (_closing) CloseSocket(); } @@ -210,48 +191,15 @@ private: void WriteHandlerWrapper(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/) { - std::unique_lock<std::mutex> guard(_writeLock); _isWritingAsync = false; - WriteHandler(guard); + HandleQueue(); } - bool WriteHandler(std::unique_lock<std::mutex>& guard) + bool HandleQueue() { if (!IsOpen()) return false; - std::size_t bytesToSend = _writeBuffer.GetActiveSize(); - - if (bytesToSend == 0) - return HandleQueue(guard); - - boost::system::error_code error; - std::size_t bytesWritten = _socket.write_some(boost::asio::buffer(_writeBuffer.GetReadPointer(), bytesToSend), error); - - if (error) - { - if (error == boost::asio::error::would_block || error == boost::asio::error::try_again) - return AsyncProcessQueue(guard); - - return false; - } - else if (bytesWritten == 0) - return false; - else if (bytesWritten < bytesToSend) - { - _writeBuffer.ReadCompleted(bytesWritten); - _writeBuffer.Normalize(); - return AsyncProcessQueue(guard); - } - - // now bytesWritten == bytesToSend - _writeBuffer.Reset(); - - return HandleQueue(guard); - } - - bool HandleQueue(std::unique_lock<std::mutex>& guard) - { if (_writeQueue.empty()) return false; @@ -265,7 +213,7 @@ private: if (error) { if (error == boost::asio::error::would_block || error == boost::asio::error::try_again) - return AsyncProcessQueue(guard); + return AsyncProcessQueue(); _writeQueue.pop(); return false; @@ -278,7 +226,7 @@ private: else if (bytesSent < bytesToSend) // now n > 0 { queuedMessage.ReadCompleted(bytesSent); - return AsyncProcessQueue(guard); + return AsyncProcessQueue(); } _writeQueue.pop(); @@ -293,6 +241,7 @@ private: uint16 _remotePort; MessageBuffer _readBuffer; + std::queue<MessageBuffer> _writeQueue; std::atomic<bool> _closed; std::atomic<bool> _closing; diff --git a/src/server/shared/Networking/SocketMgr.h b/src/server/shared/Networking/SocketMgr.h index 4037c85baa1..b14aac4ca47 100644 --- a/src/server/shared/Networking/SocketMgr.h +++ b/src/server/shared/Networking/SocketMgr.h @@ -90,20 +90,14 @@ public: _threads[i].Wait(); } - virtual void OnSocketOpen(tcp::socket&& sock) + virtual void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) { - size_t min = 0; - - for (int32 i = 1; i < _threadCount; ++i) - if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount()) - min = i; - try { std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock)); newSocket->Start(); - _threads[min].AddSocket(newSocket); + _threads[threadIndex].AddSocket(newSocket); } catch (boost::system::system_error const& err) { @@ -113,6 +107,23 @@ public: int32 GetNetworkThreadCount() const { return _threadCount; } + uint32 SelectThreadWithMinConnections() const + { + uint32 min = 0; + + for (int32 i = 1; i < _threadCount; ++i) + if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount()) + min = i; + + return min; + } + + std::pair<tcp::socket*, uint32> GetSocketForAccept() + { + uint32 threadIndex = SelectThreadWithMinConnections(); + return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex); + } + protected: SocketMgr() : _acceptor(nullptr), _threads(nullptr), _threadCount(1) { |