aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorShauren <shauren.trinity@gmail.com>2016-02-20 13:08:03 +0100
committerShauren <shauren.trinity@gmail.com>2016-02-20 13:08:03 +0100
commitb2e03a744813a17bf0c01c9ef010e65cac078420 (patch)
tree6513a0674f2a07c89b05b089e8d52e8accde68cb /src
parentd4184065b6d51b250ab08f35e88868cef631ed4b (diff)
Core/Networking: Rewrite networking threading model
Each network thread has its own io_service - this means that all operations on a given socket except queueing packets run from a single thread, removing the need for locking Sending packets now writes to a lockfree intermediate queue directly, encryption is applied in network thread if it was required at the time of sending the packet (cherry picked from commit 97a79af4701621ec04b88c8b548dbc35d120e99e)
Diffstat (limited to 'src')
-rw-r--r--src/common/Threading/MPSCQueue.h83
-rw-r--r--src/server/authserver/Server/AuthSession.cpp5
-rw-r--r--src/server/authserver/Server/AuthSocketMgr.h15
-rw-r--r--src/server/game/Server/WorldSocket.cpp98
-rw-r--r--src/server/game/Server/WorldSocket.h4
-rw-r--r--src/server/game/Server/WorldSocketMgr.cpp12
-rw-r--r--src/server/game/Server/WorldSocketMgr.h2
-rw-r--r--src/server/shared/Networking/AsyncAcceptor.h40
-rw-r--r--src/server/shared/Networking/MessageBuffer.h2
-rw-r--r--src/server/shared/Networking/NetworkThread.h84
-rw-r--r--src/server/shared/Networking/Socket.h74
-rw-r--r--src/server/shared/Networking/SocketMgr.h37
12 files changed, 277 insertions, 179 deletions
diff --git a/src/common/Threading/MPSCQueue.h b/src/common/Threading/MPSCQueue.h
new file mode 100644
index 00000000000..09648b844be
--- /dev/null
+++ b/src/common/Threading/MPSCQueue.h
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef MPSCQueue_h__
+#define MPSCQueue_h__
+
+#include <atomic>
+#include <utility>
+
+// C++ implementation of Dmitry Vyukov's lock free MPSC queue
+// http://www.1024cores.net/home/lock-free-algorithms/queues/non-intrusive-mpsc-node-based-queue
+template<typename T>
+class MPSCQueue
+{
+public:
+ MPSCQueue() : _head(new Node()), _tail(_head.load(std::memory_order_relaxed))
+ {
+ Node* front = _head.load(std::memory_order_relaxed);
+ front->Next.store(nullptr, std::memory_order_relaxed);
+ }
+
+ ~MPSCQueue()
+ {
+ T* output;
+ while (this->Dequeue(output))
+ ;
+
+ Node* front = _head.load(std::memory_order_relaxed);
+ delete front;
+ }
+
+ void Enqueue(T* input)
+ {
+ Node* node = new Node(input);
+ Node* prevHead = _head.exchange(node, std::memory_order_acq_rel);
+ prevHead->Next.store(node, std::memory_order_release);
+ }
+
+ bool Dequeue(T*& result)
+ {
+ Node* tail = _tail.load(std::memory_order_relaxed);
+ Node* next = tail->Next.load(std::memory_order_acquire);
+ if (!next)
+ return false;
+
+ result = next->Data;
+ _tail.store(next, std::memory_order_release);
+ delete tail;
+ return true;
+ }
+
+private:
+ struct Node
+ {
+ Node() = default;
+ explicit Node(T* data) : Data(data) { Next.store(nullptr, std::memory_order_relaxed); }
+
+ T* Data;
+ std::atomic<Node*> Next;
+ };
+
+ std::atomic<Node*> _head;
+ std::atomic<Node*> _tail;
+
+ MPSCQueue(MPSCQueue const&) = delete;
+ MPSCQueue& operator=(MPSCQueue const&) = delete;
+};
+
+#endif // MPSCQueue_h__
diff --git a/src/server/authserver/Server/AuthSession.cpp b/src/server/authserver/Server/AuthSession.cpp
index 519cd1f19f7..57e5d6682f2 100644
--- a/src/server/authserver/Server/AuthSession.cpp
+++ b/src/server/authserver/Server/AuthSession.cpp
@@ -274,10 +274,7 @@ void AuthSession::SendPacket(ByteBuffer& packet)
{
MessageBuffer buffer;
buffer.Write(packet.contents(), packet.size());
-
- std::unique_lock<std::mutex> guard(_writeLock);
-
- QueuePacket(std::move(buffer), guard);
+ QueuePacket(std::move(buffer));
}
}
diff --git a/src/server/authserver/Server/AuthSocketMgr.h b/src/server/authserver/Server/AuthSocketMgr.h
index fa96502663f..a16b7d405b9 100644
--- a/src/server/authserver/Server/AuthSocketMgr.h
+++ b/src/server/authserver/Server/AuthSocketMgr.h
@@ -21,8 +21,6 @@
#include "SocketMgr.h"
#include "AuthSession.h"
-void OnSocketAccept(tcp::socket&& sock);
-
class AuthSocketMgr : public SocketMgr<AuthSession>
{
typedef SocketMgr<AuthSession> BaseSocketMgr;
@@ -39,7 +37,7 @@ public:
if (!BaseSocketMgr::StartNetwork(service, bindIp, port))
return false;
- _acceptor->AsyncAcceptManaged(&OnSocketAccept);
+ _acceptor->AsyncAcceptWithCallback<&AuthSocketMgr::OnSocketAccept>();
return true;
}
@@ -48,14 +46,13 @@ protected:
{
return new NetworkThread<AuthSession>[1];
}
+
+ static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex)
+ {
+ Instance().OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
+ }
};
#define sAuthSocketMgr AuthSocketMgr::Instance()
-void OnSocketAccept(tcp::socket&& sock)
-{
- sAuthSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock));
-}
-
-
#endif // AuthSocketMgr_h__
diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp
index a2d357cbc4d..36029113055 100644
--- a/src/server/game/Server/WorldSocket.cpp
+++ b/src/server/game/Server/WorldSocket.cpp
@@ -25,6 +25,17 @@
#include <memory>
+class EncryptablePacket : public WorldPacket
+{
+public:
+ EncryptablePacket(WorldPacket const& packet, bool encrypt) : WorldPacket(packet), _encrypt(encrypt) { }
+
+ bool NeedsEncryption() const { return _encrypt; }
+
+private:
+ bool _encrypt;
+};
+
using boost::asio::ip::tcp;
WorldSocket::WorldSocket(tcp::socket&& socket)
@@ -40,11 +51,8 @@ void WorldSocket::Start()
stmt->setString(0, ip_address);
stmt->setUInt32(1, inet_addr(ip_address.c_str()));
- {
- std::lock_guard<std::mutex> guard(_queryLock);
- _queryCallback = io_service().wrap(std::bind(&WorldSocket::CheckIpCallback, this, std::placeholders::_1));
- _queryFuture = LoginDatabase.AsyncQuery(stmt);
- }
+ _queryCallback = std::bind(&WorldSocket::CheckIpCallback, this, std::placeholders::_1);
+ _queryFuture = LoginDatabase.AsyncQuery(stmt);
}
void WorldSocket::CheckIpCallback(PreparedQueryResult result)
@@ -78,17 +86,50 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result)
bool WorldSocket::Update()
{
+ EncryptablePacket* queued;
+ MessageBuffer buffer;
+ while (_bufferQueue.Dequeue(queued))
+ {
+ ServerPktHeader header(queued->size() + 2, queued->GetOpcode());
+ if (queued->NeedsEncryption())
+ _authCrypt.EncryptSend(header.header, header.getHeaderLength());
+
+ if (buffer.GetRemainingSpace() < queued->size() + header.getHeaderLength())
+ {
+ QueuePacket(std::move(buffer));
+ buffer.Resize(4096);
+ }
+
+ if (buffer.GetRemainingSpace() >= queued->size() + header.getHeaderLength())
+ {
+ buffer.Write(header.header, header.getHeaderLength());
+ if (!queued->empty())
+ buffer.Write(queued->contents(), queued->size());
+ }
+ else // single packet larger than 4096 bytes
+ {
+ MessageBuffer packetBuffer(queued->size() + header.getHeaderLength());
+ packetBuffer.Write(header.header, header.getHeaderLength());
+ if (!queued->empty())
+ packetBuffer.Write(queued->contents(), queued->size());
+
+ QueuePacket(std::move(packetBuffer));
+ }
+
+ delete queued;
+ }
+
+ if (buffer.GetActiveSize() > 0)
+ QueuePacket(std::move(buffer));
+
if (!BaseSocket::Update())
return false;
+ if (_queryFuture.valid() && _queryFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
{
- std::lock_guard<std::mutex> guard(_queryLock);
- if (_queryFuture.valid() && _queryFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
- {
- auto callback = std::move(_queryCallback);
- _queryCallback = nullptr;
- callback(_queryFuture.get());
- }
+ auto callback = _queryCallback;
+ _queryCallback = nullptr;
+ callback(_queryFuture.get());
}
return true;
@@ -351,29 +392,7 @@ void WorldSocket::SendPacket(WorldPacket const& packet)
if (sPacketLog->CanLogPacket())
sPacketLog->LogPacket(packet, SERVER_TO_CLIENT, GetRemoteIpAddress(), GetRemotePort());
- ServerPktHeader header(packet.size() + 2, packet.GetOpcode());
-
- std::unique_lock<std::mutex> guard(_writeLock);
-
- _authCrypt.EncryptSend(header.header, header.getHeaderLength());
-
-#ifndef TC_SOCKET_USE_IOCP
- if (_writeQueue.empty() && _writeBuffer.GetRemainingSpace() >= header.getHeaderLength() + packet.size())
- {
- _writeBuffer.Write(header.header, header.getHeaderLength());
- if (!packet.empty())
- _writeBuffer.Write(packet.contents(), packet.size());
- }
- else
-#endif
- {
- MessageBuffer buffer(header.getHeaderLength() + packet.size());
- buffer.Write(header.header, header.getHeaderLength());
- if (!packet.empty())
- buffer.Write(packet.contents(), packet.size());
-
- QueuePacket(std::move(buffer), guard);
- }
+ _bufferQueue.Enqueue(new EncryptablePacket(packet, _authCrypt.IsInitialized()));
}
void WorldSocket::HandleAuthSession(WorldPacket& recvPacket)
@@ -398,11 +417,8 @@ void WorldSocket::HandleAuthSession(WorldPacket& recvPacket)
stmt->setInt32(0, int32(realmID));
stmt->setString(1, authSession->Account);
- {
- std::lock_guard<std::mutex> guard(_queryLock);
- _queryCallback = io_service().wrap(std::bind(&WorldSocket::HandleAuthSessionCallback, this, authSession, std::placeholders::_1));
- _queryFuture = LoginDatabase.AsyncQuery(stmt);
- }
+ _queryCallback = std::bind(&WorldSocket::HandleAuthSessionCallback, this, authSession, std::placeholders::_1);
+ _queryFuture = LoginDatabase.AsyncQuery(stmt);
}
void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<AuthSession> authSession, PreparedQueryResult result)
@@ -559,7 +575,7 @@ void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<AuthSession> authSes
if (wardenActive)
_worldSession->InitWarden(&account.SessionKey, account.OS);
- _queryCallback = io_service().wrap(std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1));
+ _queryCallback = std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1);
_queryFuture = _worldSession->LoadPermissionsAsync();
AsyncRead();
}
diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h
index 9e5b35992a6..08a5b185cf1 100644
--- a/src/server/game/Server/WorldSocket.h
+++ b/src/server/game/Server/WorldSocket.h
@@ -26,11 +26,13 @@
#include "Util.h"
#include "WorldPacket.h"
#include "WorldSession.h"
+#include "MPSCQueue.h"
#include <chrono>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/buffer.hpp>
using boost::asio::ip::tcp;
+class EncryptablePacket;
#pragma pack(push, 1)
@@ -104,8 +106,8 @@ private:
MessageBuffer _headerBuffer;
MessageBuffer _packetBuffer;
+ MPSCQueue<EncryptablePacket> _bufferQueue;
- std::mutex _queryLock;
PreparedQueryResultFuture _queryFuture;
std::function<void(PreparedQueryResult&&)> _queryCallback;
std::string _ipCountry;
diff --git a/src/server/game/Server/WorldSocketMgr.cpp b/src/server/game/Server/WorldSocketMgr.cpp
index 529396b3966..e8f8c59f4af 100644
--- a/src/server/game/Server/WorldSocketMgr.cpp
+++ b/src/server/game/Server/WorldSocketMgr.cpp
@@ -24,9 +24,9 @@
#include <boost/system/error_code.hpp>
-static void OnSocketAccept(tcp::socket&& sock)
+static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex)
{
- sWorldSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock));
+ sWorldSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
}
class WorldSocketThread : public NetworkThread<WorldSocket>
@@ -67,7 +67,9 @@ bool WorldSocketMgr::StartNetwork(boost::asio::io_service& service, std::string
BaseSocketMgr::StartNetwork(service, bindIp, port);
- _acceptor->AsyncAcceptManaged(&OnSocketAccept);
+ _acceptor->SetSocketFactory(std::bind(&BaseSocketMgr::GetSocketForAccept, this));
+
+ _acceptor->AsyncAcceptWithCallback<&OnSocketAccept>();
sScriptMgr->OnNetworkStart();
return true;
@@ -80,7 +82,7 @@ void WorldSocketMgr::StopNetwork()
sScriptMgr->OnNetworkStop();
}
-void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock)
+void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
{
// set some options here
if (_socketSendBufferSize >= 0)
@@ -108,7 +110,7 @@ void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock)
//sock->m_OutBufferSize = static_cast<size_t> (m_SockOutUBuff);
- BaseSocketMgr::OnSocketOpen(std::forward<tcp::socket>(sock));
+ BaseSocketMgr::OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
}
NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const
diff --git a/src/server/game/Server/WorldSocketMgr.h b/src/server/game/Server/WorldSocketMgr.h
index 92a28d0c135..38e2e7abb69 100644
--- a/src/server/game/Server/WorldSocketMgr.h
+++ b/src/server/game/Server/WorldSocketMgr.h
@@ -47,7 +47,7 @@ public:
/// Stops all network threads, It will wait for all running threads .
void StopNetwork() override;
- void OnSocketOpen(tcp::socket&& sock) override;
+ void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) override;
protected:
WorldSocketMgr();
diff --git a/src/server/shared/Networking/AsyncAcceptor.h b/src/server/shared/Networking/AsyncAcceptor.h
index 260e1c8ea11..0f3fd9a145b 100644
--- a/src/server/shared/Networking/AsyncAcceptor.h
+++ b/src/server/shared/Networking/AsyncAcceptor.h
@@ -20,34 +20,39 @@
#include "Log.h"
#include <boost/asio.hpp>
+#include <functional>
using boost::asio::ip::tcp;
class AsyncAcceptor
{
public:
- typedef void(*ManagerAcceptHandler)(tcp::socket&& newSocket);
+ typedef void(*AcceptCallback)(tcp::socket&& newSocket, uint32 threadIndex);
AsyncAcceptor(boost::asio::io_service& ioService, std::string const& bindIp, uint16 port) :
_acceptor(ioService, tcp::endpoint(boost::asio::ip::address::from_string(bindIp), port)),
- _socket(ioService)
+ _socket(ioService), _closed(false), _socketFactory(std::bind(&AsyncAcceptor::DefeaultSocketFactory, this))
{
}
- template <class T>
+ template<class T>
void AsyncAccept();
- void AsyncAcceptManaged(ManagerAcceptHandler mgrHandler)
+ template<AcceptCallback acceptCallback>
+ void AsyncAcceptWithCallback()
{
- _acceptor.async_accept(_socket, [this, mgrHandler](boost::system::error_code error)
+ tcp::socket* socket;
+ uint32 threadIndex;
+ std::tie(socket, threadIndex) = _socketFactory();
+ _acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error)
{
if (!error)
{
try
{
- _socket.non_blocking(true);
+ socket->non_blocking(true);
- mgrHandler(std::move(_socket));
+ acceptCallback(std::move(*socket), threadIndex);
}
catch (boost::system::system_error const& err)
{
@@ -55,13 +60,29 @@ public:
}
}
- AsyncAcceptManaged(mgrHandler);
+ if (!_closed)
+ this->AsyncAcceptWithCallback<acceptCallback>();
});
}
+ void Close()
+ {
+ if (_closed.exchange(true))
+ return;
+
+ boost::system::error_code err;
+ _acceptor.close(err);
+ }
+
+ void SetSocketFactory(std::function<std::pair<tcp::socket*, uint32>()> func) { _socketFactory = func; }
+
private:
+ std::pair<tcp::socket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
+
tcp::acceptor _acceptor;
tcp::socket _socket;
+ std::atomic<bool> _closed;
+ std::function<std::pair<tcp::socket*, uint32>()> _socketFactory;
};
template<class T>
@@ -83,7 +104,8 @@ void AsyncAcceptor::AsyncAccept()
}
// lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face
- this->AsyncAccept<T>();
+ if (!_closed)
+ this->AsyncAccept<T>();
});
}
diff --git a/src/server/shared/Networking/MessageBuffer.h b/src/server/shared/Networking/MessageBuffer.h
index 189a56f18b6..d68bee181b1 100644
--- a/src/server/shared/Networking/MessageBuffer.h
+++ b/src/server/shared/Networking/MessageBuffer.h
@@ -105,7 +105,7 @@ public:
return std::move(_storage);
}
- MessageBuffer& operator=(MessageBuffer& right)
+ MessageBuffer& operator=(MessageBuffer const& right)
{
if (this != &right)
{
diff --git a/src/server/shared/Networking/NetworkThread.h b/src/server/shared/Networking/NetworkThread.h
index ac216838bce..5eb2fcb2f6a 100644
--- a/src/server/shared/Networking/NetworkThread.h
+++ b/src/server/shared/Networking/NetworkThread.h
@@ -22,6 +22,8 @@
#include "Errors.h"
#include "Log.h"
#include "Timer.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <boost/asio/deadline_timer.hpp>
#include <atomic>
#include <chrono>
#include <memory>
@@ -29,11 +31,14 @@
#include <set>
#include <thread>
+using boost::asio::ip::tcp;
+
template<class SocketType>
class NetworkThread
{
public:
- NetworkThread() : _connections(0), _stopped(false), _thread(nullptr)
+ NetworkThread() : _connections(0), _stopped(false), _thread(nullptr),
+ _acceptSocket(_io_service), _updateTimer(_io_service)
{
}
@@ -50,6 +55,7 @@ public:
void Stop()
{
_stopped = true;
+ _io_service.stop();
}
bool Start()
@@ -80,10 +86,12 @@ public:
std::lock_guard<std::mutex> lock(_newSocketsLock);
++_connections;
- _newSockets.insert(sock);
+ _newSockets.push_back(sock);
SocketAdded(sock);
}
+ tcp::socket* GetSocketForAccept() { return &_acceptSocket; }
+
protected:
virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { }
virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { }
@@ -95,16 +103,15 @@ protected:
if (_newSockets.empty())
return;
- for (typename SocketSet::const_iterator i = _newSockets.begin(); i != _newSockets.end(); ++i)
+ for (std::shared_ptr<SocketType> sock : _newSockets)
{
- if (!(*i)->IsOpen())
+ if (!sock->IsOpen())
{
- SocketRemoved(*i);
-
+ SocketRemoved(sock);
--_connections;
}
else
- _Sockets.insert(*i);
+ _sockets.push_back(sock);
}
_newSockets.clear();
@@ -114,53 +121,58 @@ protected:
{
TC_LOG_DEBUG("misc", "Network Thread Starting");
- typename SocketSet::iterator i, t;
+ _updateTimer.expires_from_now(boost::posix_time::milliseconds(10));
+ _updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this));
+ _io_service.run();
- uint32 sleepTime = 10;
- uint32 tickStart = 0, diff = 0;
- while (!_stopped)
- {
- std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
+ TC_LOG_DEBUG("misc", "Network Thread exits");
+ _newSockets.clear();
+ _sockets.clear();
+ }
+
+ void Update()
+ {
+ if (_stopped)
+ return;
- tickStart = getMSTime();
+ _updateTimer.expires_from_now(boost::posix_time::milliseconds(10));
+ _updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this));
- AddNewSockets();
+ AddNewSockets();
- for (i = _Sockets.begin(); i != _Sockets.end();)
+ _sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr<SocketType> sock)
+ {
+ if (!sock->Update())
{
- if (!(*i)->Update())
- {
- if ((*i)->IsOpen())
- (*i)->CloseSocket();
-
- SocketRemoved(*i);
-
- --_connections;
- _Sockets.erase(i++);
- }
- else
- ++i;
- }
+ if (sock->IsOpen())
+ sock->CloseSocket();
- diff = GetMSTimeDiffToNow(tickStart);
- sleepTime = diff > 10 ? 0 : 10 - diff;
- }
+ SocketRemoved(sock);
- TC_LOG_DEBUG("misc", "Network Thread exits");
+ --_connections;
+ return true;
+ }
+
+ return false;
+ }), _sockets.end());
}
private:
- typedef std::set<std::shared_ptr<SocketType> > SocketSet;
+ typedef std::vector<std::shared_ptr<SocketType>> SocketContainer;
std::atomic<int32> _connections;
std::atomic<bool> _stopped;
std::thread* _thread;
- SocketSet _Sockets;
+ SocketContainer _sockets;
std::mutex _newSocketsLock;
- SocketSet _newSockets;
+ SocketContainer _newSockets;
+
+ boost::asio::io_service _io_service;
+ tcp::socket _acceptSocket;
+ boost::asio::deadline_timer _updateTimer;
};
#endif // NetworkThread_h__
diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h
index a2f57b5029e..d1ba7f49aa4 100644
--- a/src/server/shared/Networking/Socket.h
+++ b/src/server/shared/Networking/Socket.h
@@ -21,15 +21,11 @@
#include "MessageBuffer.h"
#include "Log.h"
#include <atomic>
-#include <vector>
-#include <mutex>
#include <queue>
#include <memory>
#include <functional>
#include <type_traits>
#include <boost/asio/ip/tcp.hpp>
-#include <boost/asio/write.hpp>
-#include <boost/asio/read.hpp>
using boost::asio::ip::tcp;
@@ -63,14 +59,10 @@ public:
return false;
#ifndef TC_SOCKET_USE_IOCP
- std::unique_lock<std::mutex> guard(_writeLock);
- if (!guard)
+ if (_isWritingAsync || _writeQueue.empty())
return true;
- if (_isWritingAsync || (!_writeBuffer.GetActiveSize() && _writeQueue.empty()))
- return true;
-
- for (; WriteHandler(guard);)
+ for (; HandleQueue();)
;
#endif
@@ -98,14 +90,12 @@ public:
std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}
- void QueuePacket(MessageBuffer&& buffer, std::unique_lock<std::mutex>& guard)
+ void QueuePacket(MessageBuffer&& buffer)
{
_writeQueue.push(std::move(buffer));
#ifdef TC_SOCKET_USE_IOCP
- AsyncProcessQueue(guard);
-#else
- (void)guard;
+ AsyncProcessQueue();
#endif
}
@@ -135,7 +125,7 @@ protected:
virtual void ReadHandler() = 0;
- bool AsyncProcessQueue(std::unique_lock<std::mutex>&)
+ bool AsyncProcessQueue()
{
if (_isWritingAsync)
return false;
@@ -154,14 +144,6 @@ protected:
return false;
}
- std::mutex _writeLock;
- std::queue<MessageBuffer> _writeQueue;
-#ifndef TC_SOCKET_USE_IOCP
- MessageBuffer _writeBuffer;
-#endif
-
- boost::asio::io_service& io_service() { return _socket.get_io_service(); }
-
private:
void ReadHandlerInternal(boost::system::error_code error, size_t transferredBytes)
{
@@ -181,15 +163,13 @@ private:
{
if (!error)
{
- std::unique_lock<std::mutex> deleteGuard(_writeLock);
-
_isWritingAsync = false;
_writeQueue.front().ReadCompleted(transferedBytes);
if (!_writeQueue.front().GetActiveSize())
_writeQueue.pop();
if (!_writeQueue.empty())
- AsyncProcessQueue(deleteGuard);
+ AsyncProcessQueue();
else if (_closing)
CloseSocket();
}
@@ -201,48 +181,15 @@ private:
void WriteHandlerWrapper(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/)
{
- std::unique_lock<std::mutex> guard(_writeLock);
_isWritingAsync = false;
- WriteHandler(guard);
+ HandleQueue();
}
- bool WriteHandler(std::unique_lock<std::mutex>& guard)
+ bool HandleQueue()
{
if (!IsOpen())
return false;
- std::size_t bytesToSend = _writeBuffer.GetActiveSize();
-
- if (bytesToSend == 0)
- return HandleQueue(guard);
-
- boost::system::error_code error;
- std::size_t bytesWritten = _socket.write_some(boost::asio::buffer(_writeBuffer.GetReadPointer(), bytesToSend), error);
-
- if (error)
- {
- if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
- return AsyncProcessQueue(guard);
-
- return false;
- }
- else if (bytesWritten == 0)
- return false;
- else if (bytesWritten < bytesToSend)
- {
- _writeBuffer.ReadCompleted(bytesWritten);
- _writeBuffer.Normalize();
- return AsyncProcessQueue(guard);
- }
-
- // now bytesWritten == bytesToSend
- _writeBuffer.Reset();
-
- return HandleQueue(guard);
- }
-
- bool HandleQueue(std::unique_lock<std::mutex>& guard)
- {
if (_writeQueue.empty())
return false;
@@ -256,7 +203,7 @@ private:
if (error)
{
if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
- return AsyncProcessQueue(guard);
+ return AsyncProcessQueue();
_writeQueue.pop();
return false;
@@ -269,7 +216,7 @@ private:
else if (bytesSent < bytesToSend) // now n > 0
{
queuedMessage.ReadCompleted(bytesSent);
- return AsyncProcessQueue(guard);
+ return AsyncProcessQueue();
}
_writeQueue.pop();
@@ -284,6 +231,7 @@ private:
uint16 _remotePort;
MessageBuffer _readBuffer;
+ std::queue<MessageBuffer> _writeQueue;
std::atomic<bool> _closed;
std::atomic<bool> _closing;
diff --git a/src/server/shared/Networking/SocketMgr.h b/src/server/shared/Networking/SocketMgr.h
index ce5bc2d8fc2..b14aac4ca47 100644
--- a/src/server/shared/Networking/SocketMgr.h
+++ b/src/server/shared/Networking/SocketMgr.h
@@ -33,7 +33,7 @@ class SocketMgr
public:
virtual ~SocketMgr()
{
- delete[] _threads;
+ ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction");
}
virtual bool StartNetwork(boost::asio::io_service& service, std::string const& bindIp, uint16 port)
@@ -68,11 +68,19 @@ public:
virtual void StopNetwork()
{
+ _acceptor->Close();
+
if (_threadCount != 0)
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Stop();
Wait();
+
+ delete _acceptor;
+ _acceptor = nullptr;
+ delete[] _threads;
+ _threads = nullptr;
+ _threadCount = 0;
}
void Wait()
@@ -82,20 +90,14 @@ public:
_threads[i].Wait();
}
- virtual void OnSocketOpen(tcp::socket&& sock)
+ virtual void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
{
- size_t min = 0;
-
- for (int32 i = 1; i < _threadCount; ++i)
- if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
- min = i;
-
try
{
std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
newSocket->Start();
- _threads[min].AddSocket(newSocket);
+ _threads[threadIndex].AddSocket(newSocket);
}
catch (boost::system::system_error const& err)
{
@@ -105,6 +107,23 @@ public:
int32 GetNetworkThreadCount() const { return _threadCount; }
+ uint32 SelectThreadWithMinConnections() const
+ {
+ uint32 min = 0;
+
+ for (int32 i = 1; i < _threadCount; ++i)
+ if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
+ min = i;
+
+ return min;
+ }
+
+ std::pair<tcp::socket*, uint32> GetSocketForAccept()
+ {
+ uint32 threadIndex = SelectThreadWithMinConnections();
+ return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
+ }
+
protected:
SocketMgr() : _acceptor(nullptr), _threads(nullptr), _threadCount(1)
{