From b69a1a71c6b3c604f5eed1d0752f66981a4efc25 Mon Sep 17 00:00:00 2001 From: Shauren Date: Mon, 22 Dec 2025 13:06:28 +0100 Subject: [PATCH] Core/Network: Fix invalid NetworkThread array access for derived classes that have additional data members (only HttpService threads were affected) --- src/common/network/Http/HttpService.h | 46 ++++++++++--------- src/common/network/SocketMgr.h | 9 ++-- src/server/bnetserver/REST/LoginRESTService.h | 2 +- .../bnetserver/Server/SessionManager.cpp | 5 +- src/server/bnetserver/Server/SessionManager.h | 10 ++-- src/server/game/Server/WorldSocketMgr.cpp | 23 ++++------ src/server/game/Server/WorldSocketMgr.h | 16 +++++-- 7 files changed, 62 insertions(+), 49 deletions(-) diff --git a/src/common/network/Http/HttpService.h b/src/common/network/Http/HttpService.h index 1549893576f..bca573f431d 100644 --- a/src/common/network/Http/HttpService.h +++ b/src/common/network/Http/HttpService.h @@ -113,8 +113,27 @@ template concept HttpRequestHandler = invocable_r, RequestContext&>; template -class HttpService : public SocketMgr, public DispatcherService, public SessionService +class HttpNetworkThread final : public NetworkThread { +public: + explicit HttpNetworkThread(SessionService* service) : _service(service) { } + +protected: + void SocketRemoved(std::shared_ptr const& session) override + { + if (Optional id = session->GetSessionId()) + _service->MarkSessionInactive(*id); + } + +private: + SessionService* _service = nullptr; +}; + +template +class HttpService : public SocketMgr>, public DispatcherService, public SessionService +{ + using BaseSocketMgr = SocketMgr>; + public: HttpService(std::string_view loggerSuffix) : DispatcherService(loggerSuffix), SessionService(loggerSuffix), _ioContext(nullptr), _logger("server.http.") { @@ -123,7 +142,7 @@ public: bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount = 1) override { - if (!SocketMgr::StartNetwork(ioContext, bindIp, port, threadCount)) + if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount)) return false; SessionService::Start(ioContext); @@ -133,7 +152,7 @@ public: void StopNetwork() override { SessionService::Stop(); - SocketMgr::StopNetwork(); + BaseSocketMgr::StopNetwork(); } // http handling @@ -157,26 +176,11 @@ public: } protected: - class Thread : public NetworkThread + HttpNetworkThread* CreateThreads() const final { - protected: - void SocketRemoved(std::shared_ptr const& session) override - { - if (Optional id = session->GetSessionId()) - _service->MarkSessionInactive(*id); - } - - private: - friend HttpService; - - SessionService* _service; - }; - - NetworkThread* CreateThreads() const override - { - Thread* threads = new Thread[this->GetNetworkThreadCount()]; + HttpNetworkThread* threads = static_cast*>(::operator new(sizeof(HttpNetworkThread) * this->GetNetworkThreadCount())); for (int32 i = 0; i < this->GetNetworkThreadCount(); ++i) - threads[i]._service = const_cast(this); + new (&threads[i]) HttpNetworkThread(const_cast(this)); return threads; } diff --git a/src/common/network/SocketMgr.h b/src/common/network/SocketMgr.h index 07252355308..fd17be06811 100644 --- a/src/common/network/SocketMgr.h +++ b/src/common/network/SocketMgr.h @@ -27,9 +27,12 @@ namespace Trinity::Net { -template +template class SocketMgr { + static_assert(std::is_base_of_v, ThreadType>); + static_assert(std::is_final_v); + public: SocketMgr(SocketMgr const&) = delete; SocketMgr(SocketMgr&&) = delete; @@ -135,10 +138,10 @@ protected: { } - virtual NetworkThread* CreateThreads() const = 0; + virtual ThreadType* CreateThreads() const = 0; std::unique_ptr _acceptor; - std::unique_ptr[]> _threads; + std::unique_ptr _threads; int32 _threadCount; }; } diff --git a/src/server/bnetserver/REST/LoginRESTService.h b/src/server/bnetserver/REST/LoginRESTService.h index e9bc68ffdf9..aefa6fc2296 100644 --- a/src/server/bnetserver/REST/LoginRESTService.h +++ b/src/server/bnetserver/REST/LoginRESTService.h @@ -42,7 +42,7 @@ enum class BanMode BAN_ACCOUNT = 1 }; -class LoginRESTService : public Trinity::Net::Http::HttpService +class LoginRESTService final : public Trinity::Net::Http::HttpService { public: using RequestHandlerResult = Trinity::Net::Http::RequestHandlerResult; diff --git a/src/server/bnetserver/Server/SessionManager.cpp b/src/server/bnetserver/Server/SessionManager.cpp index 4c5b532ee60..b1dac140084 100644 --- a/src/server/bnetserver/Server/SessionManager.cpp +++ b/src/server/bnetserver/Server/SessionManager.cpp @@ -16,7 +16,6 @@ */ #include "SessionManager.h" -#include "Util.h" bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) { @@ -30,9 +29,9 @@ bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext return true; } -Trinity::Net::NetworkThread* Battlenet::SessionManager::CreateThreads() const +Battlenet::SessionNetworkThread* Battlenet::SessionManager::CreateThreads() const { - return new Trinity::Net::NetworkThread[GetNetworkThreadCount()]; + return new SessionNetworkThread[GetNetworkThreadCount()]; } Battlenet::SessionManager& Battlenet::SessionManager::Instance() diff --git a/src/server/bnetserver/Server/SessionManager.h b/src/server/bnetserver/Server/SessionManager.h index 528ece8739e..aa2ecdbf9de 100644 --- a/src/server/bnetserver/Server/SessionManager.h +++ b/src/server/bnetserver/Server/SessionManager.h @@ -23,9 +23,13 @@ namespace Battlenet { - class SessionManager : public Trinity::Net::SocketMgr + class SessionNetworkThread final : public Trinity::Net::NetworkThread { - typedef SocketMgr BaseSocketMgr; + }; + + class SessionManager final : public Trinity::Net::SocketMgr + { + using BaseSocketMgr = SocketMgr; public: static SessionManager& Instance(); @@ -33,7 +37,7 @@ namespace Battlenet bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount = 1) override; protected: - Trinity::Net::NetworkThread* CreateThreads() const override; + SessionNetworkThread* CreateThreads() const override; }; } diff --git a/src/server/game/Server/WorldSocketMgr.cpp b/src/server/game/Server/WorldSocketMgr.cpp index 4205b6912f0..ee2f2f478a2 100644 --- a/src/server/game/Server/WorldSocketMgr.cpp +++ b/src/server/game/Server/WorldSocketMgr.cpp @@ -17,24 +17,19 @@ #include "WorldSocketMgr.h" #include "Config.h" -#include "NetworkThread.h" #include "ScriptMgr.h" #include -class WorldSocketThread : public Trinity::Net::NetworkThread +void WorldSocketThread::SocketAdded(std::shared_ptr const& sock) { -public: - void SocketAdded(std::shared_ptr const& sock) override - { - sock->SetSendBufferSize(sWorldSocketMgr.GetApplicationSendBufferSize()); - sScriptMgr->OnSocketOpen(sock); - } + sock->SetSendBufferSize(sWorldSocketMgr.GetApplicationSendBufferSize()); + sScriptMgr->OnSocketOpen(sock); +} - void SocketRemoved(std::shared_ptrconst& sock) override - { - sScriptMgr->OnSocketClose(sock); - } -}; +void WorldSocketThread::SocketRemoved(std::shared_ptrconst& sock) +{ + sScriptMgr->OnSocketClose(sock); +} WorldSocketMgr::WorldSocketMgr() : BaseSocketMgr(), _socketSystemSendBufferSize(-1), _socketApplicationSendBufferSize(65536), _tcpNoDelay(true) { @@ -114,7 +109,7 @@ void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint3 BaseSocketMgr::OnSocketOpen(std::move(sock), threadIndex); } -Trinity::Net::NetworkThread* WorldSocketMgr::CreateThreads() const +WorldSocketThread* WorldSocketMgr::CreateThreads() const { return new WorldSocketThread[GetNetworkThreadCount()]; } diff --git a/src/server/game/Server/WorldSocketMgr.h b/src/server/game/Server/WorldSocketMgr.h index 9f905f23cd2..f9aa81df768 100644 --- a/src/server/game/Server/WorldSocketMgr.h +++ b/src/server/game/Server/WorldSocketMgr.h @@ -21,10 +21,18 @@ #include "SocketMgr.h" #include "WorldSocket.h" -/// Manages all sockets connected to peers and network threads -class TC_GAME_API WorldSocketMgr : public Trinity::Net::SocketMgr +class WorldSocketThread final : public Trinity::Net::NetworkThread { - typedef SocketMgr BaseSocketMgr; +public: + void SocketAdded(std::shared_ptr const& sock) override; + + void SocketRemoved(std::shared_ptrconst& sock) override; +}; + +/// Manages all sockets connected to peers and network threads +class TC_GAME_API WorldSocketMgr final : public Trinity::Net::SocketMgr +{ + using BaseSocketMgr = SocketMgr; public: ~WorldSocketMgr(); @@ -44,7 +52,7 @@ public: protected: WorldSocketMgr(); - Trinity::Net::NetworkThread* CreateThreads() const override; + WorldSocketThread* CreateThreads() const override; private: int32 _socketSystemSendBufferSize;