summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKargatum <dowlandtop@yandex.com>2021-05-27 21:09:31 +0700
committerGitHub <noreply@github.com>2021-05-27 16:09:31 +0200
commitc1e96064e9d8d1c7bcfe870fc5a47afbbd79e24c (patch)
tree485dd9b48c53c939a6e076d2b4ce2222b26a4f52
parent2ae84e2faf6b3884ce949299c1914e2d3e438a9d (diff)
feat(Core/Common): add Asio network threading (#6063)
-rw-r--r--src/common/Asio/AsioHacksFwd.h44
-rw-r--r--src/common/Asio/DeadlineTimer.h30
-rw-r--r--src/common/Asio/IoContext.h65
-rw-r--r--src/common/Asio/IpAddress.h31
-rw-r--r--src/common/Asio/IpNetwork.h57
-rw-r--r--src/common/Asio/Resolver.h51
-rw-r--r--src/common/Asio/Strand.h38
-rw-r--r--src/server/shared/Network/AsyncAcceptor.h146
-rw-r--r--src/server/shared/Network/NetworkThread.h166
-rw-r--r--src/server/shared/Network/Socket.h276
-rw-r--r--src/server/shared/Network/SocketMgr.h131
11 files changed, 1035 insertions, 0 deletions
diff --git a/src/common/Asio/AsioHacksFwd.h b/src/common/Asio/AsioHacksFwd.h
new file mode 100644
index 0000000000..05b7960386
--- /dev/null
+++ b/src/common/Asio/AsioHacksFwd.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
+ * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
+ */
+
+#ifndef AsioHacksFwd_h__
+#define AsioHacksFwd_h__
+
+#include <boost/version.hpp>
+
+/**
+ Collection of forward declarations to improve compile time
+ */
+namespace boost::posix_time
+{
+ class ptime;
+}
+
+namespace boost::asio
+{
+ template <typename Time>
+ struct time_traits;
+}
+
+namespace boost::asio::ip
+{
+ class address;
+ class tcp;
+
+ template <typename InternetProtocol>
+ class basic_endpoint;
+
+ typedef basic_endpoint<tcp> tcp_endpoint;
+}
+
+namespace acore::Asio
+{
+ class DeadlineTimer;
+ class IoContext;
+ class Resolver;
+ class Strand;
+}
+
+#endif // AsioHacksFwd_h__
diff --git a/src/common/Asio/DeadlineTimer.h b/src/common/Asio/DeadlineTimer.h
new file mode 100644
index 0000000000..cf993ee7b8
--- /dev/null
+++ b/src/common/Asio/DeadlineTimer.h
@@ -0,0 +1,30 @@
+/*
+ * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
+ * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
+ */
+
+#ifndef DeadlineTimer_h__
+#define DeadlineTimer_h__
+
+#include <boost/asio/deadline_timer.hpp>
+
+#if BOOST_VERSION >= 107000
+#define BasicDeadlineTimerThirdTemplateArg , boost::asio::io_context::executor_type
+#elif BOOST_VERSION >= 106600
+#define BasicDeadlineTimerThirdTemplateArg
+#else
+#define BasicDeadlineTimerThirdTemplateArg , boost::asio::deadline_timer_service<boost::posix_time::ptime, boost::asio::time_traits<boost::posix_time::ptime>>
+#endif
+
+#define DeadlineTimerBase boost::asio::basic_deadline_timer<boost::posix_time::ptime, boost::asio::time_traits<boost::posix_time::ptime> BasicDeadlineTimerThirdTemplateArg>
+
+namespace acore::Asio
+{
+ class DeadlineTimer : public DeadlineTimerBase
+ {
+ public:
+ using DeadlineTimerBase::basic_deadline_timer;
+ };
+}
+
+#endif // DeadlineTimer_h__
diff --git a/src/common/Asio/IoContext.h b/src/common/Asio/IoContext.h
new file mode 100644
index 0000000000..e08ef97a51
--- /dev/null
+++ b/src/common/Asio/IoContext.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
+ * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
+ */
+
+#ifndef IoContext_h__
+#define IoContext_h__
+
+#include <boost/version.hpp>
+
+#if BOOST_VERSION >= 106600
+#include <boost/asio/io_context.hpp>
+#include <boost/asio/post.hpp>
+#define IoContextBaseNamespace boost::asio
+#define IoContextBase io_context
+#else
+#include <boost/asio/io_service.hpp>
+#define IoContextBaseNamespace boost::asio
+#define IoContextBase io_service
+#endif
+
+namespace acore::Asio
+{
+ class IoContext
+ {
+ public:
+ IoContext() : _impl() { }
+ explicit IoContext(int concurrency_hint) : _impl(concurrency_hint) { }
+
+ operator IoContextBaseNamespace::IoContextBase&() { return _impl; }
+ operator IoContextBaseNamespace::IoContextBase const&() const { return _impl; }
+
+ std::size_t run() { return _impl.run(); }
+ void stop() { _impl.stop(); }
+
+#if BOOST_VERSION >= 106600
+ boost::asio::io_context::executor_type get_executor() noexcept { return _impl.get_executor(); }
+#endif
+
+ private:
+ IoContextBaseNamespace::IoContextBase _impl;
+ };
+
+ template<typename T>
+ inline decltype(auto) post(IoContextBaseNamespace::IoContextBase& ioContext, T&& t)
+ {
+#if BOOST_VERSION >= 106600
+ return boost::asio::post(ioContext, std::forward<T>(t));
+#else
+ return ioContext.post(std::forward<T>(t));
+#endif
+ }
+
+ template<typename T>
+ inline decltype(auto) get_io_context(T&& ioObject)
+ {
+#if BOOST_VERSION >= 106600
+ return ioObject.get_executor().context();
+#else
+ return ioObject.get_io_service();
+#endif
+ }
+}
+
+#endif // IoContext_h__
diff --git a/src/common/Asio/IpAddress.h b/src/common/Asio/IpAddress.h
new file mode 100644
index 0000000000..3545b48c66
--- /dev/null
+++ b/src/common/Asio/IpAddress.h
@@ -0,0 +1,31 @@
+/*
+ * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
+ * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
+ */
+
+#ifndef IpAddress_h__
+#define IpAddress_h__
+
+#include "Define.h"
+#include <boost/asio/ip/address.hpp>
+
+namespace acore::Net
+{
+#if BOOST_VERSION >= 106600
+ using boost::asio::ip::make_address;
+ using boost::asio::ip::make_address_v4;
+ inline uint32 address_to_uint(boost::asio::ip::address_v4 const& address) { return address.to_uint(); }
+#else
+ inline boost::asio::ip::address make_address(char const* str) { return boost::asio::ip::address::from_string(str); }
+ inline boost::asio::ip::address make_address(char const* str, boost::system::error_code& ec) { return boost::asio::ip::address::from_string(str, ec); }
+ inline boost::asio::ip::address make_address(std::string const& str) { return boost::asio::ip::address::from_string(str); }
+ inline boost::asio::ip::address make_address(std::string const& str, boost::system::error_code& ec) { return boost::asio::ip::address::from_string(str, ec); }
+ inline boost::asio::ip::address_v4 make_address_v4(char const* str) { return boost::asio::ip::address_v4::from_string(str); }
+ inline boost::asio::ip::address_v4 make_address_v4(char const* str, boost::system::error_code& ec) { return boost::asio::ip::address_v4::from_string(str, ec); }
+ inline boost::asio::ip::address_v4 make_address_v4(std::string const& str) { return boost::asio::ip::address_v4::from_string(str); }
+ inline boost::asio::ip::address_v4 make_address_v4(std::string const& str, boost::system::error_code& ec) { return boost::asio::ip::address_v4::from_string(str, ec); }
+ inline uint32 address_to_uint(boost::asio::ip::address_v4 const& address) { return address.to_ulong(); }
+#endif
+}
+
+#endif // IpAddress_h__
diff --git a/src/common/Asio/IpNetwork.h b/src/common/Asio/IpNetwork.h
new file mode 100644
index 0000000000..6c1ed9eacb
--- /dev/null
+++ b/src/common/Asio/IpNetwork.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
+ * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
+ */
+
+#ifndef IpNetwork_h__
+#define IpNetwork_h__
+
+#include "Define.h"
+#include "IpAddress.h"
+#include <boost/version.hpp>
+
+#if BOOST_VERSION >= 106600
+#include <boost/asio/ip/network_v4.hpp>
+#include <boost/asio/ip/network_v6.hpp>
+#endif
+
+namespace acore::Net
+{
+ inline bool IsInNetwork(boost::asio::ip::address_v4 const& networkAddress, boost::asio::ip::address_v4 const& mask, boost::asio::ip::address_v4 const& clientAddress)
+ {
+#if BOOST_VERSION >= 106600
+ boost::asio::ip::network_v4 network = boost::asio::ip::make_network_v4(networkAddress, mask);
+ boost::asio::ip::address_v4_range hosts = network.hosts();
+ return hosts.find(clientAddress) != hosts.end();
+#else
+ return (clientAddress.to_ulong() & mask.to_ulong()) == (networkAddress.to_ulong() & mask.to_ulong());
+#endif
+ }
+
+ inline boost::asio::ip::address_v4 GetDefaultNetmaskV4(boost::asio::ip::address_v4 const& networkAddress)
+ {
+ if ((address_to_uint(networkAddress) & 0x80000000) == 0)
+ return boost::asio::ip::address_v4(0xFF000000);
+ if ((address_to_uint(networkAddress) & 0xC0000000) == 0x80000000)
+ return boost::asio::ip::address_v4(0xFFFF0000);
+ if ((address_to_uint(networkAddress) & 0xE0000000) == 0xC0000000)
+ return boost::asio::ip::address_v4(0xFFFFFF00);
+ return boost::asio::ip::address_v4(0xFFFFFFFF);
+ }
+
+ inline bool IsInNetwork(boost::asio::ip::address_v6 const& networkAddress, uint16 prefixLength, boost::asio::ip::address_v6 const& clientAddress)
+ {
+#if BOOST_VERSION >= 106600
+ boost::asio::ip::network_v6 network = boost::asio::ip::make_network_v6(networkAddress, prefixLength);
+ boost::asio::ip::address_v6_range hosts = network.hosts();
+ return hosts.find(clientAddress) != hosts.end();
+#else
+ (void)networkAddress;
+ (void)prefixLength;
+ (void)clientAddress;
+ return false;
+#endif
+ }
+}
+
+#endif // IpNetwork_h__
diff --git a/src/common/Asio/Resolver.h b/src/common/Asio/Resolver.h
new file mode 100644
index 0000000000..459e6c7503
--- /dev/null
+++ b/src/common/Asio/Resolver.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
+ * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
+ */
+
+#ifndef Resolver_h__
+#define Resolver_h__
+
+#include "IoContext.h"
+#include "Optional.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <string>
+
+namespace acore::Asio
+{
+ /**
+ Hack to make it possible to forward declare resolver (one of its template arguments is a typedef to something super long and using nested classes)
+ */
+ class Resolver
+ {
+ public:
+ explicit Resolver(IoContext& ioContext) : _impl(ioContext) { }
+
+ Optional<boost::asio::ip::tcp::endpoint> Resolve(boost::asio::ip::tcp const& protocol, std::string const& host, std::string const& service)
+ {
+ boost::system::error_code ec;
+#if BOOST_VERSION >= 106600
+ boost::asio::ip::resolver_base::flags flagsResolver = boost::asio::ip::resolver_base::all_matching;
+ boost::asio::ip::tcp::resolver::results_type results = _impl.resolve(protocol, host, service, flagsResolver, ec);
+ if (results.begin() == results.end() || ec)
+ return {};
+
+ return results.begin()->endpoint();
+#else
+ boost::asio::ip::resolver_query_base::flags flagsQuery = boost::asio::ip::tcp::resolver::query::all_matching;
+ boost::asio::ip::tcp::resolver::query query(std::move(protocol), std::move(host), std::move(service), flagsQuery);
+ boost::asio::ip::tcp::resolver::iterator itr = _impl.resolve(query, ec);
+ boost::asio::ip::tcp::resolver::iterator end;
+ if (itr == end || ec)
+ return {};
+
+ return itr->endpoint();
+#endif
+ }
+
+ private:
+ boost::asio::ip::tcp::resolver _impl;
+ };
+}
+
+#endif // Resolver_h__
diff --git a/src/common/Asio/Strand.h b/src/common/Asio/Strand.h
new file mode 100644
index 0000000000..351492d245
--- /dev/null
+++ b/src/common/Asio/Strand.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
+ * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
+ */
+
+#ifndef Strand_h__
+#define Strand_h__
+
+#include "IoContext.h"
+#include <boost/asio/strand.hpp>
+
+#if BOOST_VERSION >= 106600
+#include <boost/asio/bind_executor.hpp>
+#endif
+
+namespace acore::Asio
+{
+ /**
+ Hack to make it possible to forward declare strand (which is a inner class)
+ */
+ class Strand : public IoContextBaseNamespace::IoContextBase::strand
+ {
+ public:
+ Strand(IoContext& ioContext) : IoContextBaseNamespace::IoContextBase::strand(ioContext) { }
+ };
+
+#if BOOST_VERSION >= 106600
+ using boost::asio::bind_executor;
+#else
+ template<typename T>
+ inline decltype(auto) bind_executor(Strand& strand, T&& t)
+ {
+ return strand.wrap(std::forward<T>(t));
+ }
+#endif
+}
+
+#endif // Strand_h__
diff --git a/src/server/shared/Network/AsyncAcceptor.h b/src/server/shared/Network/AsyncAcceptor.h
new file mode 100644
index 0000000000..c618fb9e9e
--- /dev/null
+++ b/src/server/shared/Network/AsyncAcceptor.h
@@ -0,0 +1,146 @@
+/*
+ * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
+ * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
+ */
+
+#ifndef __ASYNCACCEPT_H_
+#define __ASYNCACCEPT_H_
+
+#include "IoContext.h"
+#include "IpAddress.h"
+#include "Log.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <functional>
+#include <atomic>
+
+using boost::asio::ip::tcp;
+
+#if BOOST_VERSION >= 106600
+#define WARHEAD_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections
+#else
+#define WARHEAD_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_connections
+#endif
+
+class AsyncAcceptor
+{
+public:
+ typedef void(*AcceptCallback)(tcp::socket&& newSocket, uint32 threadIndex);
+
+ AsyncAcceptor(acore::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
+ _acceptor(ioContext), _endpoint(acore::Net::make_address(bindIp), port),
+ _socket(ioContext), _closed(false), _socketFactory(std::bind(&AsyncAcceptor::DefeaultSocketFactory, this))
+ {
+ }
+
+ template<class T>
+ void AsyncAccept();
+
+ template<AcceptCallback acceptCallback>
+ void AsyncAcceptWithCallback()
+ {
+ tcp::socket* socket;
+ uint32 threadIndex;
+ std::tie(socket, threadIndex) = _socketFactory();
+ _acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error)
+ {
+ if (!error)
+ {
+ try
+ {
+ socket->non_blocking(true);
+
+ acceptCallback(std::move(*socket), threadIndex);
+ }
+ catch (boost::system::system_error const& err)
+ {
+ LOG_INFO("network", "Failed to initialize client's socket %s", err.what());
+ }
+ }
+
+ if (!_closed)
+ this->AsyncAcceptWithCallback<acceptCallback>();
+ });
+ }
+
+ bool Bind()
+ {
+ boost::system::error_code errorCode;
+ _acceptor.open(_endpoint.protocol(), errorCode);
+ if (errorCode)
+ {
+ LOG_INFO("network", "Failed to open acceptor %s", errorCode.message().c_str());
+ return false;
+ }
+
+#if WARHEAD_PLATFORM != WARHEAD_PLATFORM_WINDOWS
+ _acceptor.set_option(boost::asio::ip::tcp::acceptor::reuse_address(true), errorCode);
+ if (errorCode)
+ {
+ LOG_INFO("network", "Failed to set reuse_address option on acceptor %s", errorCode.message().c_str());
+ return false;
+ }
+#endif
+
+ _acceptor.bind(_endpoint, errorCode);
+ if (errorCode)
+ {
+ LOG_INFO("network", "Could not bind to %s:%u %s", _endpoint.address().to_string().c_str(), _endpoint.port(), errorCode.message().c_str());
+ return false;
+ }
+
+ _acceptor.listen(WARHEAD_MAX_LISTEN_CONNECTIONS, errorCode);
+ if (errorCode)
+ {
+ LOG_INFO("network", "Failed to start listening on %s:%u %s", _endpoint.address().to_string().c_str(), _endpoint.port(), errorCode.message().c_str());
+ return false;
+ }
+
+ return true;
+ }
+
+ void Close()
+ {
+ if (_closed.exchange(true))
+ return;
+
+ boost::system::error_code err;
+ _acceptor.close(err);
+ }
+
+ void SetSocketFactory(std::function<std::pair<tcp::socket*, uint32>()> func) { _socketFactory = func; }
+
+private:
+ std::pair<tcp::socket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
+
+ tcp::acceptor _acceptor;
+ tcp::endpoint _endpoint;
+ tcp::socket _socket;
+ std::atomic<bool> _closed;
+ std::function<std::pair<tcp::socket*, uint32>()> _socketFactory;
+};
+
+template<class T>
+void AsyncAcceptor::AsyncAccept()
+{
+ _acceptor.async_accept(_socket, [this](boost::system::error_code error)
+ {
+ if (!error)
+ {
+ try
+ {
+ // this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class
+ std::make_shared<T>(std::move(this->_socket))->Start();
+ }
+ catch (boost::system::system_error const& err)
+ {
+ LOG_INFO("network", "Failed to retrieve client's remote address %s", err.what());
+ }
+ }
+
+ // lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face
+ if (!_closed)
+ this->AsyncAccept<T>();
+ });
+}
+
+#endif /* __ASYNCACCEPT_H_ */
diff --git a/src/server/shared/Network/NetworkThread.h b/src/server/shared/Network/NetworkThread.h
new file mode 100644
index 0000000000..48c6b9f163
--- /dev/null
+++ b/src/server/shared/Network/NetworkThread.h
@@ -0,0 +1,166 @@
+/*
+ * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
+ * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
+ */
+
+#ifndef NetworkThread_h__
+#define NetworkThread_h__
+
+#include "Define.h"
+#include "DeadlineTimer.h"
+#include "Errors.h"
+#include "IoContext.h"
+#include "Log.h"
+#include "Timer.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <atomic>
+#include <chrono>
+#include <memory>
+#include <mutex>
+#include <set>
+#include <thread>
+
+using boost::asio::ip::tcp;
+
+template<class SocketType>
+class NetworkThread
+{
+public:
+ NetworkThread() : _connections(0), _stopped(false), _thread(nullptr), _ioContext(1),
+ _acceptSocket(_ioContext), _updateTimer(_ioContext) { }
+
+ virtual ~NetworkThread()
+ {
+ Stop();
+
+ if (_thread)
+ {
+ Wait();
+ delete _thread;
+ }
+ }
+
+ void Stop()
+ {
+ _stopped = true;
+ _ioContext.stop();
+ }
+
+ bool Start()
+ {
+ if (_thread)
+ return false;
+
+ _thread = new std::thread(&NetworkThread::Run, this);
+ return true;
+ }
+
+ void Wait()
+ {
+ ASSERT(_thread);
+
+ _thread->join();
+ delete _thread;
+ _thread = nullptr;
+ }
+
+ int32 GetConnectionCount() const
+ {
+ return _connections;
+ }
+
+ virtual void AddSocket(std::shared_ptr<SocketType> sock)
+ {
+ std::lock_guard<std::mutex> lock(_newSocketsLock);
+
+ ++_connections;
+ _newSockets.push_back(sock);
+ SocketAdded(sock);
+ }
+
+ tcp::socket* GetSocketForAccept() { return &_acceptSocket; }
+
+protected:
+ virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { }
+ virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { }
+
+ void AddNewSockets()
+ {
+ std::lock_guard<std::mutex> lock(_newSocketsLock);
+
+ if (_newSockets.empty())
+ return;
+
+ for (std::shared_ptr<SocketType> sock : _newSockets)
+ {
+ if (!sock->IsOpen())
+ {
+ SocketRemoved(sock);
+ --_connections;
+ }
+ else
+ _sockets.push_back(sock);
+ }
+
+ _newSockets.clear();
+ }
+
+ void Run()
+ {
+ LOG_DEBUG("misc", "Network Thread Starting");
+
+ _updateTimer.expires_from_now(boost::posix_time::milliseconds(10));
+ _updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this));
+ _ioContext.run();
+
+ LOG_DEBUG("misc", "Network Thread exits");
+ _newSockets.clear();
+ _sockets.clear();
+ }
+
+ void Update()
+ {
+ if (_stopped)
+ return;
+
+ _updateTimer.expires_from_now(boost::posix_time::milliseconds(10));
+ _updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this));
+
+ AddNewSockets();
+
+ _sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr<SocketType> sock)
+ {
+ if (!sock->Update())
+ {
+ if (sock->IsOpen())
+ sock->CloseSocket();
+
+ this->SocketRemoved(sock);
+
+ --this->_connections;
+ return true;
+ }
+
+ return false;
+ }), _sockets.end());
+ }
+
+private:
+ typedef std::vector<std::shared_ptr<SocketType>> SocketContainer;
+
+ std::atomic<int32> _connections;
+ std::atomic<bool> _stopped;
+
+ std::thread* _thread;
+
+ SocketContainer _sockets;
+
+ std::mutex _newSocketsLock;
+ SocketContainer _newSockets;
+
+ acore::Asio::IoContext _ioContext;
+ tcp::socket _acceptSocket;
+ acore::Asio::DeadlineTimer _updateTimer;
+};
+
+#endif // NetworkThread_h__
diff --git a/src/server/shared/Network/Socket.h b/src/server/shared/Network/Socket.h
new file mode 100644
index 0000000000..c3c50a18e2
--- /dev/null
+++ b/src/server/shared/Network/Socket.h
@@ -0,0 +1,276 @@
+/*
+ * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
+ * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
+ */
+
+#ifndef __SOCKET_H__
+#define __SOCKET_H__
+
+#include "MessageBuffer.h"
+#include "Log.h"
+#include <atomic>
+#include <queue>
+#include <memory>
+#include <functional>
+#include <type_traits>
+#include <boost/asio/ip/tcp.hpp>
+
+using boost::asio::ip::tcp;
+
+#define READ_BLOCK_SIZE 4096
+#ifdef BOOST_ASIO_HAS_IOCP
+#define AC_SOCKET_USE_IOCP
+#endif
+
+template<class T>
+class Socket : public std::enable_shared_from_this<T>
+{
+public:
+ explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
+ _remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false)
+ {
+ _readBuffer.Resize(READ_BLOCK_SIZE);
+ }
+
+ virtual ~Socket()
+ {
+ _closed = true;
+ boost::system::error_code error;
+ _socket.close(error);
+ }
+
+ virtual void Start() = 0;
+
+ virtual bool Update()
+ {
+ if (_closed)
+ {
+ return false;
+ }
+
+#ifndef AC_SOCKET_USE_IOCP
+ if (_isWritingAsync || (_writeQueue.empty() && !_closing))
+ {
+ return true;
+ }
+
+ for (; HandleQueue();)
+ ;
+#endif
+
+ return true;
+ }
+
+ boost::asio::ip::address GetRemoteIpAddress() const
+ {
+ return _remoteAddress;
+ }
+
+ uint16 GetRemotePort() const
+ {
+ return _remotePort;
+ }
+
+ void AsyncRead()
+ {
+ if (!IsOpen())
+ {
+ return;
+ }
+
+ _readBuffer.Normalize();
+ _readBuffer.EnsureFreeSpace();
+
+ _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
+ std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+ }
+
+ void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t))
+ {
+ if (!IsOpen())
+ {
+ return;
+ }
+
+ _readBuffer.Normalize();
+ _readBuffer.EnsureFreeSpace();
+
+ _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
+ std::bind(callback, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+ }
+
+ void QueuePacket(MessageBuffer&& buffer)
+ {
+ _writeQueue.push(std::move(buffer));
+
+#ifdef AC_SOCKET_USE_IOCP
+ AsyncProcessQueue();
+#endif
+ }
+
+ bool IsOpen() const { return !_closed && !_closing; }
+
+ void CloseSocket()
+ {
+ if (_closed.exchange(true))
+ return;
+
+ boost::system::error_code shutdownError;
+ _socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError);
+
+ if (shutdownError)
+ LOG_DEBUG("network", "Socket::CloseSocket: %s errored when shutting down socket: %i (%s)", GetRemoteIpAddress().to_string().c_str(),
+ shutdownError.value(), shutdownError.message().c_str());
+
+ OnClose();
+ }
+
+ /// Marks the socket for closing after write buffer becomes empty
+ void DelayedCloseSocket() { _closing = true; }
+
+ MessageBuffer& GetReadBuffer() { return _readBuffer; }
+
+protected:
+ virtual void OnClose() { }
+ virtual void ReadHandler() = 0;
+
+ bool AsyncProcessQueue()
+ {
+ if (_isWritingAsync)
+ return false;
+
+ _isWritingAsync = true;
+
+#ifdef AC_SOCKET_USE_IOCP
+ MessageBuffer& buffer = _writeQueue.front();
+ _socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), std::bind(&Socket<T>::WriteHandler,
+ this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+#else
+ _socket.async_write_some(boost::asio::null_buffers(), std::bind(&Socket<T>::WriteHandlerWrapper,
+ this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
+#endif
+ return false;
+ }
+
+ void SetNoDelay(bool enable)
+ {
+ boost::system::error_code err;
+ _socket.set_option(tcp::no_delay(enable), err);
+
+ if (err)
+ LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for %s - %d (%s)",
+ GetRemoteIpAddress().to_string().c_str(), err.value(), err.message().c_str());
+ }
+
+private:
+ void ReadHandlerInternal(boost::system::error_code error, size_t transferredBytes)
+ {
+ if (error)
+ {
+ CloseSocket();
+ return;
+ }
+
+ _readBuffer.WriteCompleted(transferredBytes);
+ ReadHandler();
+ }
+
+#ifdef AC_SOCKET_USE_IOCP
+ void WriteHandler(boost::system::error_code error, std::size_t transferedBytes)
+ {
+ if (!error)
+ {
+ _isWritingAsync = false;
+ _writeQueue.front().ReadCompleted(transferedBytes);
+
+ if (!_writeQueue.front().GetActiveSize())
+ _writeQueue.pop();
+
+ if (!_writeQueue.empty())
+ AsyncProcessQueue();
+ else if (_closing)
+ CloseSocket();
+ }
+ else
+ CloseSocket();
+ }
+
+#else
+
+ void WriteHandlerWrapper(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/)
+ {
+ _isWritingAsync = false;
+ HandleQueue();
+ }
+
+ bool HandleQueue()
+ {
+ if (_writeQueue.empty())
+ return false;
+
+ MessageBuffer& queuedMessage = _writeQueue.front();
+
+ std::size_t bytesToSend = queuedMessage.GetActiveSize();
+
+ boost::system::error_code error;
+ std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error);
+
+ if (error)
+ {
+ if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
+ {
+ return AsyncProcessQueue();
+ }
+
+ _writeQueue.pop();
+
+ if (_closing && _writeQueue.empty())
+ {
+ CloseSocket();
+ }
+
+ return false;
+ }
+ else if (bytesSent == 0)
+ {
+ _writeQueue.pop();
+
+ if (_closing && _writeQueue.empty())
+ {
+ CloseSocket();
+ }
+
+ return false;
+ }
+ else if (bytesSent < bytesToSend) // now n > 0
+ {
+ queuedMessage.ReadCompleted(bytesSent);
+ return AsyncProcessQueue();
+ }
+
+ _writeQueue.pop();
+
+ if (_closing && _writeQueue.empty())
+ {
+ CloseSocket();
+ }
+
+ return !_writeQueue.empty();
+ }
+#endif
+
+ tcp::socket _socket;
+
+ boost::asio::ip::address _remoteAddress;
+ uint16 _remotePort;
+
+ MessageBuffer _readBuffer;
+ std::queue<MessageBuffer> _writeQueue;
+
+ std::atomic<bool> _closed;
+ std::atomic<bool> _closing;
+
+ bool _isWritingAsync;
+};
+
+#endif // __SOCKET_H__
diff --git a/src/server/shared/Network/SocketMgr.h b/src/server/shared/Network/SocketMgr.h
new file mode 100644
index 0000000000..11b1a2f72c
--- /dev/null
+++ b/src/server/shared/Network/SocketMgr.h
@@ -0,0 +1,131 @@
+/*
+ * Copyright (C) 2016+ AzerothCore <www.azerothcore.org>, released under GNU AGPL v3 license: https://github.com/azerothcore/azerothcore-wotlk/blob/master/LICENSE-AGPL3
+ * Copyright (C) 2021+ WarheadCore <https://github.com/WarheadCore>
+ */
+
+#ifndef SocketMgr_h__
+#define SocketMgr_h__
+
+#include "AsyncAcceptor.h"
+#include "Errors.h"
+#include "NetworkThread.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <memory>
+
+using boost::asio::ip::tcp;
+
+template<class SocketType>
+class SocketMgr
+{
+public:
+ virtual ~SocketMgr()
+ {
+ ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction");
+ }
+
+ virtual bool StartNetwork(acore::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
+ {
+ ASSERT(threadCount > 0);
+
+ AsyncAcceptor* acceptor = nullptr;
+ try
+ {
+ acceptor = new AsyncAcceptor(ioContext, bindIp, port);
+ }
+ catch (boost::system::system_error const& err)
+ {
+ LOG_ERROR("network", "Exception caught in SocketMgr.StartNetwork (%s:%u): %s", bindIp.c_str(), port, err.what());
+ return false;
+ }
+
+ if (!acceptor->Bind())
+ {
+ LOG_ERROR("network", "StartNetwork failed to bind socket acceptor");
+ delete acceptor;
+ return false;
+ }
+
+ _acceptor = acceptor;
+ _threadCount = threadCount;
+ _threads = CreateThreads();
+
+ ASSERT(_threads);
+
+ for (int32 i = 0; i < _threadCount; ++i)
+ _threads[i].Start();
+
+ _acceptor->SetSocketFactory([this]() { return GetSocketForAccept(); });
+
+ return true;
+ }
+
+ virtual void StopNetwork()
+ {
+ _acceptor->Close();
+
+ if (_threadCount != 0)
+ for (int32 i = 0; i < _threadCount; ++i)
+ _threads[i].Stop();
+
+ Wait();
+
+ delete _acceptor;
+ _acceptor = nullptr;
+ delete[] _threads;
+ _threads = nullptr;
+ _threadCount = 0;
+ }
+
+ void Wait()
+ {
+ if (_threadCount != 0)
+ for (int32 i = 0; i < _threadCount; ++i)
+ _threads[i].Wait();
+ }
+
+ virtual void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
+ {
+ try
+ {
+ std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
+ newSocket->Start();
+
+ _threads[threadIndex].AddSocket(newSocket);
+ }
+ catch (boost::system::system_error const& err)
+ {
+ LOG_WARN("network", "Failed to retrieve client's remote address %s", err.what());
+ }
+ }
+
+ int32 GetNetworkThreadCount() const { return _threadCount; }
+
+ uint32 SelectThreadWithMinConnections() const
+ {
+ uint32 min = 0;
+
+ for (int32 i = 1; i < _threadCount; ++i)
+ if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
+ min = i;
+
+ return min;
+ }
+
+ std::pair<tcp::socket*, uint32> GetSocketForAccept()
+ {
+ uint32 threadIndex = SelectThreadWithMinConnections();
+ return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
+ }
+
+protected:
+ SocketMgr() :
+ _acceptor(nullptr), _threads(nullptr), _threadCount(0) { }
+
+ virtual NetworkThread<SocketType>* CreateThreads() const = 0;
+
+ AsyncAcceptor* _acceptor;
+ NetworkThread<SocketType>* _threads;
+ int32 _threadCount;
+};
+
+#endif // SocketMgr_h__