mirror of
https://github.com/TrinityCore/TrinityCore.git
synced 2026-01-15 23:20:36 +01:00
Core/Network: Refactor AsyncAcceptor to use async_accept overload producing sockets through argument instead of having to preallocate it
* Also improve main() cleanup to fully process all queued async operations (including their cancellations)
This commit is contained in:
@@ -35,6 +35,12 @@ namespace boost
|
||||
template <typename Time>
|
||||
struct time_traits;
|
||||
|
||||
template <typename Clock>
|
||||
struct wait_traits;
|
||||
|
||||
template <typename Protocol, typename Clock, typename WaitTraits>
|
||||
class basic_socket_iostream;
|
||||
|
||||
namespace ip
|
||||
{
|
||||
class address;
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef IoContext_h__
|
||||
#define IoContext_h__
|
||||
#ifndef TRINITYCORE_IO_CONTEXT_H
|
||||
#define TRINITYCORE_IO_CONTEXT_H
|
||||
|
||||
#include <boost/asio/bind_executor.hpp>
|
||||
#include <boost/asio/io_context.hpp>
|
||||
@@ -56,20 +56,20 @@ namespace Trinity
|
||||
return boost::asio::post(ioContext, std::forward<T>(t));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline decltype(auto) post(boost::asio::io_context::executor_type& executor, T&& t)
|
||||
{
|
||||
return boost::asio::post(executor, std::forward<T>(t));
|
||||
}
|
||||
|
||||
using boost::asio::bind_executor;
|
||||
|
||||
template<typename T>
|
||||
inline decltype(auto) post(boost::asio::io_context::executor_type const& executor, T&& t)
|
||||
{
|
||||
return boost::asio::post(executor.context(), bind_executor(executor, std::forward<T>(t)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline decltype(auto) get_io_context(T&& ioObject)
|
||||
{
|
||||
return ioObject.get_executor().context();
|
||||
return std::forward<T>(ioObject).get_executor().context();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // IoContext_h__
|
||||
#endif // TRINITYCORE_IO_CONTEXT_H
|
||||
|
||||
@@ -27,6 +27,8 @@ namespace Trinity
|
||||
class ThreadPool
|
||||
{
|
||||
public:
|
||||
using executor_type = boost::asio::thread_pool::executor_type;
|
||||
|
||||
explicit ThreadPool(std::size_t numThreads = std::thread::hardware_concurrency()) : _impl(numThreads) { }
|
||||
|
||||
template<typename T>
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#ifndef TRINITYCORE_ASYNC_ACCEPTOR_H
|
||||
#define TRINITYCORE_ASYNC_ACCEPTOR_H
|
||||
|
||||
#include "Concepts.h"
|
||||
#include "IoContext.h"
|
||||
#include "IpAddress.h"
|
||||
#include "Log.h"
|
||||
@@ -25,40 +26,40 @@
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/ip/v6_only.hpp>
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
|
||||
#define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections
|
||||
|
||||
namespace Trinity::Net
|
||||
{
|
||||
template <typename Callable>
|
||||
concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&, uint32>;
|
||||
concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&>;
|
||||
|
||||
template <typename Callable>
|
||||
concept SelectIoContextForNewSocketFn = Trinity::invocable_r<Callable, Asio::IoContext*>;
|
||||
|
||||
class AsyncAcceptor
|
||||
{
|
||||
public:
|
||||
AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
|
||||
_acceptor(ioContext), _endpoint(make_address(bindIp), port),
|
||||
_socket(ioContext), _closed(false), _socketFactory([this] { return DefaultSocketFactory(); })
|
||||
_acceptor(ioContext), _endpoint(make_address(bindIp), port), _closed(false)
|
||||
{
|
||||
}
|
||||
|
||||
template <AcceptCallback Callback>
|
||||
void AsyncAccept(Callback&& acceptCallback)
|
||||
template <SelectIoContextForNewSocketFn SelectIoContextForNewSocket, AcceptCallback Callback>
|
||||
void AsyncAccept(SelectIoContextForNewSocket&& selectIoContext, Callback&& acceptCallback)
|
||||
{
|
||||
auto [tmpSocket, tmpThreadIndex] = _socketFactory();
|
||||
// TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures)
|
||||
IoContextTcpSocket* socket = tmpSocket;
|
||||
uint32 threadIndex = tmpThreadIndex;
|
||||
_acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error) mutable
|
||||
Asio::IoContext* context = selectIoContext();
|
||||
_acceptor.async_accept(context->get_executor(), [this,
|
||||
selectIoContext = std::forward<SelectIoContextForNewSocket>(selectIoContext),
|
||||
acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error, IoContextTcpSocket&& socket) mutable
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
try
|
||||
{
|
||||
socket->non_blocking(true);
|
||||
socket.non_blocking(true);
|
||||
|
||||
acceptCallback(std::move(*socket), threadIndex);
|
||||
acceptCallback(std::move(socket));
|
||||
}
|
||||
catch (boost::system::system_error const& err)
|
||||
{
|
||||
@@ -67,7 +68,7 @@ public:
|
||||
}
|
||||
|
||||
if (!_closed)
|
||||
this->AsyncAccept(std::move(acceptCallback));
|
||||
this->AsyncAccept(std::move(selectIoContext), std::move(acceptCallback));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -93,7 +94,14 @@ public:
|
||||
// v6_only is enabled on some *BSD distributions by default
|
||||
// we want to allow both v4 and v6 connections to the same listener
|
||||
if (_endpoint.protocol() == boost::asio::ip::tcp::v6())
|
||||
_acceptor.set_option(boost::asio::ip::v6_only(false));
|
||||
{
|
||||
_acceptor.set_option(boost::asio::ip::v6_only(false), errorCode);
|
||||
if (errorCode)
|
||||
{
|
||||
TC_LOG_INFO("network", "Could not disable v6_only option {}", errorCode.message());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
_acceptor.bind(_endpoint, errorCode);
|
||||
if (errorCode)
|
||||
@@ -121,16 +129,10 @@ public:
|
||||
_acceptor.close(err);
|
||||
}
|
||||
|
||||
void SetSocketFactory(std::function<std::pair<IoContextTcpSocket*, uint32>()> func) { _socketFactory = std::move(func); }
|
||||
|
||||
private:
|
||||
std::pair<IoContextTcpSocket*, uint32> DefaultSocketFactory() { return std::make_pair(&_socket, 0); }
|
||||
|
||||
boost::asio::basic_socket_acceptor<boost::asio::ip::tcp, IoContextTcpSocket::executor_type> _acceptor;
|
||||
boost::asio::ip::tcp::endpoint _endpoint;
|
||||
IoContextTcpSocket _socket;
|
||||
std::atomic<bool> _closed;
|
||||
std::function<std::pair<IoContextTcpSocket*, uint32>()> _socketFactory;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -113,12 +113,12 @@ template<typename Callable, typename SessionImpl>
|
||||
concept HttpRequestHandler = invocable_r<Callable, RequestHandlerResult, std::shared_ptr<SessionImpl>, RequestContext&>;
|
||||
|
||||
template<typename SessionImpl>
|
||||
class HttpNetworkThread final : public NetworkThread<SessionImpl>
|
||||
class HttpService;
|
||||
|
||||
template<typename SessionImpl>
|
||||
class HttpNetworkThread final : public NetworkThread<SessionImpl, HttpNetworkThread<SessionImpl>>
|
||||
{
|
||||
public:
|
||||
explicit HttpNetworkThread(SessionService* service) : _service(service) { }
|
||||
|
||||
protected:
|
||||
void SocketRemoved(std::shared_ptr<SessionImpl> const& session) override
|
||||
{
|
||||
if (Optional<boost::uuids::uuid> id = session->GetSessionId())
|
||||
@@ -126,13 +126,24 @@ protected:
|
||||
}
|
||||
|
||||
private:
|
||||
friend class HttpService<SessionImpl>;
|
||||
SessionService* _service = nullptr;
|
||||
};
|
||||
|
||||
template<typename SessionImpl>
|
||||
class HttpService : public SocketMgr<SessionImpl, HttpNetworkThread<SessionImpl>>, public DispatcherService, public SessionService
|
||||
struct HttpServiceTraits
|
||||
{
|
||||
using BaseSocketMgr = SocketMgr<SessionImpl, HttpNetworkThread<SessionImpl>>;
|
||||
using Self = HttpService<SessionImpl>;
|
||||
using SocketType = SessionImpl;
|
||||
using ThreadType = HttpNetworkThread<SessionImpl>;
|
||||
};
|
||||
|
||||
template<typename SessionImpl>
|
||||
class HttpService : public SocketMgr<HttpServiceTraits<SessionImpl>>, public DispatcherService, public SessionService
|
||||
{
|
||||
using BaseSocketMgr = SocketMgr<HttpServiceTraits<SessionImpl>>;
|
||||
|
||||
friend BaseSocketMgr;
|
||||
|
||||
public:
|
||||
HttpService(std::string_view loggerSuffix) : DispatcherService(loggerSuffix), SessionService(loggerSuffix), _ioContext(nullptr), _logger("server.http.")
|
||||
@@ -176,11 +187,11 @@ public:
|
||||
}
|
||||
|
||||
protected:
|
||||
HttpNetworkThread<SessionImpl>* CreateThreads() const final
|
||||
std::unique_ptr<HttpNetworkThread<SessionImpl>[]> CreateThreads() const override
|
||||
{
|
||||
HttpNetworkThread<SessionImpl>* threads = static_cast<HttpNetworkThread<SessionImpl>*>(::operator new(sizeof(HttpNetworkThread<SessionImpl>) * this->GetNetworkThreadCount()));
|
||||
std::unique_ptr<HttpNetworkThread<SessionImpl>[]> threads = std::make_unique<HttpNetworkThread<SessionImpl>[]>(this->GetNetworkThreadCount());
|
||||
for (int32 i = 0; i < this->GetNetworkThreadCount(); ++i)
|
||||
new (&threads[i]) HttpNetworkThread<SessionImpl>(const_cast<HttpService*>(this));
|
||||
threads[i]._service = const_cast<HttpService*>(this);
|
||||
return threads;
|
||||
}
|
||||
|
||||
|
||||
@@ -24,8 +24,6 @@
|
||||
#include "Errors.h"
|
||||
#include "IoContext.h"
|
||||
#include "Log.h"
|
||||
#include "Socket.h"
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
@@ -33,12 +31,12 @@
|
||||
|
||||
namespace Trinity::Net
|
||||
{
|
||||
template<class SocketType>
|
||||
template<class SocketType, class DerivedThread>
|
||||
class NetworkThread
|
||||
{
|
||||
public:
|
||||
NetworkThread() : _connections(0), _stopped(false), _thread(nullptr), _ioContext(1),
|
||||
_acceptSocket(_ioContext), _updateTimer(_ioContext)
|
||||
_updateTimer(_ioContext)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -82,15 +80,15 @@ public:
|
||||
return _connections;
|
||||
}
|
||||
|
||||
void AddSocket(std::shared_ptr<SocketType> sock)
|
||||
void AddSocket(std::shared_ptr<SocketType>&& sock)
|
||||
{
|
||||
std::scoped_lock lock(_newSocketsLock);
|
||||
|
||||
++_connections;
|
||||
SocketAdded(_newSockets.emplace_back(std::move(sock)));
|
||||
static_cast<DerivedThread*>(this)->SocketAdded(_newSockets.emplace_back(std::move(sock)));
|
||||
}
|
||||
|
||||
Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; }
|
||||
Trinity::Asio::IoContext* GetIoContext() { return &_ioContext; }
|
||||
|
||||
protected:
|
||||
virtual void SocketAdded(std::shared_ptr<SocketType> const& /*sock*/) { }
|
||||
@@ -107,7 +105,7 @@ protected:
|
||||
{
|
||||
if (!sock->IsOpen())
|
||||
{
|
||||
SocketRemoved(sock);
|
||||
static_cast<DerivedThread*>(this)->SocketRemoved(sock);
|
||||
--_connections;
|
||||
}
|
||||
else
|
||||
@@ -147,7 +145,7 @@ protected:
|
||||
if (sock->IsOpen())
|
||||
sock->CloseSocket();
|
||||
|
||||
this->SocketRemoved(sock);
|
||||
static_cast<DerivedThread*>(this)->SocketRemoved(sock);
|
||||
|
||||
--this->_connections;
|
||||
return true;
|
||||
@@ -171,7 +169,6 @@ private:
|
||||
SocketContainer _newSockets;
|
||||
|
||||
Trinity::Asio::IoContext _ioContext;
|
||||
Trinity::Net::IoContextTcpSocket _acceptSocket;
|
||||
Trinity::Asio::DeadlineTimer _updateTimer;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -128,7 +128,7 @@ class Socket : public std::enable_shared_from_this<Socket<Stream>>
|
||||
public:
|
||||
template<typename... Args>
|
||||
explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...),
|
||||
_remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open)
|
||||
_remoteEndpoint(_socket.remote_endpoint()), _openState(OpenState_Open)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -185,18 +185,17 @@ public:
|
||||
|
||||
boost::asio::ip::address const& GetRemoteIpAddress() const
|
||||
{
|
||||
return _remoteAddress;
|
||||
return _remoteEndpoint.Address;
|
||||
}
|
||||
|
||||
uint16 GetRemotePort() const
|
||||
{
|
||||
return _remotePort;
|
||||
return _remoteEndpoint.Port;
|
||||
}
|
||||
|
||||
void SetRemoteEndpoint(boost::asio::ip::tcp::endpoint const& endpoint)
|
||||
{
|
||||
_remoteAddress = endpoint.address();
|
||||
_remotePort = endpoint.port();
|
||||
_remoteEndpoint = endpoint;
|
||||
}
|
||||
|
||||
template <invocable_r<SocketReadCallbackResult> Callback>
|
||||
@@ -227,7 +226,7 @@ public:
|
||||
|
||||
void CloseSocket()
|
||||
{
|
||||
if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0)
|
||||
if (_openState.exchange(OpenState_Closed) == OpenState_Closed)
|
||||
return;
|
||||
|
||||
boost::system::error_code shutdownError;
|
||||
@@ -242,7 +241,8 @@ public:
|
||||
/// Marks the socket for closing after write buffer becomes empty
|
||||
void DelayedCloseSocket()
|
||||
{
|
||||
if (_openState.fetch_or(OpenState_Closing) != 0)
|
||||
uint8 oldState = OpenState_Open;
|
||||
if (!_openState.compare_exchange_strong(oldState, OpenState_Closing))
|
||||
return;
|
||||
|
||||
if (_writeQueue.empty())
|
||||
@@ -380,8 +380,14 @@ private:
|
||||
|
||||
Stream _socket;
|
||||
|
||||
boost::asio::ip::address _remoteAddress;
|
||||
uint16 _remotePort = 0;
|
||||
struct Endpoint
|
||||
{
|
||||
Endpoint() : Address(), Port(0) { }
|
||||
explicit(false) Endpoint(boost::asio::ip::tcp_endpoint const& endpoint) : Address(endpoint.address()), Port(endpoint.port()) { }
|
||||
|
||||
boost::asio::ip::address Address;
|
||||
uint16 Port;
|
||||
} _remoteEndpoint;
|
||||
|
||||
MessageBuffer _readBuffer = MessageBuffer(0x1000);
|
||||
std::queue<MessageBuffer> _writeQueue;
|
||||
|
||||
@@ -27,13 +27,14 @@
|
||||
|
||||
namespace Trinity::Net
|
||||
{
|
||||
template <typename SocketType, typename ThreadType>
|
||||
template <typename Traits>
|
||||
class SocketMgr
|
||||
{
|
||||
static_assert(std::is_base_of_v<NetworkThread<SocketType>, ThreadType>);
|
||||
static_assert(std::is_final_v<ThreadType>);
|
||||
|
||||
public:
|
||||
using Self = typename Traits::Self;
|
||||
using SocketType = typename Traits::SocketType;
|
||||
using ThreadType = typename Traits::ThreadType;
|
||||
|
||||
SocketMgr(SocketMgr const&) = delete;
|
||||
SocketMgr(SocketMgr&&) = delete;
|
||||
SocketMgr& operator=(SocketMgr const&) = delete;
|
||||
@@ -67,14 +68,16 @@ public:
|
||||
|
||||
_acceptor = std::move(acceptor);
|
||||
_threadCount = threadCount;
|
||||
_threads.reset(CreateThreads());
|
||||
_threads = static_cast<Self*>(this)->CreateThreads();
|
||||
|
||||
ASSERT(_threads);
|
||||
|
||||
for (int32 i = 0; i < _threadCount; ++i)
|
||||
_threads[i].Start();
|
||||
|
||||
_acceptor->SetSocketFactory([this]() { return GetSocketForAccept(); });
|
||||
_acceptor->AsyncAccept(
|
||||
[this]{ return SelectThreadWithMinConnections(); },
|
||||
[this](IoContextTcpSocket&& sock) { static_cast<Self*>(this)->OnSocketOpen(std::move(sock)); });
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -99,14 +102,19 @@ public:
|
||||
_threads[i].Wait();
|
||||
}
|
||||
|
||||
virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex)
|
||||
virtual void OnSocketOpen(IoContextTcpSocket&& sock)
|
||||
{
|
||||
try
|
||||
{
|
||||
int32 threadIndex = 0;
|
||||
for (; threadIndex < _threadCount; ++threadIndex)
|
||||
if (_threads[threadIndex].GetIoContext()->get_executor() == sock.get_executor())
|
||||
break;
|
||||
|
||||
std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
|
||||
newSocket->Start();
|
||||
|
||||
_threads[threadIndex].AddSocket(newSocket);
|
||||
_threads[threadIndex].AddSocket(std::move(newSocket));
|
||||
}
|
||||
catch (boost::system::system_error const& err)
|
||||
{
|
||||
@@ -116,21 +124,15 @@ public:
|
||||
|
||||
int32 GetNetworkThreadCount() const { return _threadCount; }
|
||||
|
||||
uint32 SelectThreadWithMinConnections() const
|
||||
Asio::IoContext* SelectThreadWithMinConnections() const
|
||||
{
|
||||
uint32 min = 0;
|
||||
ThreadType* min = &_threads[0];
|
||||
|
||||
for (int32 i = 1; i < _threadCount; ++i)
|
||||
if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
|
||||
for (ThreadType* i = min + 1; i != _threads.get() + _threadCount; ++i)
|
||||
if (i->GetConnectionCount() < min->GetConnectionCount())
|
||||
min = i;
|
||||
|
||||
return min;
|
||||
}
|
||||
|
||||
std::pair<IoContextTcpSocket*, uint32> GetSocketForAccept()
|
||||
{
|
||||
uint32 threadIndex = SelectThreadWithMinConnections();
|
||||
return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
|
||||
return min->GetIoContext();
|
||||
}
|
||||
|
||||
protected:
|
||||
@@ -138,7 +140,10 @@ protected:
|
||||
{
|
||||
}
|
||||
|
||||
virtual ThreadType* CreateThreads() const = 0;
|
||||
virtual std::unique_ptr<ThreadType[]> CreateThreads() const
|
||||
{
|
||||
return std::make_unique<ThreadType[]>(GetNetworkThreadCount());
|
||||
}
|
||||
|
||||
std::unique_ptr<AsyncAcceptor> _acceptor;
|
||||
std::unique_ptr<ThreadType[]> _threads;
|
||||
|
||||
@@ -83,11 +83,13 @@ void ServiceStatusWatcher(std::weak_ptr<Trinity::Asio::DeadlineTimer> serviceSta
|
||||
|
||||
bool StartDB();
|
||||
void StopDB();
|
||||
void SignalHandler(std::weak_ptr<Trinity::Asio::IoContext> ioContextRef, boost::system::error_code const& error, int signalNumber);
|
||||
void SignalHandler(boost::system::error_code const& error, int signalNumber);
|
||||
void KeepDatabaseAliveHandler(std::weak_ptr<Trinity::Asio::DeadlineTimer> dbPingTimerRef, int32 dbPingInterval, boost::system::error_code const& error);
|
||||
void BanExpiryHandler(std::weak_ptr<Trinity::Asio::DeadlineTimer> banExpiryCheckTimerRef, int32 banExpiryCheckInterval, boost::system::error_code const& error);
|
||||
variables_map GetConsoleArguments(int argc, char** argv, fs::path& configFile, fs::path& configDir, std::string& winServiceAction);
|
||||
|
||||
std::atomic<bool> Stopped;
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
signal(SIGABRT, &Trinity::AbortHandler);
|
||||
@@ -203,6 +205,14 @@ int main(int argc, char** argv)
|
||||
|
||||
std::shared_ptr<Trinity::Asio::IoContext> ioContext = std::make_shared<Trinity::Asio::IoContext>();
|
||||
|
||||
auto ioContextWork = boost::asio::make_work_guard(ioContext->get_executor());
|
||||
|
||||
std::thread ioContextThread(&Trinity::Asio::IoContext::run, ioContext.get());
|
||||
|
||||
auto threadJoinHandle = Trinity::make_unique_ptr_with_deleter<[](std::thread* t) { t->join(); }>(&ioContextThread);
|
||||
|
||||
auto ioContextWorkGuardHandle = Trinity::make_unique_ptr_with_deleter<[](auto* wg) { wg->reset(); }>(&ioContextWork);
|
||||
|
||||
Trinity::Net::ScanLocalNetworks();
|
||||
|
||||
std::string httpBindIp = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0");
|
||||
@@ -236,7 +246,7 @@ int main(int argc, char** argv)
|
||||
|
||||
std::string bindIp = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0");
|
||||
|
||||
if (!sSessionMgr.StartNetwork(*ioContext, bindIp, bnport))
|
||||
if (!sSessionMgr.StartNetwork(*ioContext, bindIp, bnport, 1))
|
||||
{
|
||||
TC_LOG_ERROR("server.bnetserver", "Failed to initialize network");
|
||||
return 1;
|
||||
@@ -245,14 +255,11 @@ int main(int argc, char** argv)
|
||||
auto sSessionMgrHandle = Trinity::make_unique_ptr_with_deleter<&Battlenet::SessionManager::StopNetwork>(&sSessionMgr);
|
||||
|
||||
// Set signal handlers
|
||||
boost::asio::signal_set signals(*ioContext, SIGINT, SIGTERM);
|
||||
boost::asio::basic_signal_set<Trinity::Asio::IoContext::Executor> signals(*ioContext, SIGINT, SIGTERM);
|
||||
#if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS
|
||||
signals.add(SIGBREAK);
|
||||
#endif
|
||||
signals.async_wait([ioContextRef = std::weak_ptr(ioContext)](boost::system::error_code const& error, int signalNumber) mutable
|
||||
{
|
||||
SignalHandler(std::move(ioContextRef), error, signalNumber);
|
||||
});
|
||||
signals.async_wait(SignalHandler);
|
||||
|
||||
// Set process priority according to configuration settings
|
||||
SetProcessPriority("server.bnetserver", sConfigMgr->GetIntDefault(CONFIG_PROCESSOR_AFFINITY, 0), sConfigMgr->GetBoolDefault(CONFIG_HIGH_PRIORITY, false));
|
||||
@@ -287,16 +294,11 @@ int main(int argc, char** argv)
|
||||
}
|
||||
#endif
|
||||
|
||||
// Start the io service worker loop
|
||||
ioContext->run();
|
||||
|
||||
banExpiryCheckTimer->cancel();
|
||||
dbPingTimer->cancel();
|
||||
while (!Stopped)
|
||||
std::this_thread::sleep_for(50ms);
|
||||
|
||||
TC_LOG_INFO("server.bnetserver", "Halting process...");
|
||||
|
||||
signals.cancel();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -325,11 +327,10 @@ void StopDB()
|
||||
MySQL::Library_End();
|
||||
}
|
||||
|
||||
void SignalHandler(std::weak_ptr<Trinity::Asio::IoContext> ioContextRef, boost::system::error_code const& error, int /*signalNumber*/)
|
||||
void SignalHandler(boost::system::error_code const& error, int /*signalNumber*/)
|
||||
{
|
||||
if (!error)
|
||||
if (std::shared_ptr<Trinity::Asio::IoContext> ioContext = ioContextRef.lock())
|
||||
ioContext->stop();
|
||||
Stopped = true;
|
||||
}
|
||||
|
||||
void KeepDatabaseAliveHandler(std::weak_ptr<Trinity::Asio::DeadlineTimer> dbPingTimerRef, int32 dbPingInterval, boost::system::error_code const& error)
|
||||
|
||||
@@ -30,7 +30,7 @@ struct LoginSessionState : public Trinity::Net::Http::SessionState
|
||||
std::unique_ptr<Trinity::Crypto::SRP::BnetSRP6Base> Srp;
|
||||
};
|
||||
|
||||
class LoginHttpSession : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSession>
|
||||
class LoginHttpSession final : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSession>
|
||||
{
|
||||
public:
|
||||
static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID=";
|
||||
|
||||
@@ -23,7 +23,6 @@
|
||||
#include "CryptoRandom.h"
|
||||
#include "DatabaseEnv.h"
|
||||
#include "IpNetwork.h"
|
||||
#include "IteratorPair.h"
|
||||
#include "ProtobufJSON.h"
|
||||
#include "Resolver.h"
|
||||
#include "SslContext.h"
|
||||
@@ -40,6 +39,33 @@ LoginRESTService& LoginRESTService::Instance()
|
||||
|
||||
bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount)
|
||||
{
|
||||
Trinity::Net::Resolver resolver(ioContext);
|
||||
|
||||
_externalHostname = sConfigMgr->GetStringDefault("LoginREST.ExternalAddress"sv, "127.0.0.1");
|
||||
|
||||
std::ranges::transform(resolver.ResolveAll(_externalHostname, ""),
|
||||
std::back_inserter(_addresses),
|
||||
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
|
||||
|
||||
if (_addresses.empty())
|
||||
{
|
||||
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.ExternalAddress {}", _externalHostname);
|
||||
return false;
|
||||
}
|
||||
|
||||
_localHostname = sConfigMgr->GetStringDefault("LoginREST.LocalAddress"sv, "127.0.0.1");
|
||||
_firstLocalAddressIndex = _addresses.size();
|
||||
|
||||
std::ranges::transform(resolver.ResolveAll(_localHostname, ""),
|
||||
std::back_inserter(_addresses),
|
||||
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
|
||||
|
||||
if (_addresses.size() == _firstLocalAddressIndex)
|
||||
{
|
||||
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.LocalAddress {}", _localHostname);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!HttpService::StartNetwork(ioContext, bindIp, port, threadCount))
|
||||
return false;
|
||||
|
||||
@@ -75,36 +101,8 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st
|
||||
return HandlePostRefreshLoginTicket(std::move(session), context);
|
||||
});
|
||||
|
||||
_bindIP = bindIp;
|
||||
_port = port;
|
||||
|
||||
Trinity::Net::Resolver resolver(ioContext);
|
||||
|
||||
_externalHostname = sConfigMgr->GetStringDefault("LoginREST.ExternalAddress"sv, "127.0.0.1");
|
||||
|
||||
std::ranges::transform(resolver.ResolveAll(_externalHostname, ""),
|
||||
std::back_inserter(_addresses),
|
||||
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
|
||||
|
||||
if (_addresses.empty())
|
||||
{
|
||||
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.ExternalAddress {}", _externalHostname);
|
||||
return false;
|
||||
}
|
||||
|
||||
_localHostname = sConfigMgr->GetStringDefault("LoginREST.LocalAddress"sv, "127.0.0.1");
|
||||
_firstLocalAddressIndex = _addresses.size();
|
||||
|
||||
std::ranges::transform(resolver.ResolveAll(_localHostname, ""),
|
||||
std::back_inserter(_addresses),
|
||||
[](boost::asio::ip::tcp::endpoint const& endpoint) { return endpoint.address(); });
|
||||
|
||||
if (_addresses.size() == _firstLocalAddressIndex)
|
||||
{
|
||||
TC_LOG_ERROR("server.http.login", "Could not resolve LoginREST.LocalAddress {}", _localHostname);
|
||||
return false;
|
||||
}
|
||||
|
||||
// set up form inputs
|
||||
JSON::Login::FormInput* input;
|
||||
_formInputs.set_type(JSON::Login::LOGIN_FORM);
|
||||
@@ -129,10 +127,6 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st
|
||||
|
||||
MigrateLegacyPasswordHashes();
|
||||
|
||||
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
|
||||
{
|
||||
OnSocketOpen(std::move(sock), threadIndex);
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -79,7 +79,6 @@ private:
|
||||
void MigrateLegacyPasswordHashes() const;
|
||||
|
||||
JSON::Login::FormInputs _formInputs;
|
||||
std::string _bindIP;
|
||||
uint16 _port;
|
||||
std::string _externalHostname;
|
||||
std::string _localHostname;
|
||||
|
||||
@@ -17,23 +17,6 @@
|
||||
|
||||
#include "SessionManager.h"
|
||||
|
||||
bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
|
||||
{
|
||||
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
|
||||
return false;
|
||||
|
||||
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
|
||||
{
|
||||
OnSocketOpen(std::move(sock), threadIndex);
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
Battlenet::SessionNetworkThread* Battlenet::SessionManager::CreateThreads() const
|
||||
{
|
||||
return new SessionNetworkThread[GetNetworkThreadCount()];
|
||||
}
|
||||
|
||||
Battlenet::SessionManager& Battlenet::SessionManager::Instance()
|
||||
{
|
||||
static SessionManager instance;
|
||||
|
||||
@@ -23,21 +23,21 @@
|
||||
|
||||
namespace Battlenet
|
||||
{
|
||||
class SessionNetworkThread final : public Trinity::Net::NetworkThread<Session>
|
||||
class SessionNetworkThread final : public Trinity::Net::NetworkThread<Session, SessionNetworkThread>
|
||||
{
|
||||
};
|
||||
|
||||
class SessionManager final : public Trinity::Net::SocketMgr<Session, SessionNetworkThread>
|
||||
struct SessionManagerTraits
|
||||
{
|
||||
using BaseSocketMgr = SocketMgr;
|
||||
using Self = class SessionManager;
|
||||
using SocketType = Session;
|
||||
using ThreadType = SessionNetworkThread;
|
||||
};
|
||||
|
||||
class SessionManager final : public Trinity::Net::SocketMgr<SessionManagerTraits>
|
||||
{
|
||||
public:
|
||||
static SessionManager& Instance();
|
||||
|
||||
bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount = 1) override;
|
||||
|
||||
protected:
|
||||
SessionNetworkThread* CreateThreads() const override;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ void WorldSocketThread::SocketRemoved(std::shared_ptr<WorldSocket>const& sock)
|
||||
sScriptMgr->OnSocketClose(sock);
|
||||
}
|
||||
|
||||
WorldSocketMgr::WorldSocketMgr() : BaseSocketMgr(), _socketSystemSendBufferSize(-1), _socketApplicationSendBufferSize(65536), _tcpNoDelay(true)
|
||||
WorldSocketMgr::WorldSocketMgr() : _socketSystemSendBufferSize(-1), _socketApplicationSendBufferSize(65536), _tcpNoDelay(true)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -61,26 +61,21 @@ bool WorldSocketMgr::StartNetwork(Trinity::Asio::IoContext& ioContext, std::stri
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
|
||||
if (!SocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
|
||||
return false;
|
||||
|
||||
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
|
||||
{
|
||||
OnSocketOpen(std::move(sock), threadIndex);
|
||||
});
|
||||
|
||||
sScriptMgr->OnNetworkStart();
|
||||
return true;
|
||||
}
|
||||
|
||||
void WorldSocketMgr::StopNetwork()
|
||||
{
|
||||
BaseSocketMgr::StopNetwork();
|
||||
SocketMgr::StopNetwork();
|
||||
|
||||
sScriptMgr->OnNetworkStop();
|
||||
}
|
||||
|
||||
void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
|
||||
void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock)
|
||||
{
|
||||
// set some options here
|
||||
if (_socketSystemSendBufferSize >= 0)
|
||||
@@ -106,10 +101,5 @@ void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint3
|
||||
}
|
||||
}
|
||||
|
||||
BaseSocketMgr::OnSocketOpen(std::move(sock), threadIndex);
|
||||
}
|
||||
|
||||
WorldSocketThread* WorldSocketMgr::CreateThreads() const
|
||||
{
|
||||
return new WorldSocketThread[GetNetworkThreadCount()];
|
||||
SocketMgr::OnSocketOpen(std::move(sock));
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
#include "SocketMgr.h"
|
||||
#include "WorldSocket.h"
|
||||
|
||||
class WorldSocketThread final : public Trinity::Net::NetworkThread<WorldSocket>
|
||||
class WorldSocketThread final : public Trinity::Net::NetworkThread<WorldSocket, WorldSocketThread>
|
||||
{
|
||||
public:
|
||||
void SocketAdded(std::shared_ptr<WorldSocket> const& sock) override;
|
||||
@@ -29,11 +29,18 @@ public:
|
||||
void SocketRemoved(std::shared_ptr<WorldSocket>const& sock) override;
|
||||
};
|
||||
|
||||
/// Manages all sockets connected to peers and network threads
|
||||
class TC_GAME_API WorldSocketMgr final : public Trinity::Net::SocketMgr<WorldSocket, WorldSocketThread>
|
||||
{
|
||||
using BaseSocketMgr = SocketMgr;
|
||||
class WorldSocketMgr;
|
||||
|
||||
struct WorldSocketMgrTraits
|
||||
{
|
||||
using Self = WorldSocketMgr;
|
||||
using SocketType = WorldSocket;
|
||||
using ThreadType = WorldSocketThread;
|
||||
};
|
||||
|
||||
/// Manages all sockets connected to peers and network threads
|
||||
class TC_GAME_API WorldSocketMgr final : public Trinity::Net::SocketMgr<WorldSocketMgrTraits>
|
||||
{
|
||||
public:
|
||||
~WorldSocketMgr();
|
||||
|
||||
@@ -45,15 +52,13 @@ public:
|
||||
/// Stops all network threads, It will wait for all running threads .
|
||||
void StopNetwork() override;
|
||||
|
||||
void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) override;
|
||||
void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock) override;
|
||||
|
||||
std::size_t GetApplicationSendBufferSize() const { return _socketApplicationSendBufferSize; }
|
||||
|
||||
protected:
|
||||
WorldSocketMgr();
|
||||
|
||||
WorldSocketThread* CreateThreads() const override;
|
||||
|
||||
private:
|
||||
int32 _socketSystemSendBufferSize;
|
||||
int32 _socketApplicationSendBufferSize;
|
||||
|
||||
@@ -268,7 +268,7 @@ int main(int argc, char** argv)
|
||||
}
|
||||
|
||||
// Set signal handlers (this must be done before starting IoContext threads, because otherwise they would unblock and exit)
|
||||
boost::asio::signal_set signals(*ioContext, SIGINT, SIGTERM);
|
||||
boost::asio::basic_signal_set<Trinity::Asio::IoContext::Executor> signals(*ioContext, SIGINT, SIGTERM);
|
||||
#if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS
|
||||
signals.add(SIGBREAK);
|
||||
#endif
|
||||
@@ -284,7 +284,7 @@ int main(int argc, char** argv)
|
||||
for (int i = 0; i < numThreads; ++i)
|
||||
threadPool->PostWork([ioContext]() { ioContext->run(); });
|
||||
|
||||
auto ioContextStopHandle = Trinity::make_unique_ptr_with_deleter<&Trinity::Asio::IoContext::stop>(ioContext.get());
|
||||
auto signalsCancelHandle = Trinity::make_unique_ptr_with_deleter<[](auto* s) { boost::system::error_code ec; s->cancel(ec); }>(&signals);
|
||||
|
||||
// Set process priority according to configuration settings
|
||||
SetProcessPriority("server.worldserver", sConfigMgr->GetIntDefault(CONFIG_PROCESSOR_AFFINITY, 0), sConfigMgr->GetBoolDefault(CONFIG_HIGH_PRIORITY, false));
|
||||
@@ -445,10 +445,6 @@ int main(int argc, char** argv)
|
||||
WorldPackets::Auth::ConnectTo::ShutdownEncryption();
|
||||
WorldPackets::Auth::EnterEncryptedMode::ShutdownEncryption();
|
||||
|
||||
ioContextStopHandle.reset();
|
||||
|
||||
threadPool.reset();
|
||||
|
||||
sLog->SetSynchronous();
|
||||
|
||||
sScriptMgr->OnShutdown();
|
||||
@@ -636,14 +632,14 @@ std::unique_ptr<Trinity::Net::AsyncAcceptor> StartRaSocketAcceptor(Trinity::Asio
|
||||
if (!acceptor->Bind())
|
||||
{
|
||||
TC_LOG_ERROR("server.worldserver", "Failed to bind RA socket acceptor");
|
||||
return nullptr;
|
||||
acceptor = nullptr;
|
||||
return acceptor;
|
||||
}
|
||||
|
||||
acceptor->AsyncAccept([](Trinity::Net::IoContextTcpSocket&& sock, uint32 /*threadIndex*/)
|
||||
{
|
||||
std::make_shared<RASession>(std::move(sock))->Start();
|
||||
acceptor->AsyncAccept(
|
||||
[&] { return &ioContext; },
|
||||
[](Trinity::Net::IoContextTcpSocket&& sock) { std::make_shared<RASession>(std::move(sock))->Start(); });
|
||||
|
||||
});
|
||||
return acceptor;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user