aboutsummaryrefslogtreecommitdiff
path: root/src/server/shared/Networking
diff options
context:
space:
mode:
authorShauren <shauren.trinity@gmail.com>2014-09-09 19:19:25 +0200
committerShauren <shauren.trinity@gmail.com>2014-09-09 19:19:25 +0200
commite0ce4528c5ffd43f651f88821723311541e9e461 (patch)
treec5f3906d17114120b2ea2b382cea39db07137d18 /src/server/shared/Networking
parenta2ba49afa428ed9297f98bf8a5e00f6f7a6f4c3a (diff)
Core/NetworkIO: Use reactor style sending on linux to reduce locking overhead
Diffstat (limited to 'src/server/shared/Networking')
-rw-r--r--src/server/shared/Networking/AsyncAcceptor.h52
-rw-r--r--src/server/shared/Networking/MessageBuffer.h67
-rw-r--r--src/server/shared/Networking/NetworkThread.h166
-rw-r--r--src/server/shared/Networking/Socket.h255
-rw-r--r--src/server/shared/Networking/SocketMgr.h111
5 files changed, 516 insertions, 135 deletions
diff --git a/src/server/shared/Networking/AsyncAcceptor.h b/src/server/shared/Networking/AsyncAcceptor.h
index 64665c2b198..a8b688e6b26 100644
--- a/src/server/shared/Networking/AsyncAcceptor.h
+++ b/src/server/shared/Networking/AsyncAcceptor.h
@@ -23,37 +23,32 @@
using boost::asio::ip::tcp;
-template <class T>
class AsyncAcceptor
{
public:
- AsyncAcceptor(boost::asio::io_service& ioService, std::string bindIp, int port) :
- _acceptor(ioService, tcp::endpoint(boost::asio::ip::address::from_string(bindIp), port)),
- _socket(ioService)
- {
- AsyncAccept();
- };
+ typedef void(*ManagerAcceptHandler)(tcp::socket&& newSocket);
- AsyncAcceptor(boost::asio::io_service& ioService, std::string bindIp, int port, bool tcpNoDelay) :
+ AsyncAcceptor(boost::asio::io_service& ioService, std::string const& bindIp, uint16 port) :
_acceptor(ioService, tcp::endpoint(boost::asio::ip::address::from_string(bindIp), port)),
_socket(ioService)
{
- _acceptor.set_option(boost::asio::ip::tcp::no_delay(tcpNoDelay));
+ boost::system::error_code error;
+ _acceptor.non_blocking(true, error);
+ }
- AsyncAccept();
- };
+ template <class T>
+ void AsyncAccept();
-private:
- void AsyncAccept()
+ void AsyncAcceptManaged(ManagerAcceptHandler mgrHandler)
{
- _acceptor.async_accept(_socket, [this](boost::system::error_code error)
+ _acceptor.async_accept(_socket, [this, mgrHandler](boost::system::error_code error)
{
if (!error)
{
try
{
// this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class
- std::make_shared<T>(std::move(this->_socket))->Start();
+ mgrHandler(std::move(this->_socket));
}
catch (boost::system::system_error const& err)
{
@@ -61,13 +56,36 @@ private:
}
}
- // lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face
- this->AsyncAccept();
+ AsyncAcceptManaged(mgrHandler);
});
}
+private:
tcp::acceptor _acceptor;
tcp::socket _socket;
};
+template<class T>
+void AsyncAcceptor::AsyncAccept()
+{
+ _acceptor.async_accept(_socket, [this](boost::system::error_code error)
+ {
+ if (!error)
+ {
+ try
+ {
+ // this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class
+ std::make_shared<T>(std::move(this->_socket))->Start();
+ }
+ catch (boost::system::system_error const& err)
+ {
+ TC_LOG_INFO("network", "Failed to retrieve client's remote address %s", err.what());
+ }
+ }
+
+ // lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face
+ this->AsyncAccept<T>();
+ });
+}
+
#endif /* __ASYNCACCEPT_H_ */
diff --git a/src/server/shared/Networking/MessageBuffer.h b/src/server/shared/Networking/MessageBuffer.h
index c7f8ba31a71..2115bea3f47 100644
--- a/src/server/shared/Networking/MessageBuffer.h
+++ b/src/server/shared/Networking/MessageBuffer.h
@@ -26,42 +26,74 @@ class MessageBuffer
typedef std::vector<uint8>::size_type size_type;
public:
- MessageBuffer() : _wpos(0), _storage() { }
+ MessageBuffer() : _wpos(0), _rpos(0), _storage()
+ {
+ _storage.resize(4096);
+ }
- MessageBuffer(MessageBuffer const& right) : _wpos(right._wpos), _storage(right._storage) { }
+ explicit MessageBuffer(std::size_t initialSize) : _wpos(0), _rpos(0), _storage()
+ {
+ _storage.resize(initialSize);
+ }
+
+ MessageBuffer(MessageBuffer const& right) : _wpos(right._wpos), _rpos(right._rpos), _storage(right._storage)
+ {
+ }
- MessageBuffer(MessageBuffer&& right) : _wpos(right._wpos), _storage(right.Move()) { }
+ MessageBuffer(MessageBuffer&& right) : _wpos(right._wpos), _rpos(right._rpos), _storage(right.Move()) { }
void Reset()
{
- _storage.clear();
_wpos = 0;
+ _rpos = 0;
}
- bool IsMessageReady() const { return _wpos == _storage.size(); }
+ void Resize(size_type bytes)
+ {
+ _storage.resize(bytes);
+ }
- size_type GetSize() const { return _storage.size(); }
+ uint8* GetBasePointer() { return _storage.data(); }
- size_type GetReadyDataSize() const { return _wpos; }
+ uint8* GetReadPointer() { return &_storage[_rpos]; }
- size_type GetMissingSize() const { return _storage.size() - _wpos; }
+ uint8* GetWritePointer() { return &_storage[_wpos]; }
- uint8* Data() { return _storage.data(); }
+ void ReadCompleted(size_type bytes) { _rpos += bytes; }
- void Grow(size_type bytes)
- {
- _storage.resize(_storage.size() + bytes);
- }
+ void WriteCompleted(size_type bytes) { _wpos += bytes; }
- uint8* GetWritePointer() { return &_storage[_wpos]; }
+ size_type GetActiveSize() const { return _wpos - _rpos; }
- void WriteCompleted(size_type bytes) { _wpos += bytes; }
+ size_type GetRemainingSpace() const { return _storage.size() - _wpos; }
+
+ size_type GetBufferSize() const { return _storage.size(); }
+
+ // Discards inactive data
+ void Normalize()
+ {
+ if (_rpos)
+ {
+ if (_rpos != _wpos)
+ memmove(GetBasePointer(), GetReadPointer(), GetActiveSize());
+ _wpos -= _rpos;
+ _rpos = 0;
+ }
+ }
- void ResetWritePointer() { _wpos = 0; }
+ void Write(void* data, std::size_t size)
+ {
+ if (size)
+ {
+ memcpy(GetWritePointer(), data, size);
+ WriteCompleted(size);
+ }
+ }
std::vector<uint8>&& Move()
{
_wpos = 0;
+ _rpos = 0;
return std::move(_storage);
}
@@ -70,6 +102,7 @@ public:
if (this != &right)
{
_wpos = right._wpos;
+ _rpos = right._rpos;
_storage = right._storage;
}
@@ -81,6 +114,7 @@ public:
if (this != &right)
{
_wpos = right._wpos;
+ _rpos = right._rpos;
_storage = right.Move();
}
@@ -89,6 +123,7 @@ public:
private:
size_type _wpos;
+ size_type _rpos;
std::vector<uint8> _storage;
};
diff --git a/src/server/shared/Networking/NetworkThread.h b/src/server/shared/Networking/NetworkThread.h
new file mode 100644
index 00000000000..701d0d97f36
--- /dev/null
+++ b/src/server/shared/Networking/NetworkThread.h
@@ -0,0 +1,166 @@
+/*
+ * Copyright (C) 2008-2014 TrinityCore <http://www.trinitycore.org/>
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef NetworkThread_h__
+#define NetworkThread_h__
+
+#include "Define.h"
+#include "Errors.h"
+#include "Log.h"
+#include "Timer.h"
+#include <atomic>
+#include <chrono>
+#include <memory>
+#include <mutex>
+#include <set>
+#include <thread>
+
+template<class SocketType>
+class NetworkThread
+{
+public:
+ NetworkThread() : _connections(0), _stopped(false), _thread(nullptr)
+ {
+ }
+
+ virtual ~NetworkThread()
+ {
+ Stop();
+ if (_thread)
+ {
+ Wait();
+ delete _thread;
+ }
+ }
+
+ void Stop()
+ {
+ _stopped = true;
+ }
+
+ bool Start()
+ {
+ if (_thread)
+ return false;
+
+ _thread = new std::thread(&NetworkThread::Run, this);
+ return true;
+ }
+
+ void Wait()
+ {
+ ASSERT(_thread);
+
+ _thread->join();
+ delete _thread;
+ _thread = nullptr;
+ }
+
+ int32 GetConnectionCount() const
+ {
+ return _connections;
+ }
+
+ virtual void AddSocket(std::shared_ptr<SocketType> sock)
+ {
+ std::lock_guard<std::mutex> lock(_newSocketsLock);
+
+ ++_connections;
+ _newSockets.insert(sock);
+ SocketAdded(sock);
+ }
+
+protected:
+ virtual void SocketAdded(std::shared_ptr<SocketType> sock) { }
+ virtual void SocketRemoved(std::shared_ptr<SocketType> sock) { }
+
+ void AddNewSockets()
+ {
+ std::lock_guard<std::mutex> lock(_newSocketsLock);
+
+ if (_newSockets.empty())
+ return;
+
+ for (SocketSet::const_iterator i = _newSockets.begin(); i != _newSockets.end(); ++i)
+ {
+ if (!(*i)->IsOpen())
+ {
+ SocketRemoved(*i);
+
+ --_connections;
+ }
+ else
+ _Sockets.insert(*i);
+ }
+
+ _newSockets.clear();
+ }
+
+ void Run()
+ {
+ TC_LOG_DEBUG("misc", "Network Thread Starting");
+
+ SocketSet::iterator i, t;
+
+ uint32 sleepTime = 10;
+ uint32 tickStart = 0, diff = 0;
+ while (!_stopped)
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
+
+ tickStart = getMSTime();
+
+ AddNewSockets();
+
+ for (i = _Sockets.begin(); i != _Sockets.end();)
+ {
+ if (!(*i)->Update())
+ {
+ if ((*i)->IsOpen())
+ (*i)->CloseSocket();
+
+ SocketRemoved(*i);
+
+ --_connections;
+ _Sockets.erase(i++);
+ }
+ else
+ ++i;
+ }
+
+ diff = GetMSTimeDiffToNow(tickStart);
+ sleepTime = diff > 10 ? 0 : 10 - diff;
+ }
+
+ TC_LOG_DEBUG("misc", "Network Thread exits");
+ }
+
+private:
+ typedef std::set<std::shared_ptr<SocketType> > SocketSet;
+
+ std::atomic<int32> _connections;
+ std::atomic<bool> _stopped;
+
+ std::thread* _thread;
+
+ SocketSet _Sockets;
+
+ std::mutex _newSocketsLock;
+ SocketSet _newSockets;
+};
+
+#endif // NetworkThread_h__
diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h
index 3bd30bd731b..17f48343485 100644
--- a/src/server/shared/Networking/Socket.h
+++ b/src/server/shared/Networking/Socket.h
@@ -35,32 +35,40 @@ using boost::asio::ip::tcp;
#define READ_BLOCK_SIZE 4096
-template<class T, class PacketType>
+template<class T>
class Socket : public std::enable_shared_from_this<T>
{
- typedef typename std::conditional<std::is_pointer<PacketType>::value, PacketType, PacketType const&>::type WritePacketType;
-
public:
- Socket(tcp::socket&& socket, std::size_t headerSize) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
- _remotePort(_socket.remote_endpoint().port()), _readHeaderBuffer(), _readDataBuffer(), _closed(false), _closing(false)
+ explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
+ _remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false)
{
- _readHeaderBuffer.Grow(headerSize);
+ _readBuffer.Resize(READ_BLOCK_SIZE);
}
virtual ~Socket()
{
boost::system::error_code error;
_socket.close(error);
-
- while (!_writeQueue.empty())
- {
- DeletePacket(_writeQueue.front());
- _writeQueue.pop();
- }
}
virtual void Start() = 0;
+ virtual bool Update()
+ {
+ if (!IsOpen())
+ return false;
+
+#ifndef BOOST_ASIO_HAS_IOCP
+ if (_isWritingAsync || (!_writeBuffer.GetActiveSize() && _writeQueue.empty()))
+ return true;
+
+ for (; WriteHandler(boost::system::error_code(), 0);)
+ ;
+#endif
+
+ return true;
+ }
+
boost::asio::ip::address GetRemoteIpAddress() const
{
return _remoteAddress;
@@ -71,31 +79,14 @@ public:
return _remotePort;
}
- void AsyncReadHeader()
+ void AsyncRead()
{
if (!IsOpen())
return;
- _readHeaderBuffer.ResetWritePointer();
- _readDataBuffer.Reset();
-
- AsyncReadMissingHeaderData();
- }
-
- void AsyncReadData(std::size_t size)
- {
- if (!IsOpen())
- return;
-
- if (!size)
- {
- // if this is a packet with 0 length body just invoke handler directly
- ReadDataHandler();
- return;
- }
-
- _readDataBuffer.Grow(size);
- AsyncReadMissingData();
+ _readBuffer.Normalize();
+ _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), READ_BLOCK_SIZE),
+ std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
}
void ReadData(std::size_t size)
@@ -105,13 +96,11 @@ public:
boost::system::error_code error;
- _readDataBuffer.Grow(size);
-
- std::size_t bytesRead = boost::asio::read(_socket, boost::asio::buffer(_readDataBuffer.GetWritePointer(), size), error);
+ std::size_t bytesRead = boost::asio::read(_socket, boost::asio::buffer(_readBuffer.GetWritePointer(), size), error);
- _readDataBuffer.WriteCompleted(bytesRead);
+ _readBuffer.WriteCompleted(bytesRead);
- if (error || !_readDataBuffer.IsMessageReady())
+ if (error || bytesRead != size)
{
TC_LOG_DEBUG("network", "Socket::ReadData: %s errored with: %i (%s)", GetRemoteIpAddress().to_string().c_str(), error.value(),
error.message().c_str());
@@ -120,15 +109,19 @@ public:
}
}
- void AsyncWrite(WritePacketType data)
+ void QueuePacket(MessageBuffer&& buffer, std::unique_lock<std::mutex>& guard)
{
- boost::asio::async_write(_socket, boost::asio::buffer(data), std::bind(&Socket<T, PacketType>::WriteHandler, this->shared_from_this(),
- std::placeholders::_1, std::placeholders::_2));
+
+ _writeQueue.push(std::move(buffer));
+
+#ifdef BOOST_ASIO_HAS_IOCP
+ AsyncProcessQueue(guard);
+#endif
}
bool IsOpen() const { return !_closed && !_closing; }
- virtual void CloseSocket()
+ void CloseSocket()
{
if (_closed.exchange(true))
return;
@@ -143,39 +136,37 @@ public:
/// Marks the socket for closing after write buffer becomes empty
void DelayedCloseSocket() { _closing = true; }
- virtual bool IsHeaderReady() const { return _readHeaderBuffer.IsMessageReady(); }
- virtual bool IsDataReady() const { return _readDataBuffer.IsMessageReady(); }
-
- uint8* GetHeaderBuffer() { return _readHeaderBuffer.Data(); }
- uint8* GetDataBuffer() { return _readDataBuffer.Data(); }
-
- size_t GetHeaderSize() const { return _readHeaderBuffer.GetReadyDataSize(); }
- size_t GetDataSize() const { return _readDataBuffer.GetReadyDataSize(); }
-
- MessageBuffer&& MoveHeader() { return std::move(_readHeaderBuffer); }
- MessageBuffer&& MoveData() { return std::move(_readDataBuffer); }
+ MessageBuffer& GetReadBuffer() { return _readBuffer; }
protected:
- virtual void ReadHeaderHandler() = 0;
- virtual void ReadDataHandler() = 0;
-
- std::mutex _writeLock;
- std::queue<PacketType> _writeQueue;
+ virtual void ReadHandler() = 0;
-private:
- void AsyncReadMissingHeaderData()
+ bool AsyncProcessQueue(std::unique_lock<std::mutex>&)
{
- _socket.async_read_some(boost::asio::buffer(_readHeaderBuffer.GetWritePointer(), std::min<std::size_t>(READ_BLOCK_SIZE, _readHeaderBuffer.GetMissingSize())),
- std::bind(&Socket<T, PacketType>::ReadHeaderHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+ if (_isWritingAsync)
+ return true;
+
+ _isWritingAsync = true;
+
+#ifdef BOOST_ASIO_HAS_IOCP
+ MessageBuffer& buffer = _writeQueue.front();
+ _socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), std::bind(&Socket<T>::WriteHandler,
+ this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+#else
+ _socket.async_write_some(boost::asio::null_buffers(), std::bind(&Socket<T>::WriteHandler, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+#endif
+
+ return true;
}
- void AsyncReadMissingData()
- {
- _socket.async_read_some(boost::asio::buffer(_readDataBuffer.GetWritePointer(), std::min<std::size_t>(READ_BLOCK_SIZE, _readDataBuffer.GetMissingSize())),
- std::bind(&Socket<T, PacketType>::ReadDataHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
- }
+ std::mutex _writeLock;
+ std::queue<MessageBuffer> _writeQueue;
+#ifndef BOOST_ASIO_HAS_IOCP
+ MessageBuffer _writeBuffer;
+#endif
- void ReadHeaderHandlerInternal(boost::system::error_code error, size_t transferredBytes)
+private:
+ void ReadHandlerInternal(boost::system::error_code error, size_t transferredBytes)
{
if (error)
{
@@ -183,70 +174,130 @@ private:
return;
}
- _readHeaderBuffer.WriteCompleted(transferredBytes);
- if (!IsHeaderReady())
+ _readBuffer.WriteCompleted(transferredBytes);
+ ReadHandler();
+ }
+
+#ifdef BOOST_ASIO_HAS_IOCP
+
+ void WriteHandler(boost::system::error_code error, std::size_t transferedBytes)
+ {
+ if (!error)
{
- // incomplete, read more
- AsyncReadMissingHeaderData();
- return;
- }
+ std::unique_lock<std::mutex> deleteGuard(_writeLock);
+
+ _isWritingAsync = false;
+ _writeQueue.front().ReadCompleted(transferedBytes);
+ if (!_writeQueue.front().GetActiveSize())
+ _writeQueue.pop();
- ReadHeaderHandler();
+ if (!_writeQueue.empty())
+ AsyncProcessQueue(deleteGuard);
+ else if (_closing)
+ CloseSocket();
+ }
+ else
+ CloseSocket();
}
- void ReadDataHandlerInternal(boost::system::error_code error, size_t transferredBytes)
+#else
+
+ bool WriteHandler(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/)
{
+ std::unique_lock<std::mutex> guard(_writeLock, std::try_to_lock);
+ if (!guard)
+ return false;
+
+ if (!IsOpen())
+ return false;
+
+ std::size_t bytesToSend = _writeBuffer.GetActiveSize();
+
+ if (bytesToSend == 0)
+ return HandleQueue(guard);
+
+ boost::system::error_code error;
+ std::size_t bytesWritten = _socket.write_some(boost::asio::buffer(_writeBuffer.GetReadPointer(), bytesToSend), error);
+
if (error)
{
- CloseSocket();
- return;
- }
+ if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
+ return AsyncProcessQueue(guard);
- _readDataBuffer.WriteCompleted(transferredBytes);
- if (!IsDataReady())
+ return false;
+ }
+ else if (bytesWritten == 0)
+ return false;
+ else if (bytesWritten < bytesToSend) //now n > 0
{
- // incomplete, read more
- AsyncReadMissingData();
- return;
+ _writeBuffer.ReadCompleted(bytesWritten);
+ _writeBuffer.Normalize();
+ return AsyncProcessQueue(guard);
}
- ReadDataHandler();
+ // now bytesWritten == bytesToSend
+ _writeBuffer.Reset();
+
+ return HandleQueue(guard);
}
- void WriteHandler(boost::system::error_code error, size_t /*transferedBytes*/)
+ bool HandleQueue(std::unique_lock<std::mutex>& guard)
{
- if (!error)
+ if (_writeQueue.empty())
+ {
+ _isWritingAsync = false;
+ return false;
+ }
+
+ MessageBuffer& queuedMessage = _writeQueue.front();
+
+ std::size_t bytesToSend = queuedMessage.GetActiveSize();
+
+ boost::system::error_code error;
+ std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error);
+
+ if (error)
{
- std::lock_guard<std::mutex> deleteGuard(_writeLock);
+ if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
+ return AsyncProcessQueue(guard);
- DeletePacket(_writeQueue.front());
_writeQueue.pop();
+ return false;
+ }
+ else if (bytesSent == 0)
+ {
+ _writeQueue.pop();
+ return false;
+ }
+ else if (bytesSent < bytesToSend) // now n > 0
+ {
+ queuedMessage.ReadCompleted(bytesSent);
+ return AsyncProcessQueue(guard);
+ }
- if (!_writeQueue.empty())
- AsyncWrite(_writeQueue.front());
- else if (_closing)
- CloseSocket();
+ _writeQueue.pop();
+ if (_writeQueue.empty())
+ {
+ _isWritingAsync = false;
+ return false;
}
- else
- CloseSocket();
- }
- template<typename Q = PacketType>
- typename std::enable_if<std::is_pointer<Q>::value>::type DeletePacket(PacketType& packet) { delete packet; }
+ return true;
+ }
- template<typename Q = PacketType>
- typename std::enable_if<!std::is_pointer<Q>::value>::type DeletePacket(PacketType const& /*packet*/) { }
+#endif
tcp::socket _socket;
boost::asio::ip::address _remoteAddress;
uint16 _remotePort;
- MessageBuffer _readHeaderBuffer;
- MessageBuffer _readDataBuffer;
+ MessageBuffer _readBuffer;
std::atomic<bool> _closed;
std::atomic<bool> _closing;
+
+ bool _isWritingAsync;
};
#endif // __SOCKET_H__
diff --git a/src/server/shared/Networking/SocketMgr.h b/src/server/shared/Networking/SocketMgr.h
new file mode 100644
index 00000000000..ed638ab89f3
--- /dev/null
+++ b/src/server/shared/Networking/SocketMgr.h
@@ -0,0 +1,111 @@
+/*
+ * Copyright (C) 2008-2014 TrinityCore <http://www.trinitycore.org/>
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef SocketMgr_h__
+#define SocketMgr_h__
+
+#include "AsyncAcceptor.h"
+#include "Config.h"
+#include "Errors.h"
+#include "NetworkThread.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <memory>
+
+using boost::asio::ip::tcp;
+
+template<class SocketType>
+class SocketMgr
+{
+public:
+ virtual ~SocketMgr()
+ {
+ delete[] _threads;
+ }
+
+ virtual bool StartNetwork(boost::asio::io_service& service, std::string const& bindIp, uint16 port)
+ {
+ _threadCount = sConfigMgr->GetIntDefault("Network.Threads", 1);
+
+ if (_threadCount <= 0)
+ {
+ TC_LOG_ERROR("misc", "Network.Threads is wrong in your config file");
+ return false;
+ }
+
+ _acceptor = new AsyncAcceptor(service, bindIp, port);
+ _threads = CreateThreads();
+
+ ASSERT(_threads);
+
+ for (int32 i = 0; i < _threadCount; ++i)
+ _threads[i].Start();
+
+ return true;
+ }
+
+ virtual void StopNetwork()
+ {
+ if (_threadCount != 0)
+ for (size_t i = 0; i < _threadCount; ++i)
+ _threads[i].Stop();
+
+ Wait();
+ }
+
+ void Wait()
+ {
+ if (_threadCount != 0)
+ for (size_t i = 0; i < _threadCount; ++i)
+ _threads[i].Wait();
+ }
+
+ virtual void OnSocketOpen(tcp::socket&& sock)
+ {
+ size_t min = 0;
+
+ for (size_t i = 1; i < _threadCount; ++i)
+ if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
+ min = i;
+
+ try
+ {
+ std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
+ newSocket->Start();
+
+ _threads[min].AddSocket(newSocket);
+ }
+ catch (boost::system::system_error const& err)
+ {
+ TC_LOG_INFO("network", "Failed to retrieve client's remote address %s", err.what());
+ }
+ }
+
+ int32 GetNetworkThreadCount() const { return _threadCount; }
+
+protected:
+ SocketMgr() : _threads(nullptr), _threadCount(1)
+ {
+ }
+
+ virtual NetworkThread<SocketType>* CreateThreads() const = 0;
+
+ AsyncAcceptor* _acceptor;
+ NetworkThread<SocketType>* _threads;
+ int32 _threadCount;
+};
+
+#endif // SocketMgr_h__