diff --git a/src/common/Asio/AsioHacksFwd.h b/src/common/Asio/AsioHacksFwd.h index 06f5e531c20..6adc28e8fa7 100644 --- a/src/common/Asio/AsioHacksFwd.h +++ b/src/common/Asio/AsioHacksFwd.h @@ -35,6 +35,12 @@ namespace boost template struct time_traits; + template + struct wait_traits; + + template + class basic_socket_iostream; + namespace ip { class address; diff --git a/src/common/Asio/IoContext.h b/src/common/Asio/IoContext.h index 4b469d3baa2..82fcdc03979 100644 --- a/src/common/Asio/IoContext.h +++ b/src/common/Asio/IoContext.h @@ -15,8 +15,8 @@ * with this program. If not, see . */ -#ifndef IoContext_h__ -#define IoContext_h__ +#ifndef TRINITYCORE_IO_CONTEXT_H +#define TRINITYCORE_IO_CONTEXT_H #include #include @@ -56,20 +56,20 @@ namespace Trinity return boost::asio::post(ioContext, std::forward(t)); } - template - inline decltype(auto) post(boost::asio::io_context::executor_type& executor, T&& t) - { - return boost::asio::post(executor, std::forward(t)); - } - using boost::asio::bind_executor; + template + inline decltype(auto) post(boost::asio::io_context::executor_type const& executor, T&& t) + { + return boost::asio::post(executor.context(), bind_executor(executor, std::forward(t))); + } + template inline decltype(auto) get_io_context(T&& ioObject) { - return ioObject.get_executor().context(); + return std::forward(ioObject).get_executor().context(); } } } -#endif // IoContext_h__ +#endif // TRINITYCORE_IO_CONTEXT_H diff --git a/src/common/Threading/ThreadPool.h b/src/common/Threading/ThreadPool.h index c99bdf1af4d..6e2155f0972 100644 --- a/src/common/Threading/ThreadPool.h +++ b/src/common/Threading/ThreadPool.h @@ -27,6 +27,8 @@ namespace Trinity class ThreadPool { public: + using executor_type = boost::asio::thread_pool::executor_type; + explicit ThreadPool(std::size_t numThreads = std::thread::hardware_concurrency()) : _impl(numThreads) { } template diff --git a/src/common/network/AsyncAcceptor.h b/src/common/network/AsyncAcceptor.h index aaaa9410584..cbd8027c9c0 100644 --- a/src/common/network/AsyncAcceptor.h +++ b/src/common/network/AsyncAcceptor.h @@ -18,6 +18,7 @@ #ifndef TRINITYCORE_ASYNC_ACCEPTOR_H #define TRINITYCORE_ASYNC_ACCEPTOR_H +#include "Concepts.h" #include "IoContext.h" #include "IpAddress.h" #include "Log.h" @@ -25,40 +26,40 @@ #include #include #include -#include #define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections namespace Trinity::Net { template -concept AcceptCallback = std::invocable; +concept AcceptCallback = std::invocable; + +template +concept SelectIoContextForNewSocketFn = Trinity::invocable_r; class AsyncAcceptor { public: AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) : - _acceptor(ioContext), _endpoint(make_address(bindIp), port), - _socket(ioContext), _closed(false), _socketFactory([this] { return DefaultSocketFactory(); }) + _acceptor(ioContext), _endpoint(make_address(bindIp), port), _closed(false) { } - template - void AsyncAccept(Callback&& acceptCallback) + template + void AsyncAccept(SelectIoContextForNewSocket&& selectIoContext, Callback&& acceptCallback) { - auto [tmpSocket, tmpThreadIndex] = _socketFactory(); - // TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures) - IoContextTcpSocket* socket = tmpSocket; - uint32 threadIndex = tmpThreadIndex; - _acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward(acceptCallback)](boost::system::error_code const& error) mutable + Asio::IoContext* context = selectIoContext(); + _acceptor.async_accept(context->get_executor(), [this, + selectIoContext = std::forward(selectIoContext), + acceptCallback = std::forward(acceptCallback)](boost::system::error_code const& error, IoContextTcpSocket&& socket) mutable { if (!error) { try { - socket->non_blocking(true); + socket.non_blocking(true); - acceptCallback(std::move(*socket), threadIndex); + acceptCallback(std::move(socket)); } catch (boost::system::system_error const& err) { @@ -67,7 +68,7 @@ public: } if (!_closed) - this->AsyncAccept(std::move(acceptCallback)); + this->AsyncAccept(std::move(selectIoContext), std::move(acceptCallback)); }); } @@ -93,7 +94,14 @@ public: // v6_only is enabled on some *BSD distributions by default // we want to allow both v4 and v6 connections to the same listener if (_endpoint.protocol() == boost::asio::ip::tcp::v6()) - _acceptor.set_option(boost::asio::ip::v6_only(false)); + { + _acceptor.set_option(boost::asio::ip::v6_only(false), errorCode); + if (errorCode) + { + TC_LOG_INFO("network", "Could not disable v6_only option {}", errorCode.message()); + return false; + } + } _acceptor.bind(_endpoint, errorCode); if (errorCode) @@ -121,16 +129,10 @@ public: _acceptor.close(err); } - void SetSocketFactory(std::function()> func) { _socketFactory = std::move(func); } - private: - std::pair DefaultSocketFactory() { return std::make_pair(&_socket, 0); } - boost::asio::basic_socket_acceptor _acceptor; boost::asio::ip::tcp::endpoint _endpoint; - IoContextTcpSocket _socket; std::atomic _closed; - std::function()> _socketFactory; }; } diff --git a/src/common/network/Http/HttpService.h b/src/common/network/Http/HttpService.h index bca573f431d..b82f418f6a0 100644 --- a/src/common/network/Http/HttpService.h +++ b/src/common/network/Http/HttpService.h @@ -113,12 +113,12 @@ template concept HttpRequestHandler = invocable_r, RequestContext&>; template -class HttpNetworkThread final : public NetworkThread +class HttpService; + +template +class HttpNetworkThread final : public NetworkThread> { public: - explicit HttpNetworkThread(SessionService* service) : _service(service) { } - -protected: void SocketRemoved(std::shared_ptr const& session) override { if (Optional id = session->GetSessionId()) @@ -126,13 +126,24 @@ protected: } private: + friend class HttpService; SessionService* _service = nullptr; }; template -class HttpService : public SocketMgr>, public DispatcherService, public SessionService +struct HttpServiceTraits { - using BaseSocketMgr = SocketMgr>; + using Self = HttpService; + using SocketType = SessionImpl; + using ThreadType = HttpNetworkThread; +}; + +template +class HttpService : public SocketMgr>, public DispatcherService, public SessionService +{ + using BaseSocketMgr = SocketMgr>; + + friend BaseSocketMgr; public: HttpService(std::string_view loggerSuffix) : DispatcherService(loggerSuffix), SessionService(loggerSuffix), _ioContext(nullptr), _logger("server.http.") @@ -176,11 +187,11 @@ public: } protected: - HttpNetworkThread* CreateThreads() const final + std::unique_ptr[]> CreateThreads() const override { - HttpNetworkThread* threads = static_cast*>(::operator new(sizeof(HttpNetworkThread) * this->GetNetworkThreadCount())); + std::unique_ptr[]> threads = std::make_unique[]>(this->GetNetworkThreadCount()); for (int32 i = 0; i < this->GetNetworkThreadCount(); ++i) - new (&threads[i]) HttpNetworkThread(const_cast(this)); + threads[i]._service = const_cast(this); return threads; } diff --git a/src/common/network/NetworkThread.h b/src/common/network/NetworkThread.h index b810abaa48a..b0c84aa777c 100644 --- a/src/common/network/NetworkThread.h +++ b/src/common/network/NetworkThread.h @@ -24,8 +24,6 @@ #include "Errors.h" #include "IoContext.h" #include "Log.h" -#include "Socket.h" -#include #include #include #include @@ -33,12 +31,12 @@ namespace Trinity::Net { -template +template class NetworkThread { public: NetworkThread() : _connections(0), _stopped(false), _thread(nullptr), _ioContext(1), - _acceptSocket(_ioContext), _updateTimer(_ioContext) + _updateTimer(_ioContext) { } @@ -82,15 +80,15 @@ public: return _connections; } - void AddSocket(std::shared_ptr sock) + void AddSocket(std::shared_ptr&& sock) { std::scoped_lock lock(_newSocketsLock); ++_connections; - SocketAdded(_newSockets.emplace_back(std::move(sock))); + static_cast(this)->SocketAdded(_newSockets.emplace_back(std::move(sock))); } - Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; } + Trinity::Asio::IoContext* GetIoContext() { return &_ioContext; } protected: virtual void SocketAdded(std::shared_ptr const& /*sock*/) { } @@ -107,7 +105,7 @@ protected: { if (!sock->IsOpen()) { - SocketRemoved(sock); + static_cast(this)->SocketRemoved(sock); --_connections; } else @@ -147,7 +145,7 @@ protected: if (sock->IsOpen()) sock->CloseSocket(); - this->SocketRemoved(sock); + static_cast(this)->SocketRemoved(sock); --this->_connections; return true; @@ -171,7 +169,6 @@ private: SocketContainer _newSockets; Trinity::Asio::IoContext _ioContext; - Trinity::Net::IoContextTcpSocket _acceptSocket; Trinity::Asio::DeadlineTimer _updateTimer; }; } diff --git a/src/common/network/Socket.h b/src/common/network/Socket.h index 5aca0368b90..3bc5c1a8745 100644 --- a/src/common/network/Socket.h +++ b/src/common/network/Socket.h @@ -128,7 +128,7 @@ class Socket : public std::enable_shared_from_this> public: template explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward(args)...), - _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open) + _remoteEndpoint(_socket.remote_endpoint()), _openState(OpenState_Open) { } @@ -185,18 +185,17 @@ public: boost::asio::ip::address const& GetRemoteIpAddress() const { - return _remoteAddress; + return _remoteEndpoint.Address; } uint16 GetRemotePort() const { - return _remotePort; + return _remoteEndpoint.Port; } void SetRemoteEndpoint(boost::asio::ip::tcp::endpoint const& endpoint) { - _remoteAddress = endpoint.address(); - _remotePort = endpoint.port(); + _remoteEndpoint = endpoint; } template Callback> @@ -227,7 +226,7 @@ public: void CloseSocket() { - if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0) + if (_openState.exchange(OpenState_Closed) == OpenState_Closed) return; boost::system::error_code shutdownError; @@ -242,7 +241,8 @@ public: /// Marks the socket for closing after write buffer becomes empty void DelayedCloseSocket() { - if (_openState.fetch_or(OpenState_Closing) != 0) + uint8 oldState = OpenState_Open; + if (!_openState.compare_exchange_strong(oldState, OpenState_Closing)) return; if (_writeQueue.empty()) @@ -380,8 +380,14 @@ private: Stream _socket; - boost::asio::ip::address _remoteAddress; - uint16 _remotePort = 0; + struct Endpoint + { + Endpoint() : Address(), Port(0) { } + explicit(false) Endpoint(boost::asio::ip::tcp_endpoint const& endpoint) : Address(endpoint.address()), Port(endpoint.port()) { } + + boost::asio::ip::address Address; + uint16 Port; + } _remoteEndpoint; MessageBuffer _readBuffer = MessageBuffer(0x1000); std::queue _writeQueue; diff --git a/src/common/network/SocketMgr.h b/src/common/network/SocketMgr.h index fd17be06811..e7acf24d8f3 100644 --- a/src/common/network/SocketMgr.h +++ b/src/common/network/SocketMgr.h @@ -27,13 +27,14 @@ namespace Trinity::Net { -template +template class SocketMgr { - static_assert(std::is_base_of_v, ThreadType>); - static_assert(std::is_final_v); - public: + using Self = typename Traits::Self; + using SocketType = typename Traits::SocketType; + using ThreadType = typename Traits::ThreadType; + SocketMgr(SocketMgr const&) = delete; SocketMgr(SocketMgr&&) = delete; SocketMgr& operator=(SocketMgr const&) = delete; @@ -67,14 +68,16 @@ public: _acceptor = std::move(acceptor); _threadCount = threadCount; - _threads.reset(CreateThreads()); + _threads = static_cast(this)->CreateThreads(); ASSERT(_threads); for (int32 i = 0; i < _threadCount; ++i) _threads[i].Start(); - _acceptor->SetSocketFactory([this]() { return GetSocketForAccept(); }); + _acceptor->AsyncAccept( + [this]{ return SelectThreadWithMinConnections(); }, + [this](IoContextTcpSocket&& sock) { static_cast(this)->OnSocketOpen(std::move(sock)); }); return true; } @@ -99,14 +102,19 @@ public: _threads[i].Wait(); } - virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex) + virtual void OnSocketOpen(IoContextTcpSocket&& sock) { try { + int32 threadIndex = 0; + for (; threadIndex < _threadCount; ++threadIndex) + if (_threads[threadIndex].GetIoContext()->get_executor() == sock.get_executor()) + break; + std::shared_ptr newSocket = std::make_shared(std::move(sock)); newSocket->Start(); - _threads[threadIndex].AddSocket(newSocket); + _threads[threadIndex].AddSocket(std::move(newSocket)); } catch (boost::system::system_error const& err) { @@ -116,21 +124,15 @@ public: int32 GetNetworkThreadCount() const { return _threadCount; } - uint32 SelectThreadWithMinConnections() const + Asio::IoContext* SelectThreadWithMinConnections() const { - uint32 min = 0; + ThreadType* min = &_threads[0]; - for (int32 i = 1; i < _threadCount; ++i) - if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount()) + for (ThreadType* i = min + 1; i != _threads.get() + _threadCount; ++i) + if (i->GetConnectionCount() < min->GetConnectionCount()) min = i; - return min; - } - - std::pair GetSocketForAccept() - { - uint32 threadIndex = SelectThreadWithMinConnections(); - return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex); + return min->GetIoContext(); } protected: @@ -138,7 +140,10 @@ protected: { } - virtual ThreadType* CreateThreads() const = 0; + virtual std::unique_ptr CreateThreads() const + { + return std::make_unique(GetNetworkThreadCount()); + } std::unique_ptr _acceptor; std::unique_ptr _threads; diff --git a/src/server/bnetserver/Main.cpp b/src/server/bnetserver/Main.cpp index 91e742a2b93..bc9004a6ba0 100644 --- a/src/server/bnetserver/Main.cpp +++ b/src/server/bnetserver/Main.cpp @@ -83,11 +83,13 @@ void ServiceStatusWatcher(std::weak_ptr serviceSta bool StartDB(); void StopDB(); -void SignalHandler(std::weak_ptr ioContextRef, boost::system::error_code const& error, int signalNumber); +void SignalHandler(boost::system::error_code const& error, int signalNumber); void KeepDatabaseAliveHandler(std::weak_ptr dbPingTimerRef, int32 dbPingInterval, boost::system::error_code const& error); void BanExpiryHandler(std::weak_ptr banExpiryCheckTimerRef, int32 banExpiryCheckInterval, boost::system::error_code const& error); variables_map GetConsoleArguments(int argc, char** argv, fs::path& configFile, fs::path& configDir, std::string& winServiceAction); +std::atomic Stopped; + int main(int argc, char** argv) { signal(SIGABRT, &Trinity::AbortHandler); @@ -203,6 +205,14 @@ int main(int argc, char** argv) std::shared_ptr ioContext = std::make_shared(); + auto ioContextWork = boost::asio::make_work_guard(ioContext->get_executor()); + + std::thread ioContextThread(&Trinity::Asio::IoContext::run, ioContext.get()); + + auto threadJoinHandle = Trinity::make_unique_ptr_with_deleter<[](std::thread* t) { t->join(); }>(&ioContextThread); + + auto ioContextWorkGuardHandle = Trinity::make_unique_ptr_with_deleter<[](auto* wg) { wg->reset(); }>(&ioContextWork); + Trinity::Net::ScanLocalNetworks(); std::string httpBindIp = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0"); @@ -236,7 +246,7 @@ int main(int argc, char** argv) std::string bindIp = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0"); - if (!sSessionMgr.StartNetwork(*ioContext, bindIp, bnport)) + if (!sSessionMgr.StartNetwork(*ioContext, bindIp, bnport, 1)) { TC_LOG_ERROR("server.bnetserver", "Failed to initialize network"); return 1; @@ -245,14 +255,11 @@ int main(int argc, char** argv) auto sSessionMgrHandle = Trinity::make_unique_ptr_with_deleter<&Battlenet::SessionManager::StopNetwork>(&sSessionMgr); // Set signal handlers - boost::asio::signal_set signals(*ioContext, SIGINT, SIGTERM); + boost::asio::basic_signal_set signals(*ioContext, SIGINT, SIGTERM); #if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS signals.add(SIGBREAK); #endif - signals.async_wait([ioContextRef = std::weak_ptr(ioContext)](boost::system::error_code const& error, int signalNumber) mutable - { - SignalHandler(std::move(ioContextRef), error, signalNumber); - }); + signals.async_wait(SignalHandler); // Set process priority according to configuration settings SetProcessPriority("server.bnetserver", sConfigMgr->GetIntDefault(CONFIG_PROCESSOR_AFFINITY, 0), sConfigMgr->GetBoolDefault(CONFIG_HIGH_PRIORITY, false)); @@ -287,16 +294,11 @@ int main(int argc, char** argv) } #endif - // Start the io service worker loop - ioContext->run(); - - banExpiryCheckTimer->cancel(); - dbPingTimer->cancel(); + while (!Stopped) + std::this_thread::sleep_for(50ms); TC_LOG_INFO("server.bnetserver", "Halting process..."); - signals.cancel(); - return 0; } @@ -325,11 +327,10 @@ void StopDB() MySQL::Library_End(); } -void SignalHandler(std::weak_ptr ioContextRef, boost::system::error_code const& error, int /*signalNumber*/) +void SignalHandler(boost::system::error_code const& error, int /*signalNumber*/) { if (!error) - if (std::shared_ptr ioContext = ioContextRef.lock()) - ioContext->stop(); + Stopped = true; } void KeepDatabaseAliveHandler(std::weak_ptr dbPingTimerRef, int32 dbPingInterval, boost::system::error_code const& error) diff --git a/src/server/bnetserver/REST/LoginHttpSession.h b/src/server/bnetserver/REST/LoginHttpSession.h index 9690b02f14a..4c2fa13dde2 100644 --- a/src/server/bnetserver/REST/LoginHttpSession.h +++ b/src/server/bnetserver/REST/LoginHttpSession.h @@ -30,7 +30,7 @@ struct LoginSessionState : public Trinity::Net::Http::SessionState std::unique_ptr Srp; }; -class LoginHttpSession : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this +class LoginHttpSession final : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this { public: static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID="; diff --git a/src/server/bnetserver/REST/LoginRESTService.cpp b/src/server/bnetserver/REST/LoginRESTService.cpp index b75d8d93ba8..2390a25308c 100644 --- a/src/server/bnetserver/REST/LoginRESTService.cpp +++ b/src/server/bnetserver/REST/LoginRESTService.cpp @@ -23,7 +23,6 @@ #include "CryptoRandom.h" #include "DatabaseEnv.h" #include "IpNetwork.h" -#include "IteratorPair.h" #include "ProtobufJSON.h" #include "Resolver.h" #include "SslContext.h" @@ -40,6 +39,33 @@ LoginRESTService& LoginRESTService::Instance() bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount) { + Trinity::Net::Resolver resolver(ioContext); + + _externalHostname = sConfigMgr->GetStringDefault("LoginREST.ExternalAddress"sv, "127.0.0.1"); + + std::ranges::transform(resolver.ResolveAll(_externalHostname, ""), + std::back_inserter(_addresses), + [](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); }); + + if (_addresses.empty()) + { + TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.ExternalAddress {}", _externalHostname); + return false; + } + + _localHostname = sConfigMgr->GetStringDefault("LoginREST.LocalAddress"sv, "127.0.0.1"); + _firstLocalAddressIndex = _addresses.size(); + + std::ranges::transform(resolver.ResolveAll(_localHostname, ""), + std::back_inserter(_addresses), + [](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); }); + + if (_addresses.size() == _firstLocalAddressIndex) + { + TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.LocalAddress {}", _localHostname); + return false; + } + if (!HttpService::StartNetwork(ioContext, bindIp, port, threadCount)) return false; @@ -75,36 +101,8 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st return HandlePostRefreshLoginTicket(std::move(session), context); }); - _bindIP = bindIp; _port = port; - Trinity::Net::Resolver resolver(ioContext); - - _externalHostname = sConfigMgr->GetStringDefault("LoginREST.ExternalAddress"sv, "127.0.0.1"); - - std::ranges::transform(resolver.ResolveAll(_externalHostname, ""), - std::back_inserter(_addresses), - [](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); }); - - if (_addresses.empty()) - { - TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.ExternalAddress {}", _externalHostname); - return false; - } - - _localHostname = sConfigMgr->GetStringDefault("LoginREST.LocalAddress"sv, "127.0.0.1"); - _firstLocalAddressIndex = _addresses.size(); - - std::ranges::transform(resolver.ResolveAll(_localHostname, ""), - std::back_inserter(_addresses), - [](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); }); - - if (_addresses.size() == _firstLocalAddressIndex) - { - TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.LocalAddress {}", _localHostname); - return false; - } - // set up form inputs JSON::Login::FormInput* input; _formInputs.set_type(JSON::Login::LOGIN_FORM); @@ -129,10 +127,6 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st MigrateLegacyPasswordHashes(); - _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) - { - OnSocketOpen(std::move(sock), threadIndex); - }); return true; } diff --git a/src/server/bnetserver/REST/LoginRESTService.h b/src/server/bnetserver/REST/LoginRESTService.h index aefa6fc2296..f7d53bc1c2a 100644 --- a/src/server/bnetserver/REST/LoginRESTService.h +++ b/src/server/bnetserver/REST/LoginRESTService.h @@ -79,7 +79,6 @@ private: void MigrateLegacyPasswordHashes() const; JSON::Login::FormInputs _formInputs; - std::string _bindIP; uint16 _port; std::string _externalHostname; std::string _localHostname; diff --git a/src/server/bnetserver/Server/SessionManager.cpp b/src/server/bnetserver/Server/SessionManager.cpp index b1dac140084..28575308e44 100644 --- a/src/server/bnetserver/Server/SessionManager.cpp +++ b/src/server/bnetserver/Server/SessionManager.cpp @@ -17,23 +17,6 @@ #include "SessionManager.h" -bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) -{ - if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount)) - return false; - - _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) - { - OnSocketOpen(std::move(sock), threadIndex); - }); - return true; -} - -Battlenet::SessionNetworkThread* Battlenet::SessionManager::CreateThreads() const -{ - return new SessionNetworkThread[GetNetworkThreadCount()]; -} - Battlenet::SessionManager& Battlenet::SessionManager::Instance() { static SessionManager instance; diff --git a/src/server/bnetserver/Server/SessionManager.h b/src/server/bnetserver/Server/SessionManager.h index aa2ecdbf9de..8eb426ec49d 100644 --- a/src/server/bnetserver/Server/SessionManager.h +++ b/src/server/bnetserver/Server/SessionManager.h @@ -23,21 +23,21 @@ namespace Battlenet { - class SessionNetworkThread final : public Trinity::Net::NetworkThread + class SessionNetworkThread final : public Trinity::Net::NetworkThread { }; - class SessionManager final : public Trinity::Net::SocketMgr + struct SessionManagerTraits { - using BaseSocketMgr = SocketMgr; + using Self = class SessionManager; + using SocketType = Session; + using ThreadType = SessionNetworkThread; + }; + class SessionManager final : public Trinity::Net::SocketMgr + { public: static SessionManager& Instance(); - - bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount = 1) override; - - protected: - SessionNetworkThread* CreateThreads() const override; }; } diff --git a/src/server/game/Server/WorldSocketMgr.cpp b/src/server/game/Server/WorldSocketMgr.cpp index ee2f2f478a2..9a4779a16fe 100644 --- a/src/server/game/Server/WorldSocketMgr.cpp +++ b/src/server/game/Server/WorldSocketMgr.cpp @@ -31,7 +31,7 @@ void WorldSocketThread::SocketRemoved(std::shared_ptrconst& sock) sScriptMgr->OnSocketClose(sock); } -WorldSocketMgr::WorldSocketMgr() : BaseSocketMgr(), _socketSystemSendBufferSize(-1), _socketApplicationSendBufferSize(65536), _tcpNoDelay(true) +WorldSocketMgr::WorldSocketMgr() : _socketSystemSendBufferSize(-1), _socketApplicationSendBufferSize(65536), _tcpNoDelay(true) { } @@ -61,26 +61,21 @@ bool WorldSocketMgr::StartNetwork(Trinity::Asio::IoContext& ioContext, std::stri return false; } - if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount)) + if (!SocketMgr::StartNetwork(ioContext, bindIp, port, threadCount)) return false; - _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) - { - OnSocketOpen(std::move(sock), threadIndex); - }); - sScriptMgr->OnNetworkStart(); return true; } void WorldSocketMgr::StopNetwork() { - BaseSocketMgr::StopNetwork(); + SocketMgr::StopNetwork(); sScriptMgr->OnNetworkStop(); } -void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) +void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock) { // set some options here if (_socketSystemSendBufferSize >= 0) @@ -106,10 +101,5 @@ void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint3 } } - BaseSocketMgr::OnSocketOpen(std::move(sock), threadIndex); -} - -WorldSocketThread* WorldSocketMgr::CreateThreads() const -{ - return new WorldSocketThread[GetNetworkThreadCount()]; + SocketMgr::OnSocketOpen(std::move(sock)); } diff --git a/src/server/game/Server/WorldSocketMgr.h b/src/server/game/Server/WorldSocketMgr.h index f9aa81df768..3f0f81530c0 100644 --- a/src/server/game/Server/WorldSocketMgr.h +++ b/src/server/game/Server/WorldSocketMgr.h @@ -21,7 +21,7 @@ #include "SocketMgr.h" #include "WorldSocket.h" -class WorldSocketThread final : public Trinity::Net::NetworkThread +class WorldSocketThread final : public Trinity::Net::NetworkThread { public: void SocketAdded(std::shared_ptr const& sock) override; @@ -29,11 +29,18 @@ public: void SocketRemoved(std::shared_ptrconst& sock) override; }; -/// Manages all sockets connected to peers and network threads -class TC_GAME_API WorldSocketMgr final : public Trinity::Net::SocketMgr -{ - using BaseSocketMgr = SocketMgr; +class WorldSocketMgr; +struct WorldSocketMgrTraits +{ + using Self = WorldSocketMgr; + using SocketType = WorldSocket; + using ThreadType = WorldSocketThread; +}; + +/// Manages all sockets connected to peers and network threads +class TC_GAME_API WorldSocketMgr final : public Trinity::Net::SocketMgr +{ public: ~WorldSocketMgr(); @@ -45,15 +52,13 @@ public: /// Stops all network threads, It will wait for all running threads . void StopNetwork() override; - void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) override; + void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock) override; std::size_t GetApplicationSendBufferSize() const { return _socketApplicationSendBufferSize; } protected: WorldSocketMgr(); - WorldSocketThread* CreateThreads() const override; - private: int32 _socketSystemSendBufferSize; int32 _socketApplicationSendBufferSize; diff --git a/src/server/worldserver/Main.cpp b/src/server/worldserver/Main.cpp index 256069f3a06..cb6903080fa 100644 --- a/src/server/worldserver/Main.cpp +++ b/src/server/worldserver/Main.cpp @@ -268,7 +268,7 @@ int main(int argc, char** argv) } // Set signal handlers (this must be done before starting IoContext threads, because otherwise they would unblock and exit) - boost::asio::signal_set signals(*ioContext, SIGINT, SIGTERM); + boost::asio::basic_signal_set signals(*ioContext, SIGINT, SIGTERM); #if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS signals.add(SIGBREAK); #endif @@ -284,7 +284,7 @@ int main(int argc, char** argv) for (int i = 0; i < numThreads; ++i) threadPool->PostWork([ioContext]() { ioContext->run(); }); - auto ioContextStopHandle = Trinity::make_unique_ptr_with_deleter<&Trinity::Asio::IoContext::stop>(ioContext.get()); + auto signalsCancelHandle = Trinity::make_unique_ptr_with_deleter<[](auto* s) { boost::system::error_code ec; s->cancel(ec); }>(&signals); // Set process priority according to configuration settings SetProcessPriority("server.worldserver", sConfigMgr->GetIntDefault(CONFIG_PROCESSOR_AFFINITY, 0), sConfigMgr->GetBoolDefault(CONFIG_HIGH_PRIORITY, false)); @@ -445,10 +445,6 @@ int main(int argc, char** argv) WorldPackets::Auth::ConnectTo::ShutdownEncryption(); WorldPackets::Auth::EnterEncryptedMode::ShutdownEncryption(); - ioContextStopHandle.reset(); - - threadPool.reset(); - sLog->SetSynchronous(); sScriptMgr->OnShutdown(); @@ -636,14 +632,14 @@ std::unique_ptr StartRaSocketAcceptor(Trinity::Asio if (!acceptor->Bind()) { TC_LOG_ERROR("server.worldserver", "Failed to bind RA socket acceptor"); - return nullptr; + acceptor = nullptr; + return acceptor; } - acceptor->AsyncAccept([](Trinity::Net::IoContextTcpSocket&& sock, uint32 /*threadIndex*/) - { - std::make_shared(std::move(sock))->Start(); + acceptor->AsyncAccept( + [&] { return &ioContext; }, + [](Trinity::Net::IoContextTcpSocket&& sock) { std::make_shared(std::move(sock))->Start(); }); - }); return acceptor; }