diff options
author | Kargatum <dowlandtop@yandex.com> | 2021-05-27 21:09:31 +0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-27 16:09:31 +0200 |
commit | c1e96064e9d8d1c7bcfe870fc5a47afbbd79e24c (patch) | |
tree | 485dd9b48c53c939a6e076d2b4ce2222b26a4f52 | |
parent | 2ae84e2faf6b3884ce949299c1914e2d3e438a9d (diff) |
feat(Core/Common): add Asio network threading (#6063)
-rw-r--r-- | src/common/Asio/AsioHacksFwd.h | 44 | ||||
-rw-r--r-- | src/common/Asio/DeadlineTimer.h | 30 | ||||
-rw-r--r-- | src/common/Asio/IoContext.h | 65 | ||||
-rw-r--r-- | src/common/Asio/IpAddress.h | 31 | ||||
-rw-r--r-- | src/common/Asio/IpNetwork.h | 57 | ||||
-rw-r--r-- | src/common/Asio/Resolver.h | 51 | ||||
-rw-r--r-- | src/common/Asio/Strand.h | 38 | ||||
-rw-r--r-- | src/server/shared/Network/AsyncAcceptor.h | 146 | ||||
-rw-r--r-- | src/server/shared/Network/NetworkThread.h | 166 | ||||
-rw-r--r-- | src/server/shared/Network/Socket.h | 276 | ||||
-rw-r--r-- | src/server/shared/Network/SocketMgr.h | 131 |
11 files changed, 1035 insertions, 0 deletions
diff --git a/src/common/Asio/AsioHacksFwd.h b/src/common/Asio/AsioHacksFwd.h new file mode 100644 index 0000000000..05b7960386 --- /dev/null +++ b/src/common/Asio/AsioHacksFwd.h @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3 + * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore> + */ + +#ifndef AsioHacksFwd_h__ +#define AsioHacksFwd_h__ + +#include <boost/version.hpp> + +/** + Collection of forward declarations to improve compile time + */ +namespace boost::posix_time +{ + class ptime; +} + +namespace boost::asio +{ + template <typename Time> + struct time_traits; +} + +namespace boost::asio::ip +{ + class address; + class tcp; + + template <typename InternetProtocol> + class basic_endpoint; + + typedef basic_endpoint<tcp> tcp_endpoint; +} + +namespace acore::Asio +{ + class DeadlineTimer; + class IoContext; + class Resolver; + class Strand; +} + +#endif // AsioHacksFwd_h__ diff --git a/src/common/Asio/DeadlineTimer.h b/src/common/Asio/DeadlineTimer.h new file mode 100644 index 0000000000..cf993ee7b8 --- /dev/null +++ b/src/common/Asio/DeadlineTimer.h @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3 + * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore> + */ + +#ifndef DeadlineTimer_h__ +#define DeadlineTimer_h__ + +#include <boost/asio/deadline_timer.hpp> + +#if BOOST_VERSION >= 107000 +#define BasicDeadlineTimerThirdTemplateArg , boost::asio::io_context::executor_type +#elif BOOST_VERSION >= 106600 +#define BasicDeadlineTimerThirdTemplateArg +#else +#define BasicDeadlineTimerThirdTemplateArg , boost::asio::deadline_timer_service<boost::posix_time::ptime, boost::asio::time_traits<boost::posix_time::ptime>> +#endif + +#define DeadlineTimerBase boost::asio::basic_deadline_timer<boost::posix_time::ptime, boost::asio::time_traits<boost::posix_time::ptime> BasicDeadlineTimerThirdTemplateArg> + +namespace acore::Asio +{ + class DeadlineTimer : public DeadlineTimerBase + { + public: + using DeadlineTimerBase::basic_deadline_timer; + }; +} + +#endif // DeadlineTimer_h__ diff --git a/src/common/Asio/IoContext.h b/src/common/Asio/IoContext.h new file mode 100644 index 0000000000..e08ef97a51 --- /dev/null +++ b/src/common/Asio/IoContext.h @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3 + * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore> + */ + +#ifndef IoContext_h__ +#define IoContext_h__ + +#include <boost/version.hpp> + +#if BOOST_VERSION >= 106600 +#include <boost/asio/io_context.hpp> +#include <boost/asio/post.hpp> +#define IoContextBaseNamespace boost::asio +#define IoContextBase io_context +#else +#include <boost/asio/io_service.hpp> +#define IoContextBaseNamespace boost::asio +#define IoContextBase io_service +#endif + +namespace acore::Asio +{ + class IoContext + { + public: + IoContext() : _impl() { } + explicit IoContext(int concurrency_hint) : _impl(concurrency_hint) { } + + operator IoContextBaseNamespace::IoContextBase&() { return _impl; } + operator IoContextBaseNamespace::IoContextBase const&() const { return _impl; } + + std::size_t run() { return _impl.run(); } + void stop() { _impl.stop(); } + +#if BOOST_VERSION >= 106600 + boost::asio::io_context::executor_type get_executor() noexcept { return _impl.get_executor(); } +#endif + + private: + IoContextBaseNamespace::IoContextBase _impl; + }; + + template<typename T> + inline decltype(auto) post(IoContextBaseNamespace::IoContextBase& ioContext, T&& t) + { +#if BOOST_VERSION >= 106600 + return boost::asio::post(ioContext, std::forward<T>(t)); +#else + return ioContext.post(std::forward<T>(t)); +#endif + } + + template<typename T> + inline decltype(auto) get_io_context(T&& ioObject) + { +#if BOOST_VERSION >= 106600 + return ioObject.get_executor().context(); +#else + return ioObject.get_io_service(); +#endif + } +} + +#endif // IoContext_h__ diff --git a/src/common/Asio/IpAddress.h b/src/common/Asio/IpAddress.h new file mode 100644 index 0000000000..3545b48c66 --- /dev/null +++ b/src/common/Asio/IpAddress.h @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3 + * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore> + */ + +#ifndef IpAddress_h__ +#define IpAddress_h__ + +#include "Define.h" +#include <boost/asio/ip/address.hpp> + +namespace acore::Net +{ +#if BOOST_VERSION >= 106600 + using boost::asio::ip::make_address; + using boost::asio::ip::make_address_v4; + inline uint32 address_to_uint(boost::asio::ip::address_v4 const& address) { return address.to_uint(); } +#else + inline boost::asio::ip::address make_address(char const* str) { return boost::asio::ip::address::from_string(str); } + inline boost::asio::ip::address make_address(char const* str, boost::system::error_code& ec) { return boost::asio::ip::address::from_string(str, ec); } + inline boost::asio::ip::address make_address(std::string const& str) { return boost::asio::ip::address::from_string(str); } + inline boost::asio::ip::address make_address(std::string const& str, boost::system::error_code& ec) { return boost::asio::ip::address::from_string(str, ec); } + inline boost::asio::ip::address_v4 make_address_v4(char const* str) { return boost::asio::ip::address_v4::from_string(str); } + inline boost::asio::ip::address_v4 make_address_v4(char const* str, boost::system::error_code& ec) { return boost::asio::ip::address_v4::from_string(str, ec); } + inline boost::asio::ip::address_v4 make_address_v4(std::string const& str) { return boost::asio::ip::address_v4::from_string(str); } + inline boost::asio::ip::address_v4 make_address_v4(std::string const& str, boost::system::error_code& ec) { return boost::asio::ip::address_v4::from_string(str, ec); } + inline uint32 address_to_uint(boost::asio::ip::address_v4 const& address) { return address.to_ulong(); } +#endif +} + +#endif // IpAddress_h__ diff --git a/src/common/Asio/IpNetwork.h b/src/common/Asio/IpNetwork.h new file mode 100644 index 0000000000..6c1ed9eacb --- /dev/null +++ b/src/common/Asio/IpNetwork.h @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3 + * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore> + */ + +#ifndef IpNetwork_h__ +#define IpNetwork_h__ + +#include "Define.h" +#include "IpAddress.h" +#include <boost/version.hpp> + +#if BOOST_VERSION >= 106600 +#include <boost/asio/ip/network_v4.hpp> +#include <boost/asio/ip/network_v6.hpp> +#endif + +namespace acore::Net +{ + inline bool IsInNetwork(boost::asio::ip::address_v4 const& networkAddress, boost::asio::ip::address_v4 const& mask, boost::asio::ip::address_v4 const& clientAddress) + { +#if BOOST_VERSION >= 106600 + boost::asio::ip::network_v4 network = boost::asio::ip::make_network_v4(networkAddress, mask); + boost::asio::ip::address_v4_range hosts = network.hosts(); + return hosts.find(clientAddress) != hosts.end(); +#else + return (clientAddress.to_ulong() & mask.to_ulong()) == (networkAddress.to_ulong() & mask.to_ulong()); +#endif + } + + inline boost::asio::ip::address_v4 GetDefaultNetmaskV4(boost::asio::ip::address_v4 const& networkAddress) + { + if ((address_to_uint(networkAddress) & 0x80000000) == 0) + return boost::asio::ip::address_v4(0xFF000000); + if ((address_to_uint(networkAddress) & 0xC0000000) == 0x80000000) + return boost::asio::ip::address_v4(0xFFFF0000); + if ((address_to_uint(networkAddress) & 0xE0000000) == 0xC0000000) + return boost::asio::ip::address_v4(0xFFFFFF00); + return boost::asio::ip::address_v4(0xFFFFFFFF); + } + + inline bool IsInNetwork(boost::asio::ip::address_v6 const& networkAddress, uint16 prefixLength, boost::asio::ip::address_v6 const& clientAddress) + { +#if BOOST_VERSION >= 106600 + boost::asio::ip::network_v6 network = boost::asio::ip::make_network_v6(networkAddress, prefixLength); + boost::asio::ip::address_v6_range hosts = network.hosts(); + return hosts.find(clientAddress) != hosts.end(); +#else + (void)networkAddress; + (void)prefixLength; + (void)clientAddress; + return false; +#endif + } +} + +#endif // IpNetwork_h__ diff --git a/src/common/Asio/Resolver.h b/src/common/Asio/Resolver.h new file mode 100644 index 0000000000..459e6c7503 --- /dev/null +++ b/src/common/Asio/Resolver.h @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3 + * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore> + */ + +#ifndef Resolver_h__ +#define Resolver_h__ + +#include "IoContext.h" +#include "Optional.h" +#include <boost/asio/ip/tcp.hpp> +#include <string> + +namespace acore::Asio +{ + /** + Hack to make it possible to forward declare resolver (one of its template arguments is a typedef to something super long and using nested classes) + */ + class Resolver + { + public: + explicit Resolver(IoContext& ioContext) : _impl(ioContext) { } + + Optional<boost::asio::ip::tcp::endpoint> Resolve(boost::asio::ip::tcp const& protocol, std::string const& host, std::string const& service) + { + boost::system::error_code ec; +#if BOOST_VERSION >= 106600 + boost::asio::ip::resolver_base::flags flagsResolver = boost::asio::ip::resolver_base::all_matching; + boost::asio::ip::tcp::resolver::results_type results = _impl.resolve(protocol, host, service, flagsResolver, ec); + if (results.begin() == results.end() || ec) + return {}; + + return results.begin()->endpoint(); +#else + boost::asio::ip::resolver_query_base::flags flagsQuery = boost::asio::ip::tcp::resolver::query::all_matching; + boost::asio::ip::tcp::resolver::query query(std::move(protocol), std::move(host), std::move(service), flagsQuery); + boost::asio::ip::tcp::resolver::iterator itr = _impl.resolve(query, ec); + boost::asio::ip::tcp::resolver::iterator end; + if (itr == end || ec) + return {}; + + return itr->endpoint(); +#endif + } + + private: + boost::asio::ip::tcp::resolver _impl; + }; +} + +#endif // Resolver_h__ diff --git a/src/common/Asio/Strand.h b/src/common/Asio/Strand.h new file mode 100644 index 0000000000..351492d245 --- /dev/null +++ b/src/common/Asio/Strand.h @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3 + * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore> + */ + +#ifndef Strand_h__ +#define Strand_h__ + +#include "IoContext.h" +#include <boost/asio/strand.hpp> + +#if BOOST_VERSION >= 106600 +#include <boost/asio/bind_executor.hpp> +#endif + +namespace acore::Asio +{ + /** + Hack to make it possible to forward declare strand (which is a inner class) + */ + class Strand : public IoContextBaseNamespace::IoContextBase::strand + { + public: + Strand(IoContext& ioContext) : IoContextBaseNamespace::IoContextBase::strand(ioContext) { } + }; + +#if BOOST_VERSION >= 106600 + using boost::asio::bind_executor; +#else + template<typename T> + inline decltype(auto) bind_executor(Strand& strand, T&& t) + { + return strand.wrap(std::forward<T>(t)); + } +#endif +} + +#endif // Strand_h__ diff --git a/src/server/shared/Network/AsyncAcceptor.h b/src/server/shared/Network/AsyncAcceptor.h new file mode 100644 index 0000000000..c618fb9e9e --- /dev/null +++ b/src/server/shared/Network/AsyncAcceptor.h @@ -0,0 +1,146 @@ +/* + * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3 + * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore> + */ + +#ifndef __ASYNCACCEPT_H_ +#define __ASYNCACCEPT_H_ + +#include "IoContext.h" +#include "IpAddress.h" +#include "Log.h" +#include <boost/asio/ip/tcp.hpp> +#include <functional> +#include <atomic> + +using boost::asio::ip::tcp; + +#if BOOST_VERSION >= 106600 +#define WARHEAD_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections +#else +#define WARHEAD_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_connections +#endif + +class AsyncAcceptor +{ +public: + typedef void(*AcceptCallback)(tcp::socket&& newSocket, uint32 threadIndex); + + AsyncAcceptor(acore::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) : + _acceptor(ioContext), _endpoint(acore::Net::make_address(bindIp), port), + _socket(ioContext), _closed(false), _socketFactory(std::bind(&AsyncAcceptor::DefeaultSocketFactory, this)) + { + } + + template<class T> + void AsyncAccept(); + + template<AcceptCallback acceptCallback> + void AsyncAcceptWithCallback() + { + tcp::socket* socket; + uint32 threadIndex; + std::tie(socket, threadIndex) = _socketFactory(); + _acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error) + { + if (!error) + { + try + { + socket->non_blocking(true); + + acceptCallback(std::move(*socket), threadIndex); + } + catch (boost::system::system_error const& err) + { + LOG_INFO("network", "Failed to initialize client's socket %s", err.what()); + } + } + + if (!_closed) + this->AsyncAcceptWithCallback<acceptCallback>(); + }); + } + + bool Bind() + { + boost::system::error_code errorCode; + _acceptor.open(_endpoint.protocol(), errorCode); + if (errorCode) + { + LOG_INFO("network", "Failed to open acceptor %s", errorCode.message().c_str()); + return false; + } + +#if WARHEAD_PLATFORM != WARHEAD_PLATFORM_WINDOWS + _acceptor.set_option(boost::asio::ip::tcp::acceptor::reuse_address(true), errorCode); + if (errorCode) + { + LOG_INFO("network", "Failed to set reuse_address option on acceptor %s", errorCode.message().c_str()); + return false; + } +#endif + + _acceptor.bind(_endpoint, errorCode); + if (errorCode) + { + LOG_INFO("network", "Could not bind to %s:%u %s", _endpoint.address().to_string().c_str(), _endpoint.port(), errorCode.message().c_str()); + return false; + } + + _acceptor.listen(WARHEAD_MAX_LISTEN_CONNECTIONS, errorCode); + if (errorCode) + { + LOG_INFO("network", "Failed to start listening on %s:%u %s", _endpoint.address().to_string().c_str(), _endpoint.port(), errorCode.message().c_str()); + return false; + } + + return true; + } + + void Close() + { + if (_closed.exchange(true)) + return; + + boost::system::error_code err; + _acceptor.close(err); + } + + void SetSocketFactory(std::function<std::pair<tcp::socket*, uint32>()> func) { _socketFactory = func; } + +private: + std::pair<tcp::socket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); } + + tcp::acceptor _acceptor; + tcp::endpoint _endpoint; + tcp::socket _socket; + std::atomic<bool> _closed; + std::function<std::pair<tcp::socket*, uint32>()> _socketFactory; +}; + +template<class T> +void AsyncAcceptor::AsyncAccept() +{ + _acceptor.async_accept(_socket, [this](boost::system::error_code error) + { + if (!error) + { + try + { + // this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class + std::make_shared<T>(std::move(this->_socket))->Start(); + } + catch (boost::system::system_error const& err) + { + LOG_INFO("network", "Failed to retrieve client's remote address %s", err.what()); + } + } + + // lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face + if (!_closed) + this->AsyncAccept<T>(); + }); +} + +#endif /* __ASYNCACCEPT_H_ */ diff --git a/src/server/shared/Network/NetworkThread.h b/src/server/shared/Network/NetworkThread.h new file mode 100644 index 0000000000..48c6b9f163 --- /dev/null +++ b/src/server/shared/Network/NetworkThread.h @@ -0,0 +1,166 @@ +/* + * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3 + * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore> + */ + +#ifndef NetworkThread_h__ +#define NetworkThread_h__ + +#include "Define.h" +#include "DeadlineTimer.h" +#include "Errors.h" +#include "IoContext.h" +#include "Log.h" +#include "Timer.h" +#include <boost/asio/ip/tcp.hpp> +#include <atomic> +#include <chrono> +#include <memory> +#include <mutex> +#include <set> +#include <thread> + +using boost::asio::ip::tcp; + +template<class SocketType> +class NetworkThread +{ +public: + NetworkThread() : _connections(0), _stopped(false), _thread(nullptr), _ioContext(1), + _acceptSocket(_ioContext), _updateTimer(_ioContext) { } + + virtual ~NetworkThread() + { + Stop(); + + if (_thread) + { + Wait(); + delete _thread; + } + } + + void Stop() + { + _stopped = true; + _ioContext.stop(); + } + + bool Start() + { + if (_thread) + return false; + + _thread = new std::thread(&NetworkThread::Run, this); + return true; + } + + void Wait() + { + ASSERT(_thread); + + _thread->join(); + delete _thread; + _thread = nullptr; + } + + int32 GetConnectionCount() const + { + return _connections; + } + + virtual void AddSocket(std::shared_ptr<SocketType> sock) + { + std::lock_guard<std::mutex> lock(_newSocketsLock); + + ++_connections; + _newSockets.push_back(sock); + SocketAdded(sock); + } + + tcp::socket* GetSocketForAccept() { return &_acceptSocket; } + +protected: + virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { } + virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { } + + void AddNewSockets() + { + std::lock_guard<std::mutex> lock(_newSocketsLock); + + if (_newSockets.empty()) + return; + + for (std::shared_ptr<SocketType> sock : _newSockets) + { + if (!sock->IsOpen()) + { + SocketRemoved(sock); + --_connections; + } + else + _sockets.push_back(sock); + } + + _newSockets.clear(); + } + + void Run() + { + LOG_DEBUG("misc", "Network Thread Starting"); + + _updateTimer.expires_from_now(boost::posix_time::milliseconds(10)); + _updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this)); + _ioContext.run(); + + LOG_DEBUG("misc", "Network Thread exits"); + _newSockets.clear(); + _sockets.clear(); + } + + void Update() + { + if (_stopped) + return; + + _updateTimer.expires_from_now(boost::posix_time::milliseconds(10)); + _updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this)); + + AddNewSockets(); + + _sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr<SocketType> sock) + { + if (!sock->Update()) + { + if (sock->IsOpen()) + sock->CloseSocket(); + + this->SocketRemoved(sock); + + --this->_connections; + return true; + } + + return false; + }), _sockets.end()); + } + +private: + typedef std::vector<std::shared_ptr<SocketType>> SocketContainer; + + std::atomic<int32> _connections; + std::atomic<bool> _stopped; + + std::thread* _thread; + + SocketContainer _sockets; + + std::mutex _newSocketsLock; + SocketContainer _newSockets; + + acore::Asio::IoContext _ioContext; + tcp::socket _acceptSocket; + acore::Asio::DeadlineTimer _updateTimer; +}; + +#endif // NetworkThread_h__ diff --git a/src/server/shared/Network/Socket.h b/src/server/shared/Network/Socket.h new file mode 100644 index 0000000000..c3c50a18e2 --- /dev/null +++ b/src/server/shared/Network/Socket.h @@ -0,0 +1,276 @@ +/* + * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3 + * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore> + */ + +#ifndef __SOCKET_H__ +#define __SOCKET_H__ + +#include "MessageBuffer.h" +#include "Log.h" +#include <atomic> +#include <queue> +#include <memory> +#include <functional> +#include <type_traits> +#include <boost/asio/ip/tcp.hpp> + +using boost::asio::ip::tcp; + +#define READ_BLOCK_SIZE 4096 +#ifdef BOOST_ASIO_HAS_IOCP +#define AC_SOCKET_USE_IOCP +#endif + +template<class T> +class Socket : public std::enable_shared_from_this<T> +{ +public: + explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()), + _remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false) + { + _readBuffer.Resize(READ_BLOCK_SIZE); + } + + virtual ~Socket() + { + _closed = true; + boost::system::error_code error; + _socket.close(error); + } + + virtual void Start() = 0; + + virtual bool Update() + { + if (_closed) + { + return false; + } + +#ifndef AC_SOCKET_USE_IOCP + if (_isWritingAsync || (_writeQueue.empty() && !_closing)) + { + return true; + } + + for (; HandleQueue();) + ; +#endif + + return true; + } + + boost::asio::ip::address GetRemoteIpAddress() const + { + return _remoteAddress; + } + + uint16 GetRemotePort() const + { + return _remotePort; + } + + void AsyncRead() + { + if (!IsOpen()) + { + return; + } + + _readBuffer.Normalize(); + _readBuffer.EnsureFreeSpace(); + + _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), + std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); + } + + void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t)) + { + if (!IsOpen()) + { + return; + } + + _readBuffer.Normalize(); + _readBuffer.EnsureFreeSpace(); + + _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), + std::bind(callback, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); + } + + void QueuePacket(MessageBuffer&& buffer) + { + _writeQueue.push(std::move(buffer)); + +#ifdef AC_SOCKET_USE_IOCP + AsyncProcessQueue(); +#endif + } + + bool IsOpen() const { return !_closed && !_closing; } + + void CloseSocket() + { + if (_closed.exchange(true)) + return; + + boost::system::error_code shutdownError; + _socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError); + + if (shutdownError) + LOG_DEBUG("network", "Socket::CloseSocket: %s errored when shutting down socket: %i (%s)", GetRemoteIpAddress().to_string().c_str(), + shutdownError.value(), shutdownError.message().c_str()); + + OnClose(); + } + + /// Marks the socket for closing after write buffer becomes empty + void DelayedCloseSocket() { _closing = true; } + + MessageBuffer& GetReadBuffer() { return _readBuffer; } + +protected: + virtual void OnClose() { } + virtual void ReadHandler() = 0; + + bool AsyncProcessQueue() + { + if (_isWritingAsync) + return false; + + _isWritingAsync = true; + +#ifdef AC_SOCKET_USE_IOCP + MessageBuffer& buffer = _writeQueue.front(); + _socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), std::bind(&Socket<T>::WriteHandler, + this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); +#else + _socket.async_write_some(boost::asio::null_buffers(), std::bind(&Socket<T>::WriteHandlerWrapper, + this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); +#endif + return false; + } + + void SetNoDelay(bool enable) + { + boost::system::error_code err; + _socket.set_option(tcp::no_delay(enable), err); + + if (err) + LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for %s - %d (%s)", + GetRemoteIpAddress().to_string().c_str(), err.value(), err.message().c_str()); + } + +private: + void ReadHandlerInternal(boost::system::error_code error, size_t transferredBytes) + { + if (error) + { + CloseSocket(); + return; + } + + _readBuffer.WriteCompleted(transferredBytes); + ReadHandler(); + } + +#ifdef AC_SOCKET_USE_IOCP + void WriteHandler(boost::system::error_code error, std::size_t transferedBytes) + { + if (!error) + { + _isWritingAsync = false; + _writeQueue.front().ReadCompleted(transferedBytes); + + if (!_writeQueue.front().GetActiveSize()) + _writeQueue.pop(); + + if (!_writeQueue.empty()) + AsyncProcessQueue(); + else if (_closing) + CloseSocket(); + } + else + CloseSocket(); + } + +#else + + void WriteHandlerWrapper(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/) + { + _isWritingAsync = false; + HandleQueue(); + } + + bool HandleQueue() + { + if (_writeQueue.empty()) + return false; + + MessageBuffer& queuedMessage = _writeQueue.front(); + + std::size_t bytesToSend = queuedMessage.GetActiveSize(); + + boost::system::error_code error; + std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error); + + if (error) + { + if (error == boost::asio::error::would_block || error == boost::asio::error::try_again) + { + return AsyncProcessQueue(); + } + + _writeQueue.pop(); + + if (_closing && _writeQueue.empty()) + { + CloseSocket(); + } + + return false; + } + else if (bytesSent == 0) + { + _writeQueue.pop(); + + if (_closing && _writeQueue.empty()) + { + CloseSocket(); + } + + return false; + } + else if (bytesSent < bytesToSend) // now n > 0 + { + queuedMessage.ReadCompleted(bytesSent); + return AsyncProcessQueue(); + } + + _writeQueue.pop(); + + if (_closing && _writeQueue.empty()) + { + CloseSocket(); + } + + return !_writeQueue.empty(); + } +#endif + + tcp::socket _socket; + + boost::asio::ip::address _remoteAddress; + uint16 _remotePort; + + MessageBuffer _readBuffer; + std::queue<MessageBuffer> _writeQueue; + + std::atomic<bool> _closed; + std::atomic<bool> _closing; + + bool _isWritingAsync; +}; + +#endif // __SOCKET_H__ diff --git a/src/server/shared/Network/SocketMgr.h b/src/server/shared/Network/SocketMgr.h new file mode 100644 index 0000000000..11b1a2f72c --- /dev/null +++ b/src/server/shared/Network/SocketMgr.h @@ -0,0 +1,131 @@ +/* + * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3 + * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore> + */ + +#ifndef SocketMgr_h__ +#define SocketMgr_h__ + +#include "AsyncAcceptor.h" +#include "Errors.h" +#include "NetworkThread.h" +#include <boost/asio/ip/tcp.hpp> +#include <memory> + +using boost::asio::ip::tcp; + +template<class SocketType> +class SocketMgr +{ +public: + virtual ~SocketMgr() + { + ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction"); + } + + virtual bool StartNetwork(acore::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) + { + ASSERT(threadCount > 0); + + AsyncAcceptor* acceptor = nullptr; + try + { + acceptor = new AsyncAcceptor(ioContext, bindIp, port); + } + catch (boost::system::system_error const& err) + { + LOG_ERROR("network", "Exception caught in SocketMgr.StartNetwork (%s:%u): %s", bindIp.c_str(), port, err.what()); + return false; + } + + if (!acceptor->Bind()) + { + LOG_ERROR("network", "StartNetwork failed to bind socket acceptor"); + delete acceptor; + return false; + } + + _acceptor = acceptor; + _threadCount = threadCount; + _threads = CreateThreads(); + + ASSERT(_threads); + + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Start(); + + _acceptor->SetSocketFactory([this]() { return GetSocketForAccept(); }); + + return true; + } + + virtual void StopNetwork() + { + _acceptor->Close(); + + if (_threadCount != 0) + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Stop(); + + Wait(); + + delete _acceptor; + _acceptor = nullptr; + delete[] _threads; + _threads = nullptr; + _threadCount = 0; + } + + void Wait() + { + if (_threadCount != 0) + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Wait(); + } + + virtual void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) + { + try + { + std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock)); + newSocket->Start(); + + _threads[threadIndex].AddSocket(newSocket); + } + catch (boost::system::system_error const& err) + { + LOG_WARN("network", "Failed to retrieve client's remote address %s", err.what()); + } + } + + int32 GetNetworkThreadCount() const { return _threadCount; } + + uint32 SelectThreadWithMinConnections() const + { + uint32 min = 0; + + for (int32 i = 1; i < _threadCount; ++i) + if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount()) + min = i; + + return min; + } + + std::pair<tcp::socket*, uint32> GetSocketForAccept() + { + uint32 threadIndex = SelectThreadWithMinConnections(); + return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex); + } + +protected: + SocketMgr() : + _acceptor(nullptr), _threads(nullptr), _threadCount(0) { } + + virtual NetworkThread<SocketType>* CreateThreads() const = 0; + + AsyncAcceptor* _acceptor; + NetworkThread<SocketType>* _threads; + int32 _threadCount; +}; + +#endif // SocketMgr_h__ |