Core/Network: Refactor AsyncAcceptor to use async_accept overload producing sockets through argument instead of having to preallocate it

* Also improve main() cleanup to fully process all queued async operations (including their cancellations)
This commit is contained in:
Shauren
2026-01-12 20:59:19 +01:00
parent c2c5c70fb1
commit 585e170ad6
17 changed files with 187 additions and 190 deletions

View File

@@ -83,11 +83,13 @@ void ServiceStatusWatcher(std::weak_ptr<Trinity::Asio::DeadlineTimer> serviceSta
bool StartDB();
void StopDB();
void SignalHandler(std::weak_ptr<Trinity::Asio::IoContext> ioContextRef, boost::system::error_code const& error, int signalNumber);
void SignalHandler(boost::system::error_code const& error, int signalNumber);
void KeepDatabaseAliveHandler(std::weak_ptr<Trinity::Asio::DeadlineTimer> dbPingTimerRef, int32 dbPingInterval, boost::system::error_code const& error);
void BanExpiryHandler(std::weak_ptr<Trinity::Asio::DeadlineTimer> banExpiryCheckTimerRef, int32 banExpiryCheckInterval, boost::system::error_code const& error);
variables_map GetConsoleArguments(int argc, char** argv, fs::path& configFile, fs::path& configDir, std::string& winServiceAction);
std::atomic<bool> Stopped;
int main(int argc, char** argv)
{
signal(SIGABRT, &Trinity::AbortHandler);
@@ -203,6 +205,14 @@ int main(int argc, char** argv)
std::shared_ptr<Trinity::Asio::IoContext> ioContext = std::make_shared<Trinity::Asio::IoContext>();
auto ioContextWork = boost::asio::make_work_guard(ioContext->get_executor());
std::thread ioContextThread(&Trinity::Asio::IoContext::run, ioContext.get());
auto threadJoinHandle = Trinity::make_unique_ptr_with_deleter<[](std::thread* t) { t->join(); }>(&ioContextThread);
auto ioContextWorkGuardHandle = Trinity::make_unique_ptr_with_deleter<[](auto* wg) { wg->reset(); }>(&ioContextWork);
Trinity::Net::ScanLocalNetworks();
std::string httpBindIp = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0");
@@ -236,7 +246,7 @@ int main(int argc, char** argv)
std::string bindIp = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0");
if (!sSessionMgr.StartNetwork(*ioContext, bindIp, bnport))
if (!sSessionMgr.StartNetwork(*ioContext, bindIp, bnport, 1))
{
TC_LOG_ERROR("server.bnetserver", "Failed to initialize network");
return 1;
@@ -245,14 +255,11 @@ int main(int argc, char** argv)
auto sSessionMgrHandle = Trinity::make_unique_ptr_with_deleter<&Battlenet::SessionManager::StopNetwork>(&sSessionMgr);
// Set signal handlers
boost::asio::signal_set signals(*ioContext, SIGINT, SIGTERM);
boost::asio::basic_signal_set<Trinity::Asio::IoContext::Executor> signals(*ioContext, SIGINT, SIGTERM);
#if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS
signals.add(SIGBREAK);
#endif
signals.async_wait([ioContextRef = std::weak_ptr(ioContext)](boost::system::error_code const& error, int signalNumber) mutable
{
SignalHandler(std::move(ioContextRef), error, signalNumber);
});
signals.async_wait(SignalHandler);
// Set process priority according to configuration settings
SetProcessPriority("server.bnetserver", sConfigMgr->GetIntDefault(CONFIG_PROCESSOR_AFFINITY, 0), sConfigMgr->GetBoolDefault(CONFIG_HIGH_PRIORITY, false));
@@ -287,16 +294,11 @@ int main(int argc, char** argv)
}
#endif
// Start the io service worker loop
ioContext->run();
banExpiryCheckTimer->cancel();
dbPingTimer->cancel();
while (!Stopped)
std::this_thread::sleep_for(50ms);
TC_LOG_INFO("server.bnetserver", "Halting process...");
signals.cancel();
return 0;
}
@@ -325,11 +327,10 @@ void StopDB()
MySQL::Library_End();
}
void SignalHandler(std::weak_ptr<Trinity::Asio::IoContext> ioContextRef, boost::system::error_code const& error, int /*signalNumber*/)
void SignalHandler(boost::system::error_code const& error, int /*signalNumber*/)
{
if (!error)
if (std::shared_ptr<Trinity::Asio::IoContext> ioContext = ioContextRef.lock())
ioContext->stop();
Stopped = true;
}
void KeepDatabaseAliveHandler(std::weak_ptr<Trinity::Asio::DeadlineTimer> dbPingTimerRef, int32 dbPingInterval, boost::system::error_code const& error)

View File

@@ -30,7 +30,7 @@ struct LoginSessionState : public Trinity::Net::Http::SessionState
std::unique_ptr<Trinity::Crypto::SRP::BnetSRP6Base> Srp;
};
class LoginHttpSession : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSession>
class LoginHttpSession final : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSession>
{
public:
static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID=";

View File

@@ -23,7 +23,6 @@
#include "CryptoRandom.h"
#include "DatabaseEnv.h"
#include "IpNetwork.h"
#include "IteratorPair.h"
#include "ProtobufJSON.h"
#include "Resolver.h"
#include "SslContext.h"
@@ -40,6 +39,33 @@ LoginRESTService& LoginRESTService::Instance()
bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount)
{
Trinity::Net::Resolver resolver(ioContext);
_externalHostname = sConfigMgr->GetStringDefault("LoginREST.ExternalAddress"sv, "127.0.0.1");
std::ranges::transform(resolver.ResolveAll(_externalHostname, ""),
std::back_inserter(_addresses),
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
if (_addresses.empty())
{
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.ExternalAddress {}", _externalHostname);
return false;
}
_localHostname = sConfigMgr->GetStringDefault("LoginREST.LocalAddress"sv, "127.0.0.1");
_firstLocalAddressIndex = _addresses.size();
std::ranges::transform(resolver.ResolveAll(_localHostname, ""),
std::back_inserter(_addresses),
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
if (_addresses.size() == _firstLocalAddressIndex)
{
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.LocalAddress {}", _localHostname);
return false;
}
if (!HttpService::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
@@ -75,36 +101,8 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st
return HandlePostRefreshLoginTicket(std::move(session), context);
});
_bindIP = bindIp;
_port = port;
Trinity::Net::Resolver resolver(ioContext);
_externalHostname = sConfigMgr->GetStringDefault("LoginREST.ExternalAddress"sv, "127.0.0.1");
std::ranges::transform(resolver.ResolveAll(_externalHostname, ""),
std::back_inserter(_addresses),
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
if (_addresses.empty())
{
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.ExternalAddress {}", _externalHostname);
return false;
}
_localHostname = sConfigMgr->GetStringDefault("LoginREST.LocalAddress"sv, "127.0.0.1");
_firstLocalAddressIndex = _addresses.size();
std::ranges::transform(resolver.ResolveAll(_localHostname, ""),
std::back_inserter(_addresses),
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
if (_addresses.size() == _firstLocalAddressIndex)
{
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.LocalAddress {}", _localHostname);
return false;
}
// set up form inputs
JSON::Login::FormInput* input;
_formInputs.set_type(JSON::Login::LOGIN_FORM);
@@ -129,10 +127,6 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st
MigrateLegacyPasswordHashes();
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
{
OnSocketOpen(std::move(sock), threadIndex);
});
return true;
}

View File

@@ -79,7 +79,6 @@ private:
void MigrateLegacyPasswordHashes() const;
JSON::Login::FormInputs _formInputs;
std::string _bindIP;
uint16 _port;
std::string _externalHostname;
std::string _localHostname;

View File

@@ -17,23 +17,6 @@
#include "SessionManager.h"
bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
{
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
{
OnSocketOpen(std::move(sock), threadIndex);
});
return true;
}
Battlenet::SessionNetworkThread* Battlenet::SessionManager::CreateThreads() const
{
return new SessionNetworkThread[GetNetworkThreadCount()];
}
Battlenet::SessionManager& Battlenet::SessionManager::Instance()
{
static SessionManager instance;

View File

@@ -23,21 +23,21 @@
namespace Battlenet
{
class SessionNetworkThread final : public Trinity::Net::NetworkThread<Session>
class SessionNetworkThread final : public Trinity::Net::NetworkThread<Session, SessionNetworkThread>
{
};
class SessionManager final : public Trinity::Net::SocketMgr<Session, SessionNetworkThread>
struct SessionManagerTraits
{
using BaseSocketMgr = SocketMgr;
using Self = class SessionManager;
using SocketType = Session;
using ThreadType = SessionNetworkThread;
};
class SessionManager final : public Trinity::Net::SocketMgr<SessionManagerTraits>
{
public:
static SessionManager& Instance();
bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount = 1) override;
protected:
SessionNetworkThread* CreateThreads() const override;
};
}

View File

@@ -31,7 +31,7 @@ void WorldSocketThread::SocketRemoved(std::shared_ptr<WorldSocket>const& sock)
sScriptMgr->OnSocketClose(sock);
}
WorldSocketMgr::WorldSocketMgr() : BaseSocketMgr(), _socketSystemSendBufferSize(-1), _socketApplicationSendBufferSize(65536), _tcpNoDelay(true)
WorldSocketMgr::WorldSocketMgr() : _socketSystemSendBufferSize(-1), _socketApplicationSendBufferSize(65536), _tcpNoDelay(true)
{
}
@@ -61,26 +61,21 @@ bool WorldSocketMgr::StartNetwork(Trinity::Asio::IoContext& ioContext, std::stri
return false;
}
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
if (!SocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
{
OnSocketOpen(std::move(sock), threadIndex);
});
sScriptMgr->OnNetworkStart();
return true;
}
void WorldSocketMgr::StopNetwork()
{
BaseSocketMgr::StopNetwork();
SocketMgr::StopNetwork();
sScriptMgr->OnNetworkStop();
}
void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock)
{
// set some options here
if (_socketSystemSendBufferSize >= 0)
@@ -106,10 +101,5 @@ void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint3
}
}
BaseSocketMgr::OnSocketOpen(std::move(sock), threadIndex);
}
WorldSocketThread* WorldSocketMgr::CreateThreads() const
{
return new WorldSocketThread[GetNetworkThreadCount()];
SocketMgr::OnSocketOpen(std::move(sock));
}

View File

@@ -21,7 +21,7 @@
#include "SocketMgr.h"
#include "WorldSocket.h"
class WorldSocketThread final : public Trinity::Net::NetworkThread<WorldSocket>
class WorldSocketThread final : public Trinity::Net::NetworkThread<WorldSocket, WorldSocketThread>
{
public:
void SocketAdded(std::shared_ptr<WorldSocket> const& sock) override;
@@ -29,11 +29,18 @@ public:
void SocketRemoved(std::shared_ptr<WorldSocket>const& sock) override;
};
/// Manages all sockets connected to peers and network threads
class TC_GAME_API WorldSocketMgr final : public Trinity::Net::SocketMgr<WorldSocket, WorldSocketThread>
{
using BaseSocketMgr = SocketMgr;
class WorldSocketMgr;
struct WorldSocketMgrTraits
{
using Self = WorldSocketMgr;
using SocketType = WorldSocket;
using ThreadType = WorldSocketThread;
};
/// Manages all sockets connected to peers and network threads
class TC_GAME_API WorldSocketMgr final : public Trinity::Net::SocketMgr<WorldSocketMgrTraits>
{
public:
~WorldSocketMgr();
@@ -45,15 +52,13 @@ public:
/// Stops all network threads, It will wait for all running threads .
void StopNetwork() override;
void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) override;
void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock) override;
std::size_t GetApplicationSendBufferSize() const { return _socketApplicationSendBufferSize; }
protected:
WorldSocketMgr();
WorldSocketThread* CreateThreads() const override;
private:
int32 _socketSystemSendBufferSize;
int32 _socketApplicationSendBufferSize;

View File

@@ -268,7 +268,7 @@ int main(int argc, char** argv)
}
// Set signal handlers (this must be done before starting IoContext threads, because otherwise they would unblock and exit)
boost::asio::signal_set signals(*ioContext, SIGINT, SIGTERM);
boost::asio::basic_signal_set<Trinity::Asio::IoContext::Executor> signals(*ioContext, SIGINT, SIGTERM);
#if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS
signals.add(SIGBREAK);
#endif
@@ -284,7 +284,7 @@ int main(int argc, char** argv)
for (int i = 0; i < numThreads; ++i)
threadPool->PostWork([ioContext]() { ioContext->run(); });
auto ioContextStopHandle = Trinity::make_unique_ptr_with_deleter<&Trinity::Asio::IoContext::stop>(ioContext.get());
auto signalsCancelHandle = Trinity::make_unique_ptr_with_deleter<[](auto* s) { boost::system::error_code ec; s->cancel(ec); }>(&signals);
// Set process priority according to configuration settings
SetProcessPriority("server.worldserver", sConfigMgr->GetIntDefault(CONFIG_PROCESSOR_AFFINITY, 0), sConfigMgr->GetBoolDefault(CONFIG_HIGH_PRIORITY, false));
@@ -445,10 +445,6 @@ int main(int argc, char** argv)
WorldPackets::Auth::ConnectTo::ShutdownEncryption();
WorldPackets::Auth::EnterEncryptedMode::ShutdownEncryption();
ioContextStopHandle.reset();
threadPool.reset();
sLog->SetSynchronous();
sScriptMgr->OnShutdown();
@@ -636,14 +632,14 @@ std::unique_ptr<Trinity::Net::AsyncAcceptor> StartRaSocketAcceptor(Trinity::Asio
if (!acceptor->Bind())
{
TC_LOG_ERROR("server.worldserver", "Failed to bind RA socket acceptor");
return nullptr;
acceptor = nullptr;
return acceptor;
}
acceptor->AsyncAccept([](Trinity::Net::IoContextTcpSocket&& sock, uint32 /*threadIndex*/)
{
std::make_shared<RASession>(std::move(sock))->Start();
acceptor->AsyncAccept(
[&] { return &ioContext; },
[](Trinity::Net::IoContextTcpSocket&& sock) { std::make_shared<RASession>(std::move(sock))->Start(); });
});
return acceptor;
}