Core/Network: Refactor AsyncAcceptor to use async_accept overload producing sockets through argument instead of having to preallocate it

* Also improve main() cleanup to fully process all queued async operations (including their cancellations)
This commit is contained in:
Shauren
2026-01-12 20:59:19 +01:00
parent c2c5c70fb1
commit 585e170ad6
17 changed files with 187 additions and 190 deletions

View File

@@ -35,6 +35,12 @@ namespace boost
template <typename Time>
struct time_traits;
template <typename Clock>
struct wait_traits;
template <typename Protocol, typename Clock, typename WaitTraits>
class basic_socket_iostream;
namespace ip
{
class address;

View File

@@ -15,8 +15,8 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef IoContext_h__
#define IoContext_h__
#ifndef TRINITYCORE_IO_CONTEXT_H
#define TRINITYCORE_IO_CONTEXT_H
#include <boost/asio/bind_executor.hpp>
#include <boost/asio/io_context.hpp>
@@ -56,20 +56,20 @@ namespace Trinity
return boost::asio::post(ioContext, std::forward<T>(t));
}
template<typename T>
inline decltype(auto) post(boost::asio::io_context::executor_type& executor, T&& t)
{
return boost::asio::post(executor, std::forward<T>(t));
}
using boost::asio::bind_executor;
template<typename T>
inline decltype(auto) post(boost::asio::io_context::executor_type const& executor, T&& t)
{
return boost::asio::post(executor.context(), bind_executor(executor, std::forward<T>(t)));
}
template<typename T>
inline decltype(auto) get_io_context(T&& ioObject)
{
return ioObject.get_executor().context();
return std::forward<T>(ioObject).get_executor().context();
}
}
}
#endif // IoContext_h__
#endif // TRINITYCORE_IO_CONTEXT_H

View File

@@ -27,6 +27,8 @@ namespace Trinity
class ThreadPool
{
public:
using executor_type = boost::asio::thread_pool::executor_type;
explicit ThreadPool(std::size_t numThreads = std::thread::hardware_concurrency()) : _impl(numThreads) { }
template<typename T>

View File

@@ -18,6 +18,7 @@
#ifndef TRINITYCORE_ASYNC_ACCEPTOR_H
#define TRINITYCORE_ASYNC_ACCEPTOR_H
#include "Concepts.h"
#include "IoContext.h"
#include "IpAddress.h"
#include "Log.h"
@@ -25,40 +26,40 @@
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/ip/v6_only.hpp>
#include <atomic>
#include <functional>
#define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections
namespace Trinity::Net
{
template <typename Callable>
concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&, uint32>;
concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&>;
template <typename Callable>
concept SelectIoContextForNewSocketFn = Trinity::invocable_r<Callable, Asio::IoContext*>;
class AsyncAcceptor
{
public:
AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
_acceptor(ioContext), _endpoint(make_address(bindIp), port),
_socket(ioContext), _closed(false), _socketFactory([this] { return DefaultSocketFactory(); })
_acceptor(ioContext), _endpoint(make_address(bindIp), port), _closed(false)
{
}
template <AcceptCallback Callback>
void AsyncAccept(Callback&& acceptCallback)
template <SelectIoContextForNewSocketFn SelectIoContextForNewSocket, AcceptCallback Callback>
void AsyncAccept(SelectIoContextForNewSocket&& selectIoContext, Callback&& acceptCallback)
{
auto [tmpSocket, tmpThreadIndex] = _socketFactory();
// TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures)
IoContextTcpSocket* socket = tmpSocket;
uint32 threadIndex = tmpThreadIndex;
_acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error) mutable
Asio::IoContext* context = selectIoContext();
_acceptor.async_accept(context->get_executor(), [this,
selectIoContext = std::forward<SelectIoContextForNewSocket>(selectIoContext),
acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error, IoContextTcpSocket&& socket) mutable
{
if (!error)
{
try
{
socket->non_blocking(true);
socket.non_blocking(true);
acceptCallback(std::move(*socket), threadIndex);
acceptCallback(std::move(socket));
}
catch (boost::system::system_error const& err)
{
@@ -67,7 +68,7 @@ public:
}
if (!_closed)
this->AsyncAccept(std::move(acceptCallback));
this->AsyncAccept(std::move(selectIoContext), std::move(acceptCallback));
});
}
@@ -93,7 +94,14 @@ public:
// v6_only is enabled on some *BSD distributions by default
// we want to allow both v4 and v6 connections to the same listener
if (_endpoint.protocol() == boost::asio::ip::tcp::v6())
_acceptor.set_option(boost::asio::ip::v6_only(false));
{
_acceptor.set_option(boost::asio::ip::v6_only(false), errorCode);
if (errorCode)
{
TC_LOG_INFO("network", "Could not disable v6_only option {}", errorCode.message());
return false;
}
}
_acceptor.bind(_endpoint, errorCode);
if (errorCode)
@@ -121,16 +129,10 @@ public:
_acceptor.close(err);
}
void SetSocketFactory(std::function<std::pair<IoContextTcpSocket*, uint32>()> func) { _socketFactory = std::move(func); }
private:
std::pair<IoContextTcpSocket*, uint32> DefaultSocketFactory() { return std::make_pair(&_socket, 0); }
boost::asio::basic_socket_acceptor<boost::asio::ip::tcp, IoContextTcpSocket::executor_type> _acceptor;
boost::asio::ip::tcp::endpoint _endpoint;
IoContextTcpSocket _socket;
std::atomic<bool> _closed;
std::function<std::pair<IoContextTcpSocket*, uint32>()> _socketFactory;
};
}

View File

@@ -113,12 +113,12 @@ template<typename Callable, typename SessionImpl>
concept HttpRequestHandler = invocable_r<Callable, RequestHandlerResult, std::shared_ptr<SessionImpl>, RequestContext&>;
template<typename SessionImpl>
class HttpNetworkThread final : public NetworkThread<SessionImpl>
class HttpService;
template<typename SessionImpl>
class HttpNetworkThread final : public NetworkThread<SessionImpl, HttpNetworkThread<SessionImpl>>
{
public:
explicit HttpNetworkThread(SessionService* service) : _service(service) { }
protected:
void SocketRemoved(std::shared_ptr<SessionImpl> const& session) override
{
if (Optional<boost::uuids::uuid> id = session->GetSessionId())
@@ -126,13 +126,24 @@ protected:
}
private:
friend class HttpService<SessionImpl>;
SessionService* _service = nullptr;
};
template<typename SessionImpl>
class HttpService : public SocketMgr<SessionImpl, HttpNetworkThread<SessionImpl>>, public DispatcherService, public SessionService
struct HttpServiceTraits
{
using BaseSocketMgr = SocketMgr<SessionImpl, HttpNetworkThread<SessionImpl>>;
using Self = HttpService<SessionImpl>;
using SocketType = SessionImpl;
using ThreadType = HttpNetworkThread<SessionImpl>;
};
template<typename SessionImpl>
class HttpService : public SocketMgr<HttpServiceTraits<SessionImpl>>, public DispatcherService, public SessionService
{
using BaseSocketMgr = SocketMgr<HttpServiceTraits<SessionImpl>>;
friend BaseSocketMgr;
public:
HttpService(std::string_view loggerSuffix) : DispatcherService(loggerSuffix), SessionService(loggerSuffix), _ioContext(nullptr), _logger("server.http.")
@@ -176,11 +187,11 @@ public:
}
protected:
HttpNetworkThread<SessionImpl>* CreateThreads() const final
std::unique_ptr<HttpNetworkThread<SessionImpl>[]> CreateThreads() const override
{
HttpNetworkThread<SessionImpl>* threads = static_cast<HttpNetworkThread<SessionImpl>*>(::operator new(sizeof(HttpNetworkThread<SessionImpl>) * this->GetNetworkThreadCount()));
std::unique_ptr<HttpNetworkThread<SessionImpl>[]> threads = std::make_unique<HttpNetworkThread<SessionImpl>[]>(this->GetNetworkThreadCount());
for (int32 i = 0; i < this->GetNetworkThreadCount(); ++i)
new (&threads[i]) HttpNetworkThread<SessionImpl>(const_cast<HttpService*>(this));
threads[i]._service = const_cast<HttpService*>(this);
return threads;
}

View File

@@ -24,8 +24,6 @@
#include "Errors.h"
#include "IoContext.h"
#include "Log.h"
#include "Socket.h"
#include <boost/asio/ip/tcp.hpp>
#include <atomic>
#include <memory>
#include <mutex>
@@ -33,12 +31,12 @@
namespace Trinity::Net
{
template<class SocketType>
template<class SocketType, class DerivedThread>
class NetworkThread
{
public:
NetworkThread() : _connections(0), _stopped(false), _thread(nullptr), _ioContext(1),
_acceptSocket(_ioContext), _updateTimer(_ioContext)
_updateTimer(_ioContext)
{
}
@@ -82,15 +80,15 @@ public:
return _connections;
}
void AddSocket(std::shared_ptr<SocketType> sock)
void AddSocket(std::shared_ptr<SocketType>&& sock)
{
std::scoped_lock lock(_newSocketsLock);
++_connections;
SocketAdded(_newSockets.emplace_back(std::move(sock)));
static_cast<DerivedThread*>(this)->SocketAdded(_newSockets.emplace_back(std::move(sock)));
}
Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; }
Trinity::Asio::IoContext* GetIoContext() { return &_ioContext; }
protected:
virtual void SocketAdded(std::shared_ptr<SocketType> const& /*sock*/) { }
@@ -107,7 +105,7 @@ protected:
{
if (!sock->IsOpen())
{
SocketRemoved(sock);
static_cast<DerivedThread*>(this)->SocketRemoved(sock);
--_connections;
}
else
@@ -147,7 +145,7 @@ protected:
if (sock->IsOpen())
sock->CloseSocket();
this->SocketRemoved(sock);
static_cast<DerivedThread*>(this)->SocketRemoved(sock);
--this->_connections;
return true;
@@ -171,7 +169,6 @@ private:
SocketContainer _newSockets;
Trinity::Asio::IoContext _ioContext;
Trinity::Net::IoContextTcpSocket _acceptSocket;
Trinity::Asio::DeadlineTimer _updateTimer;
};
}

View File

@@ -128,7 +128,7 @@ class Socket : public std::enable_shared_from_this<Socket<Stream>>
public:
template<typename... Args>
explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...),
_remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open)
_remoteEndpoint(_socket.remote_endpoint()), _openState(OpenState_Open)
{
}
@@ -185,18 +185,17 @@ public:
boost::asio::ip::address const& GetRemoteIpAddress() const
{
return _remoteAddress;
return _remoteEndpoint.Address;
}
uint16 GetRemotePort() const
{
return _remotePort;
return _remoteEndpoint.Port;
}
void SetRemoteEndpoint(boost::asio::ip::tcp::endpoint const& endpoint)
{
_remoteAddress = endpoint.address();
_remotePort = endpoint.port();
_remoteEndpoint = endpoint;
}
template <invocable_r<SocketReadCallbackResult> Callback>
@@ -227,7 +226,7 @@ public:
void CloseSocket()
{
if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0)
if (_openState.exchange(OpenState_Closed) == OpenState_Closed)
return;
boost::system::error_code shutdownError;
@@ -242,7 +241,8 @@ public:
/// Marks the socket for closing after write buffer becomes empty
void DelayedCloseSocket()
{
if (_openState.fetch_or(OpenState_Closing) != 0)
uint8 oldState = OpenState_Open;
if (!_openState.compare_exchange_strong(oldState, OpenState_Closing))
return;
if (_writeQueue.empty())
@@ -380,8 +380,14 @@ private:
Stream _socket;
boost::asio::ip::address _remoteAddress;
uint16 _remotePort = 0;
struct Endpoint
{
Endpoint() : Address(), Port(0) { }
explicit(false) Endpoint(boost::asio::ip::tcp_endpoint const& endpoint) : Address(endpoint.address()), Port(endpoint.port()) { }
boost::asio::ip::address Address;
uint16 Port;
} _remoteEndpoint;
MessageBuffer _readBuffer = MessageBuffer(0x1000);
std::queue<MessageBuffer> _writeQueue;

View File

@@ -27,13 +27,14 @@
namespace Trinity::Net
{
template <typename SocketType, typename ThreadType>
template <typename Traits>
class SocketMgr
{
static_assert(std::is_base_of_v<NetworkThread<SocketType>, ThreadType>);
static_assert(std::is_final_v<ThreadType>);
public:
using Self = typename Traits::Self;
using SocketType = typename Traits::SocketType;
using ThreadType = typename Traits::ThreadType;
SocketMgr(SocketMgr const&) = delete;
SocketMgr(SocketMgr&&) = delete;
SocketMgr& operator=(SocketMgr const&) = delete;
@@ -67,14 +68,16 @@ public:
_acceptor = std::move(acceptor);
_threadCount = threadCount;
_threads.reset(CreateThreads());
_threads = static_cast<Self*>(this)->CreateThreads();
ASSERT(_threads);
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Start();
_acceptor->SetSocketFactory([this]() { return GetSocketForAccept(); });
_acceptor->AsyncAccept(
[this]{ return SelectThreadWithMinConnections(); },
[this](IoContextTcpSocket&& sock) { static_cast<Self*>(this)->OnSocketOpen(std::move(sock)); });
return true;
}
@@ -99,14 +102,19 @@ public:
_threads[i].Wait();
}
virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex)
virtual void OnSocketOpen(IoContextTcpSocket&& sock)
{
try
{
int32 threadIndex = 0;
for (; threadIndex < _threadCount; ++threadIndex)
if (_threads[threadIndex].GetIoContext()->get_executor() == sock.get_executor())
break;
std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
newSocket->Start();
_threads[threadIndex].AddSocket(newSocket);
_threads[threadIndex].AddSocket(std::move(newSocket));
}
catch (boost::system::system_error const& err)
{
@@ -116,21 +124,15 @@ public:
int32 GetNetworkThreadCount() const { return _threadCount; }
uint32 SelectThreadWithMinConnections() const
Asio::IoContext* SelectThreadWithMinConnections() const
{
uint32 min = 0;
ThreadType* min = &_threads[0];
for (int32 i = 1; i < _threadCount; ++i)
if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
for (ThreadType* i = min + 1; i != _threads.get() + _threadCount; ++i)
if (i->GetConnectionCount() < min->GetConnectionCount())
min = i;
return min;
}
std::pair<IoContextTcpSocket*, uint32> GetSocketForAccept()
{
uint32 threadIndex = SelectThreadWithMinConnections();
return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
return min->GetIoContext();
}
protected:
@@ -138,7 +140,10 @@ protected:
{
}
virtual ThreadType* CreateThreads() const = 0;
virtual std::unique_ptr<ThreadType[]> CreateThreads() const
{
return std::make_unique<ThreadType[]>(GetNetworkThreadCount());
}
std::unique_ptr<AsyncAcceptor> _acceptor;
std::unique_ptr<ThreadType[]> _threads;

View File

@@ -83,11 +83,13 @@ void ServiceStatusWatcher(std::weak_ptr<Trinity::Asio::DeadlineTimer> serviceSta
bool StartDB();
void StopDB();
void SignalHandler(std::weak_ptr<Trinity::Asio::IoContext> ioContextRef, boost::system::error_code const& error, int signalNumber);
void SignalHandler(boost::system::error_code const& error, int signalNumber);
void KeepDatabaseAliveHandler(std::weak_ptr<Trinity::Asio::DeadlineTimer> dbPingTimerRef, int32 dbPingInterval, boost::system::error_code const& error);
void BanExpiryHandler(std::weak_ptr<Trinity::Asio::DeadlineTimer> banExpiryCheckTimerRef, int32 banExpiryCheckInterval, boost::system::error_code const& error);
variables_map GetConsoleArguments(int argc, char** argv, fs::path& configFile, fs::path& configDir, std::string& winServiceAction);
std::atomic<bool> Stopped;
int main(int argc, char** argv)
{
signal(SIGABRT, &Trinity::AbortHandler);
@@ -203,6 +205,14 @@ int main(int argc, char** argv)
std::shared_ptr<Trinity::Asio::IoContext> ioContext = std::make_shared<Trinity::Asio::IoContext>();
auto ioContextWork = boost::asio::make_work_guard(ioContext->get_executor());
std::thread ioContextThread(&Trinity::Asio::IoContext::run, ioContext.get());
auto threadJoinHandle = Trinity::make_unique_ptr_with_deleter<[](std::thread* t) { t->join(); }>(&ioContextThread);
auto ioContextWorkGuardHandle = Trinity::make_unique_ptr_with_deleter<[](auto* wg) { wg->reset(); }>(&ioContextWork);
Trinity::Net::ScanLocalNetworks();
std::string httpBindIp = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0");
@@ -236,7 +246,7 @@ int main(int argc, char** argv)
std::string bindIp = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0");
if (!sSessionMgr.StartNetwork(*ioContext, bindIp, bnport))
if (!sSessionMgr.StartNetwork(*ioContext, bindIp, bnport, 1))
{
TC_LOG_ERROR("server.bnetserver", "Failed to initialize network");
return 1;
@@ -245,14 +255,11 @@ int main(int argc, char** argv)
auto sSessionMgrHandle = Trinity::make_unique_ptr_with_deleter<&Battlenet::SessionManager::StopNetwork>(&sSessionMgr);
// Set signal handlers
boost::asio::signal_set signals(*ioContext, SIGINT, SIGTERM);
boost::asio::basic_signal_set<Trinity::Asio::IoContext::Executor> signals(*ioContext, SIGINT, SIGTERM);
#if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS
signals.add(SIGBREAK);
#endif
signals.async_wait([ioContextRef = std::weak_ptr(ioContext)](boost::system::error_code const& error, int signalNumber) mutable
{
SignalHandler(std::move(ioContextRef), error, signalNumber);
});
signals.async_wait(SignalHandler);
// Set process priority according to configuration settings
SetProcessPriority("server.bnetserver", sConfigMgr->GetIntDefault(CONFIG_PROCESSOR_AFFINITY, 0), sConfigMgr->GetBoolDefault(CONFIG_HIGH_PRIORITY, false));
@@ -287,16 +294,11 @@ int main(int argc, char** argv)
}
#endif
// Start the io service worker loop
ioContext->run();
banExpiryCheckTimer->cancel();
dbPingTimer->cancel();
while (!Stopped)
std::this_thread::sleep_for(50ms);
TC_LOG_INFO("server.bnetserver", "Halting process...");
signals.cancel();
return 0;
}
@@ -325,11 +327,10 @@ void StopDB()
MySQL::Library_End();
}
void SignalHandler(std::weak_ptr<Trinity::Asio::IoContext> ioContextRef, boost::system::error_code const& error, int /*signalNumber*/)
void SignalHandler(boost::system::error_code const& error, int /*signalNumber*/)
{
if (!error)
if (std::shared_ptr<Trinity::Asio::IoContext> ioContext = ioContextRef.lock())
ioContext->stop();
Stopped = true;
}
void KeepDatabaseAliveHandler(std::weak_ptr<Trinity::Asio::DeadlineTimer> dbPingTimerRef, int32 dbPingInterval, boost::system::error_code const& error)

View File

@@ -30,7 +30,7 @@ struct LoginSessionState : public Trinity::Net::Http::SessionState
std::unique_ptr<Trinity::Crypto::SRP::BnetSRP6Base> Srp;
};
class LoginHttpSession : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSession>
class LoginHttpSession final : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSession>
{
public:
static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID=";

View File

@@ -23,7 +23,6 @@
#include "CryptoRandom.h"
#include "DatabaseEnv.h"
#include "IpNetwork.h"
#include "IteratorPair.h"
#include "ProtobufJSON.h"
#include "Resolver.h"
#include "SslContext.h"
@@ -40,6 +39,33 @@ LoginRESTService& LoginRESTService::Instance()
bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount)
{
Trinity::Net::Resolver resolver(ioContext);
_externalHostname = sConfigMgr->GetStringDefault("LoginREST.ExternalAddress"sv, "127.0.0.1");
std::ranges::transform(resolver.ResolveAll(_externalHostname, ""),
std::back_inserter(_addresses),
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
if (_addresses.empty())
{
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.ExternalAddress {}", _externalHostname);
return false;
}
_localHostname = sConfigMgr->GetStringDefault("LoginREST.LocalAddress"sv, "127.0.0.1");
_firstLocalAddressIndex = _addresses.size();
std::ranges::transform(resolver.ResolveAll(_localHostname, ""),
std::back_inserter(_addresses),
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
if (_addresses.size() == _firstLocalAddressIndex)
{
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.LocalAddress {}", _localHostname);
return false;
}
if (!HttpService::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
@@ -75,36 +101,8 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st
return HandlePostRefreshLoginTicket(std::move(session), context);
});
_bindIP = bindIp;
_port = port;
Trinity::Net::Resolver resolver(ioContext);
_externalHostname = sConfigMgr->GetStringDefault("LoginREST.ExternalAddress"sv, "127.0.0.1");
std::ranges::transform(resolver.ResolveAll(_externalHostname, ""),
std::back_inserter(_addresses),
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
if (_addresses.empty())
{
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.ExternalAddress {}", _externalHostname);
return false;
}
_localHostname = sConfigMgr->GetStringDefault("LoginREST.LocalAddress"sv, "127.0.0.1");
_firstLocalAddressIndex = _addresses.size();
std::ranges::transform(resolver.ResolveAll(_localHostname, ""),
std::back_inserter(_addresses),
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
if (_addresses.size() == _firstLocalAddressIndex)
{
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.LocalAddress {}", _localHostname);
return false;
}
// set up form inputs
JSON::Login::FormInput* input;
_formInputs.set_type(JSON::Login::LOGIN_FORM);
@@ -129,10 +127,6 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st
MigrateLegacyPasswordHashes();
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
{
OnSocketOpen(std::move(sock), threadIndex);
});
return true;
}

View File

@@ -79,7 +79,6 @@ private:
void MigrateLegacyPasswordHashes() const;
JSON::Login::FormInputs _formInputs;
std::string _bindIP;
uint16 _port;
std::string _externalHostname;
std::string _localHostname;

View File

@@ -17,23 +17,6 @@
#include "SessionManager.h"
bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
{
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
{
OnSocketOpen(std::move(sock), threadIndex);
});
return true;
}
Battlenet::SessionNetworkThread* Battlenet::SessionManager::CreateThreads() const
{
return new SessionNetworkThread[GetNetworkThreadCount()];
}
Battlenet::SessionManager& Battlenet::SessionManager::Instance()
{
static SessionManager instance;

View File

@@ -23,21 +23,21 @@
namespace Battlenet
{
class SessionNetworkThread final : public Trinity::Net::NetworkThread<Session>
class SessionNetworkThread final : public Trinity::Net::NetworkThread<Session, SessionNetworkThread>
{
};
class SessionManager final : public Trinity::Net::SocketMgr<Session, SessionNetworkThread>
struct SessionManagerTraits
{
using BaseSocketMgr = SocketMgr;
using Self = class SessionManager;
using SocketType = Session;
using ThreadType = SessionNetworkThread;
};
class SessionManager final : public Trinity::Net::SocketMgr<SessionManagerTraits>
{
public:
static SessionManager& Instance();
bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount = 1) override;
protected:
SessionNetworkThread* CreateThreads() const override;
};
}

View File

@@ -31,7 +31,7 @@ void WorldSocketThread::SocketRemoved(std::shared_ptr<WorldSocket>const& sock)
sScriptMgr->OnSocketClose(sock);
}
WorldSocketMgr::WorldSocketMgr() : BaseSocketMgr(), _socketSystemSendBufferSize(-1), _socketApplicationSendBufferSize(65536), _tcpNoDelay(true)
WorldSocketMgr::WorldSocketMgr() : _socketSystemSendBufferSize(-1), _socketApplicationSendBufferSize(65536), _tcpNoDelay(true)
{
}
@@ -61,26 +61,21 @@ bool WorldSocketMgr::StartNetwork(Trinity::Asio::IoContext& ioContext, std::stri
return false;
}
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
if (!SocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
{
OnSocketOpen(std::move(sock), threadIndex);
});
sScriptMgr->OnNetworkStart();
return true;
}
void WorldSocketMgr::StopNetwork()
{
BaseSocketMgr::StopNetwork();
SocketMgr::StopNetwork();
sScriptMgr->OnNetworkStop();
}
void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock)
{
// set some options here
if (_socketSystemSendBufferSize >= 0)
@@ -106,10 +101,5 @@ void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint3
}
}
BaseSocketMgr::OnSocketOpen(std::move(sock), threadIndex);
}
WorldSocketThread* WorldSocketMgr::CreateThreads() const
{
return new WorldSocketThread[GetNetworkThreadCount()];
SocketMgr::OnSocketOpen(std::move(sock));
}

View File

@@ -21,7 +21,7 @@
#include "SocketMgr.h"
#include "WorldSocket.h"
class WorldSocketThread final : public Trinity::Net::NetworkThread<WorldSocket>
class WorldSocketThread final : public Trinity::Net::NetworkThread<WorldSocket, WorldSocketThread>
{
public:
void SocketAdded(std::shared_ptr<WorldSocket> const& sock) override;
@@ -29,11 +29,18 @@ public:
void SocketRemoved(std::shared_ptr<WorldSocket>const& sock) override;
};
/// Manages all sockets connected to peers and network threads
class TC_GAME_API WorldSocketMgr final : public Trinity::Net::SocketMgr<WorldSocket, WorldSocketThread>
{
using BaseSocketMgr = SocketMgr;
class WorldSocketMgr;
struct WorldSocketMgrTraits
{
using Self = WorldSocketMgr;
using SocketType = WorldSocket;
using ThreadType = WorldSocketThread;
};
/// Manages all sockets connected to peers and network threads
class TC_GAME_API WorldSocketMgr final : public Trinity::Net::SocketMgr<WorldSocketMgrTraits>
{
public:
~WorldSocketMgr();
@@ -45,15 +52,13 @@ public:
/// Stops all network threads, It will wait for all running threads .
void StopNetwork() override;
void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) override;
void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock) override;
std::size_t GetApplicationSendBufferSize() const { return _socketApplicationSendBufferSize; }
protected:
WorldSocketMgr();
WorldSocketThread* CreateThreads() const override;
private:
int32 _socketSystemSendBufferSize;
int32 _socketApplicationSendBufferSize;

View File

@@ -268,7 +268,7 @@ int main(int argc, char** argv)
}
// Set signal handlers (this must be done before starting IoContext threads, because otherwise they would unblock and exit)
boost::asio::signal_set signals(*ioContext, SIGINT, SIGTERM);
boost::asio::basic_signal_set<Trinity::Asio::IoContext::Executor> signals(*ioContext, SIGINT, SIGTERM);
#if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS
signals.add(SIGBREAK);
#endif
@@ -284,7 +284,7 @@ int main(int argc, char** argv)
for (int i = 0; i < numThreads; ++i)
threadPool->PostWork([ioContext]() { ioContext->run(); });
auto ioContextStopHandle = Trinity::make_unique_ptr_with_deleter<&Trinity::Asio::IoContext::stop>(ioContext.get());
auto signalsCancelHandle = Trinity::make_unique_ptr_with_deleter<[](auto* s) { boost::system::error_code ec; s->cancel(ec); }>(&signals);
// Set process priority according to configuration settings
SetProcessPriority("server.worldserver", sConfigMgr->GetIntDefault(CONFIG_PROCESSOR_AFFINITY, 0), sConfigMgr->GetBoolDefault(CONFIG_HIGH_PRIORITY, false));
@@ -445,10 +445,6 @@ int main(int argc, char** argv)
WorldPackets::Auth::ConnectTo::ShutdownEncryption();
WorldPackets::Auth::EnterEncryptedMode::ShutdownEncryption();
ioContextStopHandle.reset();
threadPool.reset();
sLog->SetSynchronous();
sScriptMgr->OnShutdown();
@@ -636,14 +632,14 @@ std::unique_ptr<Trinity::Net::AsyncAcceptor> StartRaSocketAcceptor(Trinity::Asio
if (!acceptor->Bind())
{
TC_LOG_ERROR("server.worldserver", "Failed to bind RA socket acceptor");
return nullptr;
acceptor = nullptr;
return acceptor;
}
acceptor->AsyncAccept([](Trinity::Net::IoContextTcpSocket&& sock, uint32 /*threadIndex*/)
{
std::make_shared<RASession>(std::move(sock))->Start();
acceptor->AsyncAccept(
[&] { return &ioContext; },
[](Trinity::Net::IoContextTcpSocket&& sock) { std::make_shared<RASession>(std::move(sock))->Start(); });
});
return acceptor;
}