diff options
author | Shauren <shauren.trinity@gmail.com> | 2025-04-08 19:15:16 +0200 |
---|---|---|
committer | Shauren <shauren.trinity@gmail.com> | 2025-04-08 19:15:16 +0200 |
commit | e8b2be3527c7683e8bfca70ed7706fc20da566fd (patch) | |
tree | 54d5099554c8628cad719e6f1a49d387c7eced4f /src | |
parent | 40d80f3476ade4898be24659408e82aa4234b099 (diff) |
Core/Network: Socket refactors
* Devirtualize calls to Read and Update by marking concrete implementations as final
* Removed derived class template argument
* Specialize boost::asio::basic_stream_socket for boost::asio::io_context instead of type-erased any_io_executor
* Make socket initialization easier composable (before entering Read loop)
* Remove use of deprecated boost::asio::null_buffers and boost::beast::ssl_stream
Diffstat (limited to 'src')
32 files changed, 962 insertions, 806 deletions
diff --git a/src/common/Utilities/Containers.h b/src/common/Utilities/Containers.h index 4a764629937..11928fa053d 100644 --- a/src/common/Utilities/Containers.h +++ b/src/common/Utilities/Containers.h @@ -259,7 +259,7 @@ namespace Trinity if (!p(*rpos)) { if (rpos != wpos) - std::swap(*rpos, *wpos); + std::ranges::swap(*rpos, *wpos); ++wpos; } } diff --git a/src/server/bnetserver/REST/LoginHttpSession.cpp b/src/server/bnetserver/REST/LoginHttpSession.cpp index aff579de7f9..23a317d3726 100644 --- a/src/server/bnetserver/REST/LoginHttpSession.cpp +++ b/src/server/bnetserver/REST/LoginHttpSession.cpp @@ -17,126 +17,117 @@ #include "LoginHttpSession.h" #include "DatabaseEnv.h" +#include "HttpSocket.h" +#include "HttpSslSocket.h" +#include "IpBanCheckConnectionInitializer.h" #include "LoginRESTService.h" #include "SslContext.h" #include "Util.h" +#include <boost/container/static_vector.hpp> -namespace Battlenet -{ -template<template<typename> typename SocketImpl> -LoginHttpSession<SocketImpl>::LoginHttpSession(boost::asio::ip::tcp::socket&& socket, LoginHttpSessionWrapper& owner) - : BaseSocket(std::move(socket), SslContext::instance()), _owner(owner) -{ -} - -template<template<typename> typename SocketImpl> -LoginHttpSession<SocketImpl>::~LoginHttpSession() = default; - -template<template<typename> typename SocketImpl> -void LoginHttpSession<SocketImpl>::Start() -{ - std::string ip_address = this->GetRemoteIpAddress().to_string(); - TC_LOG_TRACE("server.http.session", "{} Accepted connection", this->GetClientInfo()); - - // Verify that this IP is not in the ip_banned table - LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS)); - - LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); - stmt->setString(0, ip_address); - - this->_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt) - .WithPreparedCallback([sess = this->shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); })); -} - -template<template<typename> typename SocketImpl> -void LoginHttpSession<SocketImpl>::CheckIpCallback(PreparedQueryResult result) -{ - if (result) - { - bool banned = false; - do - { - Field* fields = result->Fetch(); - if (fields[0].GetUInt64() != 0) - banned = true; - - } while (result->NextRow()); - - if (banned) - { - TC_LOG_DEBUG("server.http.session", "{} tries to log in using banned IP!", this->GetClientInfo()); - this->CloseSocket(); - return; - } - } - - if constexpr (std::is_same_v<BaseSocket, Trinity::Net::Http::SslSocket<LoginHttpSession<Trinity::Net::Http::SslSocket>>>) - { - this->AsyncHandshake(); - } - else - { - this->ResetHttpParser(); - this->AsyncRead(); - } -} - -template<template<typename> typename SocketImpl> -Trinity::Net::Http::RequestHandlerResult LoginHttpSession<SocketImpl>::RequestHandler(Trinity::Net::Http::RequestContext& context) +namespace { - return sLoginService.HandleRequest(_owner.shared_from_this(), context); -} - std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context, boost::asio::ip::address const& remoteAddress) { using namespace std::string_literals; - std::shared_ptr<Trinity::Net::Http::SessionState> state; - auto cookieItr = context.request.find(boost::beast::http::field::cookie); if (cookieItr != context.request.end()) { std::vector<std::string_view> cookies = Trinity::Tokenize(Trinity::Net::Http::ToStdStringView(cookieItr->value()), ';', false); - std::size_t eq = 0; - auto sessionIdItr = std::find_if(cookies.begin(), cookies.end(), [&](std::string_view cookie) + auto sessionIdItr = std::ranges::find_if(cookies, [](std::string_view const& cookie) { - std::string_view name = cookie; - eq = cookie.find('='); - if (eq != std::string_view::npos) - name = cookie.substr(0, eq); - - return name == LoginHttpSessionWrapper::SESSION_ID_COOKIE; + return cookie.length() > Battlenet::LoginHttpSession::SESSION_ID_COOKIE.length() + && cookie.starts_with(Battlenet::LoginHttpSession::SESSION_ID_COOKIE); }); if (sessionIdItr != cookies.end()) { - std::string_view value = sessionIdItr->substr(eq + 1); - state = sLoginService.FindAndRefreshSessionState(value, remoteAddress); + sessionIdItr->remove_prefix(Battlenet::LoginHttpSession::SESSION_ID_COOKIE.length()); + state = sLoginService.FindAndRefreshSessionState(*sessionIdItr, remoteAddress); } } - if (!state) { state = sLoginService.CreateNewSessionState(remoteAddress); std::string_view host = Trinity::Net::Http::ToStdStringView(context.request[boost::beast::http::field::host]); - if (std::size_t port = host.find(':'); port != std::string_view::npos) + if (size_t port = host.find(':'); port != std::string_view::npos) host.remove_suffix(host.length() - port); - context.response.insert(boost::beast::http::field::set_cookie, Trinity::StringFormat("{}={}; Path=/bnetserver; Domain={}; Secure; HttpOnly; SameSite=None", - LoginHttpSessionWrapper::SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host)); + context.response.insert(boost::beast::http::field::set_cookie, Trinity::StringFormat("{}{}; Path=/bnetserver; Domain={}; Secure; HttpOnly; SameSite=None", + Battlenet::LoginHttpSession::SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host)); } - return state; } -template class LoginHttpSession<Trinity::Net::Http::SslSocket>; -template class LoginHttpSession<Trinity::Net::Http::Socket>; +template<typename SocketImpl> +class LoginHttpSocketImpl final : public SocketImpl +{ +public: + using BaseSocket = SocketImpl; + + explicit LoginHttpSocketImpl(Trinity::Net::IoContextTcpSocket&& socket, Battlenet::LoginHttpSession& owner) + : BaseSocket(std::move(socket)), _owner(owner) + { + } + + LoginHttpSocketImpl(LoginHttpSocketImpl const&) = delete; + LoginHttpSocketImpl(LoginHttpSocketImpl&&) = delete; + LoginHttpSocketImpl& operator=(LoginHttpSocketImpl const&) = delete; + LoginHttpSocketImpl& operator=(LoginHttpSocketImpl&&) = delete; + + ~LoginHttpSocketImpl() = default; + + void Start() override + { + // build initializer chain + boost::container::static_vector<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 4> initializers; + + initializers.stable_emplace_back(std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<BaseSocket>>(this)); + + if constexpr (std::is_same_v<BaseSocket, Trinity::Net::Http::SslSocket>) + initializers.stable_emplace_back(std::make_shared<Trinity::Net::SslHandshakeConnectionInitializer<BaseSocket>>(this)); + + initializers.stable_emplace_back(std::make_shared<Trinity::Net::Http::HttpConnectionInitializer<BaseSocket>>(this)); + initializers.stable_emplace_back(std::make_shared<Trinity::Net::ReadConnectionInitializer<BaseSocket>>(this)); + + Trinity::Net::SocketConnectionInitializer::SetupChain(std::span(initializers.data(), initializers.size()))->Start(); + } + + Trinity::Net::Http::RequestHandlerResult RequestHandler(Trinity::Net::Http::RequestContext& context) override + { + return sLoginService.HandleRequest(_owner.shared_from_this(), context); + } + +protected: + std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context) const override + { + return ::ObtainSessionState(context, this->GetRemoteIpAddress()); + } + + Battlenet::LoginHttpSession& _owner; +}; + +template<> +LoginHttpSocketImpl<Trinity::Net::Http::SslSocket>::LoginHttpSocketImpl(Trinity::Net::IoContextTcpSocket&& socket, Battlenet::LoginHttpSession& owner) + : BaseSocket(std::move(socket), Battlenet::SslContext::instance()), _owner(owner) +{ +} +} + +namespace Battlenet +{ +LoginHttpSession::LoginHttpSession(Trinity::Net::IoContextTcpSocket&& socket) + : _socket(!SslContext::UsesDevWildcardCertificate() + ? std::shared_ptr<AbstractSocket>(std::make_shared<LoginHttpSocketImpl<Trinity::Net::Http::SslSocket>>(std::move(socket), *this)) + : std::shared_ptr<AbstractSocket>(std::make_shared<LoginHttpSocketImpl<Trinity::Net::Http::Socket>>(std::move(socket), *this))) +{ +} -LoginHttpSessionWrapper::LoginHttpSessionWrapper(boost::asio::ip::tcp::socket&& socket) +void LoginHttpSession::Start() { - if (!SslContext::UsesDevWildcardCertificate()) - _socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::SslSocket>>(std::move(socket), *this); - else - _socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::Socket>>(std::move(socket), *this); + TC_LOG_TRACE("server.http.session", "{} Accepted connection", GetClientInfo()); + + return _socket->Start(); } } diff --git a/src/server/bnetserver/REST/LoginHttpSession.h b/src/server/bnetserver/REST/LoginHttpSession.h index 6bd1ec113d0..c15442f9e0c 100644 --- a/src/server/bnetserver/REST/LoginHttpSession.h +++ b/src/server/bnetserver/REST/LoginHttpSession.h @@ -18,10 +18,8 @@ #ifndef TRINITYCORE_LOGIN_HTTP_SESSION_H #define TRINITYCORE_LOGIN_HTTP_SESSION_H -#include "HttpSocket.h" -#include "HttpSslSocket.h" +#include "BaseHttpSocket.h" #include "SRP6.h" -#include <variant> namespace Battlenet { @@ -30,59 +28,27 @@ struct LoginSessionState : public Trinity::Net::Http::SessionState std::unique_ptr<Trinity::Crypto::SRP::BnetSRP6Base> Srp; }; -class LoginHttpSessionWrapper; -std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context, boost::asio::ip::address const& remoteAddress); - -template<template<typename> typename SocketImpl> -class LoginHttpSession : public SocketImpl<LoginHttpSession<SocketImpl>> +class LoginHttpSession : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSession> { - using BaseSocket = SocketImpl<LoginHttpSession<SocketImpl>>; - public: - explicit LoginHttpSession(boost::asio::ip::tcp::socket&& socket, LoginHttpSessionWrapper& owner); - ~LoginHttpSession(); - - void Start() override; + static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID="; - void CheckIpCallback(PreparedQueryResult result); + explicit LoginHttpSession(Trinity::Net::IoContextTcpSocket&& socket); - Trinity::Net::Http::RequestHandlerResult RequestHandler(Trinity::Net::Http::RequestContext& context) override; - - LoginSessionState* GetSessionState() const { return static_cast<LoginSessionState*>(this->_state.get()); } - -protected: - std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context) const override - { - return Battlenet::ObtainSessionState(context, this->GetRemoteIpAddress()); - } - - LoginHttpSessionWrapper& _owner; -}; - -class LoginHttpSessionWrapper : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSessionWrapper> -{ -public: - static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID"; - - explicit LoginHttpSessionWrapper(boost::asio::ip::tcp::socket&& socket); - - void Start() { return std::visit([&](auto&& socket) { return socket->Start(); }, _socket); } - bool Update() { return std::visit([&](auto&& socket) { return socket->Update(); }, _socket); } - boost::asio::ip::address GetRemoteIpAddress() const { return std::visit([&](auto&& socket) { return socket->GetRemoteIpAddress(); }, _socket); } - bool IsOpen() const { return std::visit([&](auto&& socket) { return socket->IsOpen(); }, _socket); } - void CloseSocket() { return std::visit([&](auto&& socket) { return socket->CloseSocket(); }, _socket); } + void Start() override; + bool Update() override { return _socket->Update(); } + boost::asio::ip::address const& GetRemoteIpAddress() const override { return _socket->GetRemoteIpAddress(); } + bool IsOpen() const override { return _socket->IsOpen(); } + void CloseSocket() override { return _socket->CloseSocket(); } - void SendResponse(Trinity::Net::Http::RequestContext& context) override { return std::visit([&](auto&& socket) { return socket->SendResponse(context); }, _socket); } - void QueueQuery(QueryCallback&& queryCallback) override { return std::visit([&](auto&& socket) { return socket->QueueQuery(std::move(queryCallback)); }, _socket); } - std::string GetClientInfo() const override { return std::visit([&](auto&& socket) { return socket->GetClientInfo(); }, _socket); } - Optional<boost::uuids::uuid> GetSessionId() const override { return std::visit([&](auto&& socket) { return socket->GetSessionId(); }, _socket); } - LoginSessionState* GetSessionState() const { return std::visit([&](auto&& socket) { return socket->GetSessionState(); }, _socket); } + void SendResponse(Trinity::Net::Http::RequestContext& context) override { return _socket->SendResponse(context); } + void QueueQuery(QueryCallback&& queryCallback) override { return _socket->QueueQuery(std::move(queryCallback)); } + std::string GetClientInfo() const override { return _socket->GetClientInfo(); } + LoginSessionState* GetSessionState() const override { return static_cast<LoginSessionState*>(_socket->GetSessionState()); } private: - std::variant< - std::shared_ptr<LoginHttpSession<Trinity::Net::Http::SslSocket>>, - std::shared_ptr<LoginHttpSession<Trinity::Net::Http::Socket>> - > _socket; + std::shared_ptr<Trinity::Net::Http::AbstractSocket> _socket; }; } + #endif // TRINITYCORE_LOGIN_HTTP_SESSION_H diff --git a/src/server/bnetserver/REST/LoginRESTService.cpp b/src/server/bnetserver/REST/LoginRESTService.cpp index 12f6cf63712..1afcc5c88bb 100644 --- a/src/server/bnetserver/REST/LoginRESTService.cpp +++ b/src/server/bnetserver/REST/LoginRESTService.cpp @@ -44,32 +44,32 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st using Trinity::Net::Http::RequestHandlerFlag; - RegisterHandler(boost::beast::http::verb::get, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::get, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandleGetForm(std::move(session), context); }); - RegisterHandler(boost::beast::http::verb::get, "/bnetserver/gameAccounts/", [](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::get, "/bnetserver/gameAccounts/", [](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandleGetGameAccounts(std::move(session), context); }); - RegisterHandler(boost::beast::http::verb::get, "/bnetserver/portal/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::get, "/bnetserver/portal/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandleGetPortal(std::move(session), context); }); - RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandlePostLogin(std::move(session), context); }, RequestHandlerFlag::DoNotLogRequestContent); - RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/srp/", [](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/srp/", [](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandlePostLoginSrpChallenge(std::move(session), context); }); - RegisterHandler(boost::beast::http::verb::post, "/bnetserver/refreshLoginTicket/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::post, "/bnetserver/refreshLoginTicket/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandlePostRefreshLoginTicket(std::move(session), context); }); @@ -121,7 +121,10 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st MigrateLegacyPasswordHashes(); - _acceptor->AsyncAcceptWithCallback<&LoginRESTService::OnSocketAccept>(); + _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) + { + OnSocketOpen(std::move(sock), threadIndex); + }); return true; } @@ -165,7 +168,7 @@ std::string LoginRESTService::ExtractAuthorization(HttpRequest const& request) return ticket; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const +LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const { JSON::Login::FormInputs form = _formInputs; form.set_srp_url(Trinity::StringFormat("http{}://{}:{}/bnetserver/login/srp/", !SslContext::UsesDevWildcardCertificate() ? "s" : "", @@ -176,7 +179,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shar return RequestHandlerResult::Handled; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) +LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { std::string ticket = ExtractAuthorization(context.request); if (ticket.empty()) @@ -225,14 +228,14 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(s return RequestHandlerResult::Async; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetPortal(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const +LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetPortal(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const { context.response.set(boost::beast::http::field::content_type, "text/plain"); context.response.body() = Trinity::StringFormat("{}:{}", GetHostnameForClient(session->GetRemoteIpAddress()), sConfigMgr->GetIntDefault("BattlenetPort", 1119)); return RequestHandlerResult::Handled; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const +LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const { std::shared_ptr<JSON::Login::LoginForm> loginForm = std::make_shared<JSON::Login::LoginForm>(); if (!::JSON::Deserialize(context.request.body(), loginForm.get())) @@ -396,7 +399,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::sh return RequestHandlerResult::Async; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) +LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { JSON::Login::LoginForm loginForm; if (!::JSON::Deserialize(context.request.body(), &loginForm)) @@ -483,7 +486,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChall return RequestHandlerResult::Async; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const +LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const { std::string ticket = ExtractAuthorization(context.request); if (ticket.empty()) @@ -551,11 +554,6 @@ std::shared_ptr<Trinity::Net::Http::SessionState> LoginRESTService::CreateNewSes return state; } -void LoginRESTService::OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex) -{ - sLoginService.OnSocketOpen(std::move(sock), threadIndex); -} - void LoginRESTService::MigrateLegacyPasswordHashes() const { if (!LoginDatabase.Query("SELECT 1 FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = SCHEMA() AND TABLE_NAME = 'battlenet_accounts' AND COLUMN_NAME = 'sha_pass_hash'")) diff --git a/src/server/bnetserver/REST/LoginRESTService.h b/src/server/bnetserver/REST/LoginRESTService.h index 079bb6fca9c..517f5e295b7 100644 --- a/src/server/bnetserver/REST/LoginRESTService.h +++ b/src/server/bnetserver/REST/LoginRESTService.h @@ -15,8 +15,8 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef LoginRESTService_h__ -#define LoginRESTService_h__ +#ifndef TRINITYCORE_LOGIN_REST_SERVICE_H +#define TRINITYCORE_LOGIN_REST_SERVICE_H #include "HttpService.h" #include "Login.pb.h" @@ -42,7 +42,7 @@ enum class BanMode BAN_ACCOUNT = 1 }; -class LoginRESTService : public Trinity::Net::Http::HttpService<LoginHttpSessionWrapper> +class LoginRESTService : public Trinity::Net::Http::HttpService<LoginHttpSession> { public: using RequestHandlerResult = Trinity::Net::Http::RequestHandlerResult; @@ -63,17 +63,15 @@ public: std::shared_ptr<Trinity::Net::Http::SessionState> CreateNewSessionState(boost::asio::ip::address const& address) override; private: - static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex); - static std::string ExtractAuthorization(HttpRequest const& request); - RequestHandlerResult HandleGetForm(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const; - static RequestHandlerResult HandleGetGameAccounts(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context); - RequestHandlerResult HandleGetPortal(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const; + RequestHandlerResult HandleGetForm(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const; + static RequestHandlerResult HandleGetGameAccounts(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context); + RequestHandlerResult HandleGetPortal(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const; - RequestHandlerResult HandlePostLogin(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const; - static RequestHandlerResult HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context); - RequestHandlerResult HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const; + RequestHandlerResult HandlePostLogin(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const; + static RequestHandlerResult HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context); + RequestHandlerResult HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const; static std::unique_ptr<Trinity::Crypto::SRP::BnetSRP6Base> CreateSrpImplementation(SrpVersion version, SrpHashFunction hashFunction, std::string const& username, Trinity::Crypto::SRP::Salt const& salt, Trinity::Crypto::SRP::Verifier const& verifier); @@ -90,4 +88,4 @@ private: #define sLoginService Battlenet::LoginRESTService::Instance() -#endif // LoginRESTService_h__ +#endif // TRINITYCORE_LOGIN_REST_SERVICE_H diff --git a/src/server/bnetserver/Server/Session.cpp b/src/server/bnetserver/Server/Session.cpp index c3a8ae6f211..7b9d6f45a3e 100644 --- a/src/server/bnetserver/Server/Session.cpp +++ b/src/server/bnetserver/Server/Session.cpp @@ -23,6 +23,7 @@ #include "Errors.h" #include "Hash.h" #include "IPLocation.h" +#include "IpBanCheckConnectionInitializer.h" #include "LoginRESTService.h" #include "MapUtils.h" #include "ProtobufJSON.h" @@ -74,7 +75,7 @@ void Battlenet::Session::GameAccountInfo::LoadResult(Field const* fields) DisplayName = Name; } -Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSocket(std::move(socket), SslContext::instance()), +Battlenet::Session::Session(Trinity::Net::IoContextTcpSocket&& socket) : BaseSocket(std::move(socket), SslContext::instance()), _accountInfo(new AccountInfo()), _gameAccountInfo(nullptr), _locale(), _os(), _build(0), _clientInfo(), _timezoneOffset(0min), _ipCountry(), _clientSecret(), _authed(false), _requestToken(0) { @@ -83,54 +84,24 @@ Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSo Battlenet::Session::~Session() = default; -void Battlenet::Session::AsyncHandshake() -{ - underlying_stream().async_handshake(boost::asio::ssl::stream_base::server, - [sess = shared_from_this()](boost::system::error_code const& error) { sess->HandshakeHandler(error); }); -} - void Battlenet::Session::Start() { - std::string ip_address = GetRemoteIpAddress().to_string(); TC_LOG_TRACE("session", "{} Accepted connection", GetClientInfo()); - // Verify that this IP is not in the ip_banned table - LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS)); + // build initializer chain + std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers = + { { + std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<Session>>(this), + std::make_shared<Trinity::Net::SslHandshakeConnectionInitializer<Session>>(this), + std::make_shared<Trinity::Net::ReadConnectionInitializer<Session>>(this), + } }; - LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); - stmt->setString(0, ip_address); - - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt) - .WithPreparedCallback([sess = shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); })); -} - -void Battlenet::Session::CheckIpCallback(PreparedQueryResult result) -{ - if (result) - { - bool banned = false; - do - { - Field* fields = result->Fetch(); - if (fields[0].GetUInt64() != 0) - banned = true; - - } while (result->NextRow()); - - if (banned) - { - TC_LOG_DEBUG("session", "{} tries to log in using banned IP!", GetClientInfo()); - CloseSocket(); - return; - } - } - - AsyncHandshake(); + Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start(); } bool Battlenet::Session::Update() { - if (!BattlenetSocket::Update()) + if (!BaseSocket::Update()) return false; _queryProcessor.ProcessReadyCallbacks(); @@ -211,6 +182,11 @@ void Battlenet::Session::SendRequest(uint32 serviceHash, uint32 methodId, pb::Me AsyncWrite(&packet); } +void Battlenet::Session::QueueQuery(QueryCallback&& queryCallback) +{ + _queryProcessor.AddCallback(std::move(queryCallback)); +} + uint32 Battlenet::Session::HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation) { if (logonRequest->program() != "WoW") @@ -292,7 +268,7 @@ uint32 Battlenet::Session::HandleGenerateWebCredentials(authentication::v1::Gene LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_BNET_EXISTING_AUTHENTICATION_BY_ID); stmt->setUInt32(0, _accountInfo->Id); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result) + QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result) { // just send existing credentials back (not the best but it works for now with them being stored in db) Battlenet::Services::Authentication asyncContinuationService(this); @@ -314,7 +290,7 @@ uint32 Battlenet::Session::VerifyWebCredentials(std::string const& webCredential std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)> asyncContinuation = std::move(continuation); std::shared_ptr<AccountInfo> accountInfo = std::make_shared<AccountInfo>(); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result) + QueueQuery(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result) { Battlenet::Services::Authentication asyncContinuationService(this); NoData response; @@ -713,20 +689,8 @@ uint32 Battlenet::Session::HandleGetAllValuesForAttribute(game_utilities::v1::Ge return ERROR_RPC_NOT_IMPLEMENTED; } -void Battlenet::Session::HandshakeHandler(boost::system::error_code const& error) -{ - if (error) - { - TC_LOG_ERROR("session", "{} SSL Handshake failed {}", GetClientInfo(), error.message()); - CloseSocket(); - return; - } - - AsyncRead(); -} - template<bool(Battlenet::Session::*processMethod)(), MessageBuffer Battlenet::Session::*outputBuffer> -inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer) +static inline Optional<Trinity::Net::SocketReadCallbackResult> PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer) { MessageBuffer& buffer = session->*outputBuffer; @@ -738,46 +702,45 @@ inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inp buffer.Write(inputBuffer.GetReadPointer(), readDataSize); inputBuffer.ReadCompleted(readDataSize); } + else + return { }; // go to next buffer if (buffer.GetRemainingSpace() > 0) { // Couldn't receive the whole data this time. ASSERT(inputBuffer.GetActiveSize() == 0); - return false; + return Trinity::Net::SocketReadCallbackResult::KeepReading; } // just received fresh new payload if (!(session->*processMethod)()) { session->CloseSocket(); - return false; + return Trinity::Net::SocketReadCallbackResult::Stop; } - return true; + return { }; // go to next buffer } -void Battlenet::Session::ReadHandler() +Trinity::Net::SocketReadCallbackResult Battlenet::Session::ReadHandler() { - if (!IsOpen()) - return; - MessageBuffer& packet = GetReadBuffer(); while (packet.GetActiveSize() > 0) { - if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderLengthHandler, &Battlenet::Session::_headerLengthBuffer>(this, packet)) - break; + if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderLengthHandler, &Session::_headerLengthBuffer>(this, packet)) + return *partialResult; - if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderHandler, &Battlenet::Session::_headerBuffer>(this, packet)) - break; + if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderHandler, &Session::_headerBuffer>(this, packet)) + return *partialResult; - if (!PartialProcessPacket<&Battlenet::Session::ReadDataHandler, &Battlenet::Session::_packetBuffer>(this, packet)) - break; + if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadDataHandler, &Session::_packetBuffer>(this, packet)) + return *partialResult; _headerLengthBuffer.Reset(); _headerBuffer.Reset(); } - AsyncRead(); + return Trinity::Net::SocketReadCallbackResult::KeepReading; } bool Battlenet::Session::ReadHeaderLengthHandler() @@ -835,5 +798,5 @@ std::string Battlenet::Session::GetClientInfo() const stream << ']'; - return stream.str(); + return std::move(stream).str(); } diff --git a/src/server/bnetserver/Server/Session.h b/src/server/bnetserver/Server/Session.h index 8a8b1f0a15b..faf82115f18 100644 --- a/src/server/bnetserver/Server/Session.h +++ b/src/server/bnetserver/Server/Session.h @@ -15,8 +15,8 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef Session_h__ -#define Session_h__ +#ifndef TRINITYCORE_SESSION_H +#define TRINITYCORE_SESSION_H #include "AsyncCallbackProcessor.h" #include "ClientBuildInfo.h" @@ -24,7 +24,7 @@ #include "QueryResult.h" #include "Realm.h" #include "Socket.h" -#include "SslSocket.h" +#include "SslStream.h" #include <boost/asio/ip/tcp.hpp> #include <google/protobuf/message.h> #include <memory> @@ -65,9 +65,9 @@ using namespace bgs::protocol; namespace Battlenet { - class Session : public Socket<Session, SslSocket<>> + class Session final : public Trinity::Net::Socket<Trinity::Net::SslStream<>> { - typedef Socket<Session, SslSocket<>> BattlenetSocket; + using BaseSocket = Socket<Trinity::Net::SslStream<>>; public: struct LastPlayedCharacterInfo @@ -110,7 +110,7 @@ namespace Battlenet std::unordered_map<uint32, GameAccountInfo> GameAccounts; }; - explicit Session(boost::asio::ip::tcp::socket&& socket); + explicit Session(Trinity::Net::IoContextTcpSocket&& socket); ~Session(); void Start() override; @@ -130,6 +130,8 @@ namespace Battlenet void SendRequest(uint32 serviceHash, uint32 methodId, pb::Message const* request); + void QueueQuery(QueryCallback&& queryCallback); + uint32 HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation); uint32 HandleVerifyWebCredentials(authentication::v1::VerifyWebCredentialsRequest const* verifyWebCredentialsRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation); uint32 HandleGenerateWebCredentials(authentication::v1::GenerateWebCredentialsRequest const* request, std::function<void(ServiceBase*, uint32, google::protobuf::Message const*)>& continuation); @@ -140,9 +142,9 @@ namespace Battlenet std::string GetClientInfo() const; + Trinity::Net::SocketReadCallbackResult ReadHandler() override; + protected: - void HandshakeHandler(boost::system::error_code const& error); - void ReadHandler() override; bool ReadHeaderLengthHandler(); bool ReadHeaderHandler(); bool ReadDataHandler(); @@ -150,10 +152,6 @@ namespace Battlenet private: void AsyncWrite(MessageBuffer* packet); - void AsyncHandshake(); - - void CheckIpCallback(PreparedQueryResult result); - uint32 VerifyWebCredentials(std::string const& webCredentials, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation); typedef uint32(Session::*ClientRequestHandler)(std::unordered_map<std::string, Variant const*> const&, game_utilities::v1::ClientResponse*); @@ -190,4 +188,4 @@ namespace Battlenet }; } -#endif // Session_h__ +#endif // TRINITYCORE_SESSION_H diff --git a/src/server/bnetserver/Server/SessionManager.cpp b/src/server/bnetserver/Server/SessionManager.cpp index cb56972a9c2..4c5b532ee60 100644 --- a/src/server/bnetserver/Server/SessionManager.cpp +++ b/src/server/bnetserver/Server/SessionManager.cpp @@ -16,7 +16,6 @@ */ #include "SessionManager.h" -#include "DatabaseEnv.h" #include "Util.h" bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) @@ -24,18 +23,16 @@ bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount)) return false; - _acceptor->AsyncAcceptWithCallback<&OnSocketAccept>(); + _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) + { + OnSocketOpen(std::move(sock), threadIndex); + }); return true; } -NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() const +Trinity::Net::NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() const { - return new NetworkThread<Session>[GetNetworkThreadCount()]; -} - -void Battlenet::SessionManager::OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex) -{ - sSessionMgr.OnSocketOpen(std::move(sock), threadIndex); + return new Trinity::Net::NetworkThread<Session>[GetNetworkThreadCount()]; } Battlenet::SessionManager& Battlenet::SessionManager::Instance() diff --git a/src/server/bnetserver/Server/SessionManager.h b/src/server/bnetserver/Server/SessionManager.h index c635122c977..528ece8739e 100644 --- a/src/server/bnetserver/Server/SessionManager.h +++ b/src/server/bnetserver/Server/SessionManager.h @@ -15,15 +15,15 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef SessionManager_h__ -#define SessionManager_h__ +#ifndef TRINITYCORE_SESSION_MANAGER_H +#define TRINITYCORE_SESSION_MANAGER_H #include "SocketMgr.h" #include "Session.h" namespace Battlenet { - class SessionManager : public SocketMgr<Session> + class SessionManager : public Trinity::Net::SocketMgr<Session> { typedef SocketMgr<Session> BaseSocketMgr; @@ -33,13 +33,10 @@ namespace Battlenet bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount = 1) override; protected: - NetworkThread<Session>* CreateThreads() const override; - - private: - static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex); + Trinity::Net::NetworkThread<Session>* CreateThreads() const override; }; } #define sSessionMgr Battlenet::SessionManager::Instance() -#endif // SessionManager_h__ +#endif // TRINITYCORE_SESSION_MANAGER_H diff --git a/src/server/game/Scripting/ScriptMgr.cpp b/src/server/game/Scripting/ScriptMgr.cpp index 99f25980a0d..ac1cb69b6a1 100644 --- a/src/server/game/Scripting/ScriptMgr.cpp +++ b/src/server/game/Scripting/ScriptMgr.cpp @@ -1474,14 +1474,14 @@ void ScriptMgr::OnNetworkStop() FOREACH_SCRIPT(ServerScript)->OnNetworkStop(); } -void ScriptMgr::OnSocketOpen(std::shared_ptr<WorldSocket> socket) +void ScriptMgr::OnSocketOpen(std::shared_ptr<WorldSocket> const& socket) { ASSERT(socket); FOREACH_SCRIPT(ServerScript)->OnSocketOpen(socket); } -void ScriptMgr::OnSocketClose(std::shared_ptr<WorldSocket> socket) +void ScriptMgr::OnSocketClose(std::shared_ptr<WorldSocket> const& socket) { ASSERT(socket); diff --git a/src/server/game/Scripting/ScriptMgr.h b/src/server/game/Scripting/ScriptMgr.h index d7746b01e04..e3b7dc63981 100644 --- a/src/server/game/Scripting/ScriptMgr.h +++ b/src/server/game/Scripting/ScriptMgr.h @@ -1067,8 +1067,8 @@ class TC_GAME_API ScriptMgr void OnNetworkStart(); void OnNetworkStop(); - void OnSocketOpen(std::shared_ptr<WorldSocket> socket); - void OnSocketClose(std::shared_ptr<WorldSocket> socket); + void OnSocketOpen(std::shared_ptr<WorldSocket> const& socket); + void OnSocketClose(std::shared_ptr<WorldSocket> const& socket); void OnPacketReceive(WorldSession* session, WorldPacket const& packet); void OnPacketSend(WorldSession* session, WorldPacket const& packet); diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp index 4b2ad6c3c5c..2b868ff3884 100644 --- a/src/server/game/Server/WorldSocket.cpp +++ b/src/server/game/Server/WorldSocket.cpp @@ -26,6 +26,7 @@ #include "GameTime.h" #include "HMAC.h" #include "IPLocation.h" +#include "IpBanCheckConnectionInitializer.h" #include "PacketLog.h" #include "ProtobufJSON.h" #include "RealmList.h" @@ -33,6 +34,7 @@ #include "RealmList.pb.h" #include "ScriptMgr.h" #include "SessionKeyGenerator.h" +#include "SslStream.h" #include "World.h" #include "WorldPacket.h" #include "WorldSession.h" @@ -49,8 +51,6 @@ struct CompressedWorldPacket #pragma pack(pop) -std::string const WorldSocket::ServerConnectionInitialize("WORLD OF WARCRAFT CONNECTION - SERVER TO CLIENT - V2"); -std::string const WorldSocket::ClientConnectionInitialize("WORLD OF WARCRAFT CONNECTION - CLIENT TO SERVER - V2"); uint32 const WorldSocket::MinSizeForCompression = 0x400; std::array<uint8, 32> const WorldSocket::AuthCheckSeed = { 0xDE, 0x3A, 0x2A, 0x8E, 0x6B, 0x89, 0x52, 0x66, 0x88, 0x9D, 0x7E, 0x7A, 0x77, 0x1D, 0x5D, 0x1F, @@ -62,7 +62,7 @@ std::array<uint8, 32> const WorldSocket::ContinuedSessionSeed = { 0x56, 0x5C, 0x std::array<uint8, 32> const WorldSocket::EncryptionKeySeed = { 0x71, 0xC9, 0xED, 0x5A, 0xA7, 0x0E, 0x4D, 0xFF, 0x4C, 0x36, 0xA6, 0x5A, 0x3E, 0x46, 0x8A, 0x4A, 0x5D, 0xA1, 0x48, 0xC8, 0x30, 0x47, 0x4A, 0xDE, 0xF6, 0x0D, 0x6C, 0xBE, 0x6F, 0xE4, 0x55, 0x73 }; -WorldSocket::WorldSocket(boost::asio::ip::tcp::socket&& socket) : Socket(std::move(socket)), +WorldSocket::WorldSocket(Trinity::Net::IoContextTcpSocket&& socket) : BaseSocket(std::move(socket)), _type(CONNECTION_TYPE_REALM), _key(0), _serverChallenge(), _sessionKey(), _encryptKey(), _OverSpeedPings(0), _worldSession(nullptr), _authed(false), _canRequestHotfixes(true), _headerBuffer(sizeof(IncomingPacketHeader)), _sendBufferSize(4096), _compressionStream(nullptr) { @@ -77,127 +77,127 @@ WorldSocket::~WorldSocket() } } -void WorldSocket::Start() +struct WorldSocketProtocolInitializer final : Trinity::Net::SocketConnectionInitializer { - std::string ip_address = GetRemoteIpAddress().to_string(); - LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); - stmt->setString(0, ip_address); + static constexpr std::string_view ServerConnectionInitialize = "WORLD OF WARCRAFT CONNECTION - SERVER TO CLIENT - V2\n"; + static constexpr std::string_view ClientConnectionInitialize = "WORLD OF WARCRAFT CONNECTION - CLIENT TO SERVER - V2\n"; - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([self = shared_from_this()](PreparedQueryResult result) - { - self->CheckIpCallback(std::move(result)); - })); -} + explicit WorldSocketProtocolInitializer(WorldSocket* socket) : _socket(socket) { } -void WorldSocket::CheckIpCallback(PreparedQueryResult result) -{ - if (result) + void Start() override { - bool banned = false; - do - { - Field* fields = result->Fetch(); - if (fields[0].GetUInt64() != 0) - banned = true; + _packetBuffer.Resize(ClientConnectionInitialize.length()); - } while (result->NextRow()); + AsyncRead(); - if (banned) - { - TC_LOG_ERROR("network", "WorldSocket::CheckIpCallback: Sent Auth Response (IP {} banned).", GetRemoteIpAddress().to_string()); - DelayedCloseSocket(); - return; - } + MessageBuffer initializer; + initializer.Write(ServerConnectionInitialize.data(), ServerConnectionInitialize.length()); + + // - IoContext.run thread, safe. + _socket->QueuePacket(std::move(initializer)); } - _packetBuffer.Resize(ClientConnectionInitialize.length() + 1); + void AsyncRead() + { + _socket->AsyncRead( + [socketRef = _socket->weak_from_this(), self = static_pointer_cast<WorldSocketProtocolInitializer>(this->shared_from_this())] + { + if (!socketRef.expired()) + return self->ReadHandler(); + + return Trinity::Net::SocketReadCallbackResult::Stop; + }); + } - AsyncReadWithCallback(&WorldSocket::InitializeHandler); + Trinity::Net::SocketReadCallbackResult ReadHandler(); - MessageBuffer initializer; - initializer.Write(ServerConnectionInitialize.c_str(), ServerConnectionInitialize.length()); - initializer.Write("\n", 1); + void HandleDataReady(); - // - IoContext.run thread, safe. - QueuePacket(std::move(initializer)); +private: + WorldSocket* _socket; + MessageBuffer _packetBuffer; +}; + +void WorldSocket::Start() +{ + // build initializer chain + std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers = + { { + std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<WorldSocket>>(this), + std::make_shared<WorldSocketProtocolInitializer>(this), + std::make_shared<Trinity::Net::ReadConnectionInitializer<WorldSocket>>(this), + } }; + + Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start(); } -void WorldSocket::InitializeHandler(boost::system::error_code const& error, std::size_t transferedBytes) +Trinity::Net::SocketReadCallbackResult WorldSocketProtocolInitializer::ReadHandler() { - if (error) + MessageBuffer& packet = _socket->GetReadBuffer(); + if (packet.GetActiveSize() > 0 && _packetBuffer.GetRemainingSpace() > 0) { - CloseSocket(); - return; + // need to receive the header + std::size_t readHeaderSize = std::min(packet.GetActiveSize(), _packetBuffer.GetRemainingSpace()); + _packetBuffer.Write(packet.GetReadPointer(), readHeaderSize); + packet.ReadCompleted(readHeaderSize); + + if (_packetBuffer.GetRemainingSpace() == 0) + { + HandleDataReady(); + return Trinity::Net::SocketReadCallbackResult::Stop; + } + + // Couldn't receive the whole header this time. + ASSERT(packet.GetActiveSize() == 0); } - GetReadBuffer().WriteCompleted(transferedBytes); + return Trinity::Net::SocketReadCallbackResult::KeepReading; +} - MessageBuffer& packet = GetReadBuffer(); - if (packet.GetActiveSize() > 0) +void WorldSocketProtocolInitializer::HandleDataReady() +{ + try { - if (_packetBuffer.GetRemainingSpace() > 0) + ByteBuffer buffer(std::move(_packetBuffer)); + if (buffer.ReadString(ClientConnectionInitialize.length()) != ClientConnectionInitialize) { - // need to receive the header - std::size_t readHeaderSize = std::min(packet.GetActiveSize(), _packetBuffer.GetRemainingSpace()); - _packetBuffer.Write(packet.GetReadPointer(), readHeaderSize); - packet.ReadCompleted(readHeaderSize); - - if (_packetBuffer.GetRemainingSpace() > 0) - { - // Couldn't receive the whole header this time. - ASSERT(packet.GetActiveSize() == 0); - AsyncReadWithCallback(&WorldSocket::InitializeHandler); - return; - } - - try - { - ByteBuffer buffer(std::move(_packetBuffer)); - std::string initializer(buffer.ReadString(ClientConnectionInitialize.length())); - if (initializer != ClientConnectionInitialize) - { - CloseSocket(); - return; - } + _socket->CloseSocket(); + return; + } + } + catch (ByteBufferException const& ex) + { + TC_LOG_ERROR("network", "WorldSocket::InitializeHandler ByteBufferException {} occured while parsing initial packet from {}", + ex.what(), _socket->GetRemoteIpAddress().to_string()); + _socket->CloseSocket(); + return; + } - uint8 terminator; - buffer >> terminator; - if (terminator != '\n') - { - CloseSocket(); - return; - } - } - catch (ByteBufferException const& ex) - { - TC_LOG_ERROR("network", "WorldSocket::InitializeHandler ByteBufferException {} occured while parsing initial packet from {}", - ex.what(), GetRemoteIpAddress().to_string()); - CloseSocket(); - return; - } + if (!_socket->InitializeCompression()) + return; - _compressionStream = new z_stream(); - _compressionStream->zalloc = (alloc_func)nullptr; - _compressionStream->zfree = (free_func)nullptr; - _compressionStream->opaque = (voidpf)nullptr; - _compressionStream->avail_in = 0; - _compressionStream->next_in = nullptr; - int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY); - if (z_res != Z_OK) - { - CloseSocket(); - TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: {} ({})", z_res, zError(z_res)); - return; - } + _socket->SendAuthSession(); + if (next) + next->Start(); +} - _packetBuffer.Reset(); - HandleSendAuthSession(); - AsyncRead(); - return; - } +bool WorldSocket::InitializeCompression() +{ + _compressionStream = new z_stream(); + _compressionStream->zalloc = (alloc_func)nullptr; + _compressionStream->zfree = (free_func)nullptr; + _compressionStream->opaque = (voidpf)nullptr; + _compressionStream->avail_in = 0; + _compressionStream->next_in = nullptr; + int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY); + if (z_res != Z_OK) + { + CloseSocket(); + TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: {} ({})", z_res, zError(z_res)); + return false; } - AsyncReadWithCallback(&WorldSocket::InitializeHandler); + return true; } bool WorldSocket::Update() @@ -206,7 +206,7 @@ bool WorldSocket::Update() MessageBuffer buffer(_sendBufferSize); while (_bufferQueue.Dequeue(queued)) { - uint32 packetSize = queued->size() + 2 /*opcode*/; + uint32 packetSize = queued->size() + 4 /*opcode*/; if (packetSize > MinSizeForCompression && queued->NeedsEncryption()) packetSize = deflateBound(_compressionStream, packetSize) + sizeof(CompressedWorldPacket); @@ -240,7 +240,7 @@ bool WorldSocket::Update() return true; } -void WorldSocket::HandleSendAuthSession() +void WorldSocket::SendAuthSession() { Trinity::Crypto::GetRandomBytes(_serverChallenge); @@ -260,11 +260,8 @@ void WorldSocket::OnClose() } } -void WorldSocket::ReadHandler() +Trinity::Net::SocketReadCallbackResult WorldSocket::ReadHandler() { - if (!IsOpen()) - return; - MessageBuffer& packet = GetReadBuffer(); while (packet.GetActiveSize() > 0) { @@ -286,7 +283,7 @@ void WorldSocket::ReadHandler() if (!ReadHeaderHandler()) { CloseSocket(); - return; + return Trinity::Net::SocketReadCallbackResult::Stop; } } @@ -314,11 +311,16 @@ void WorldSocket::ReadHandler() if (result != ReadDataHandlerResult::WaitingForQuery) CloseSocket(); - return; + return Trinity::Net::SocketReadCallbackResult::Stop; } } - AsyncRead(); + return Trinity::Net::SocketReadCallbackResult::KeepReading; +} + +void WorldSocket::QueueQuery(QueryCallback&& queryCallback) +{ + _queryProcessor.AddCallback(std::move(queryCallback)); } void WorldSocket::SetWorldSession(WorldSession* session) @@ -510,14 +512,13 @@ WorldSocket::ReadDataHandlerResult WorldSocket::ReadDataHandler() void WorldSocket::LogOpcodeText(OpcodeClient opcode, std::unique_lock<std::mutex> const& guard) const { - if (!guard) + if (!guard || !_worldSession) { TC_LOG_TRACE("network.opcode", "C->S: {} {}", GetRemoteIpAddress().to_string(), GetOpcodeNameForLogging(opcode)); } else { - TC_LOG_TRACE("network.opcode", "C->S: {} {}", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress().to_string()), - GetOpcodeNameForLogging(opcode)); + TC_LOG_TRACE("network.opcode", "C->S: {} {}", _worldSession->GetPlayerInfo(), GetOpcodeNameForLogging(opcode)); } } @@ -688,7 +689,7 @@ void WorldSocket::HandleAuthSession(std::shared_ptr<WorldPackets::Auth::AuthSess stmt->setInt32(0, int32(sRealmList->GetCurrentRealmId().Realm)); stmt->setString(1, joinTicket->gameaccount()); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession), joinTicket = std::move(joinTicket)](PreparedQueryResult result) mutable + QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession), joinTicket = std::move(joinTicket)](PreparedQueryResult result) mutable { HandleAuthSessionCallback(std::move(authSession), std::move(joinTicket), std::move(result)); })); @@ -898,19 +899,20 @@ void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<WorldPackets::Auth:: sScriptMgr->OnAccountLogin(account.Game.Id); _authed = true; - _worldSession = new WorldSession(account.Game.Id, std::move(*joinTicket->mutable_gameaccount()), account.BattleNet.Id, shared_from_this(), account.Game.Security, - account.Game.Expansion, mutetime, account.Game.OS, account.Game.TimezoneOffset, account.Game.Build, buildVariant, account.Game.Locale, + _worldSession = new WorldSession(account.Game.Id, std::move(*joinTicket->mutable_gameaccount()), account.BattleNet.Id, + static_pointer_cast<WorldSocket>(shared_from_this()), account.Game.Security, account.Game.Expansion, mutetime, + account.Game.OS, account.Game.TimezoneOffset, account.Game.Build, buildVariant, account.Game.Locale, account.Game.Recruiter, account.Game.IsRectuiter); // Initialize Warden system only if it is enabled by config if (wardenActive) _worldSession->InitWarden(_sessionKey); - _queryProcessor.AddCallback(_worldSession->LoadPermissionsAsync().WithPreparedCallback([this](PreparedQueryResult result) + QueueQuery(_worldSession->LoadPermissionsAsync().WithPreparedCallback([this](PreparedQueryResult result) { LoadSessionPermissionsCallback(std::move(result)); })); - AsyncRead(); + AsyncRead(Trinity::Net::InvokeReadHandlerCallback<WorldSocket>{ .Socket = this }); } void WorldSocket::LoadSessionPermissionsCallback(PreparedQueryResult result) @@ -938,7 +940,7 @@ void WorldSocket::HandleAuthContinuedSession(std::shared_ptr<WorldPackets::Auth: LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_CONTINUED_SESSION); stmt->setUInt32(0, accountId); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession)](PreparedQueryResult result) mutable + QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession)](PreparedQueryResult result) mutable { HandleAuthContinuedSessionCallback(std::move(authSession), std::move(result)); })); @@ -985,7 +987,7 @@ void WorldSocket::HandleAuthContinuedSessionCallback(std::shared_ptr<WorldPacket memcpy(_encryptKey.data(), encryptKeyGen.GetDigest().data(), 32); SendPacketAndLogOpcode(*WorldPackets::Auth::EnterEncryptedMode(_encryptKey, true).Write()); - AsyncRead(); + AsyncRead(Trinity::Net::InvokeReadHandlerCallback<WorldSocket>{ .Socket = this }); } void WorldSocket::HandleConnectToFailed(WorldPackets::Auth::ConnectToFailed& connectToFailed) @@ -1032,7 +1034,7 @@ void WorldSocket::HandleEnterEncryptedModeAck() if (_type == CONNECTION_TYPE_REALM) sWorld->AddSession(_worldSession); else - sWorld->AddInstanceSocket(shared_from_this(), _key); + sWorld->AddInstanceSocket(static_pointer_cast<WorldSocket>(shared_from_this()), _key); } void WorldSocket::SendAuthResponseError(uint32 code) diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h index 04451dc26e8..b5388532b4d 100644 --- a/src/server/game/Server/WorldSocket.h +++ b/src/server/game/Server/WorldSocket.h @@ -15,8 +15,8 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef __WORLDSOCKET_H__ -#define __WORLDSOCKET_H__ +#ifndef TRINITYCORE_WORLD_SOCKET_H +#define TRINITYCORE_WORLD_SOCKET_H #include "AsyncCallbackProcessor.h" #include "AuthDefines.h" @@ -77,7 +77,7 @@ struct PacketHeader uint32 Size; uint8 Tag[12]; - bool IsValidSize() { return Size < 0x10000; } + bool IsValidSize() const { return Size < 0x10000; } }; struct IncomingPacketHeader : PacketHeader @@ -87,10 +87,8 @@ struct IncomingPacketHeader : PacketHeader #pragma pack(pop) -class TC_GAME_API WorldSocket : public Socket<WorldSocket> +class TC_GAME_API WorldSocket final : public Trinity::Net::Socket<> { - static std::string const ServerConnectionInitialize; - static std::string const ClientConnectionInitialize; static uint32 const MinSizeForCompression; static std::array<uint8, 32> const AuthCheckSeed; @@ -98,14 +96,16 @@ class TC_GAME_API WorldSocket : public Socket<WorldSocket> static std::array<uint8, 32> const ContinuedSessionSeed; static std::array<uint8, 32> const EncryptionKeySeed; - typedef Socket<WorldSocket> BaseSocket; + using BaseSocket = Socket; public: - WorldSocket(boost::asio::ip::tcp::socket&& socket); + WorldSocket(Trinity::Net::IoContextTcpSocket&& socket); ~WorldSocket(); WorldSocket(WorldSocket const& right) = delete; + WorldSocket(WorldSocket&& right) = delete; WorldSocket& operator=(WorldSocket const& right) = delete; + WorldSocket& operator=(WorldSocket&& right) = delete; void Start() override; bool Update() override; @@ -118,9 +118,15 @@ public: void SetWorldSession(WorldSession* session); void SetSendBufferSize(std::size_t sendBufferSize) { _sendBufferSize = sendBufferSize; } -protected: void OnClose() override; - void ReadHandler() override; + Trinity::Net::SocketReadCallbackResult ReadHandler() override; + + void QueueQuery(QueryCallback&& queryCallback); + + void SendAuthSession(); + bool InitializeCompression(); + +protected: bool ReadHeaderHandler(); enum class ReadDataHandlerResult @@ -132,9 +138,6 @@ protected: ReadDataHandlerResult ReadDataHandler(); private: - void CheckIpCallback(PreparedQueryResult result); - void InitializeHandler(boost::system::error_code const& error, std::size_t transferedBytes); - /// writes network.opcode log /// accessing WorldSession is not threadsafe, only do it when holding _worldSessionLock void LogOpcodeText(OpcodeClient opcode, std::unique_lock<std::mutex> const& guard) const; @@ -143,7 +146,6 @@ private: void WritePacketToBuffer(EncryptablePacket const& packet, MessageBuffer& buffer); uint32 CompressPacket(uint8* buffer, WorldPacket const& packet); - void HandleSendAuthSession(); void HandleAuthSession(std::shared_ptr<WorldPackets::Auth::AuthSession> authSession); void HandleAuthSessionCallback(std::shared_ptr<WorldPackets::Auth::AuthSession> authSession, std::shared_ptr<JSON::RealmList::RealmJoinTicket> joinTicket, PreparedQueryResult result); diff --git a/src/server/game/Server/WorldSocketMgr.cpp b/src/server/game/Server/WorldSocketMgr.cpp index 58f242b52f8..cc44e5bad4b 100644 --- a/src/server/game/Server/WorldSocketMgr.cpp +++ b/src/server/game/Server/WorldSocketMgr.cpp @@ -22,21 +22,16 @@ #include "WorldSocket.h" #include <boost/system/error_code.hpp> -static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex) -{ - sWorldSocketMgr.OnSocketOpen(std::move(sock), threadIndex); -} - -class WorldSocketThread : public NetworkThread<WorldSocket> +class WorldSocketThread : public Trinity::Net::NetworkThread<WorldSocket> { public: - void SocketAdded(std::shared_ptr<WorldSocket> sock) override + void SocketAdded(std::shared_ptr<WorldSocket> const& sock) override { sock->SetSendBufferSize(sWorldSocketMgr.GetApplicationSendBufferSize()); sScriptMgr->OnSocketOpen(sock); } - void SocketRemoved(std::shared_ptr<WorldSocket> sock) override + void SocketRemoved(std::shared_ptr<WorldSocket>const& sock) override { sScriptMgr->OnSocketClose(sock); } @@ -75,7 +70,10 @@ bool WorldSocketMgr::StartNetwork(Trinity::Asio::IoContext& ioContext, std::stri if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount)) return false; - _acceptor->AsyncAcceptWithCallback<&OnSocketAccept>(); + _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) + { + OnSocketOpen(std::move(sock), threadIndex); + }); sScriptMgr->OnNetworkStart(); return true; @@ -88,7 +86,7 @@ void WorldSocketMgr::StopNetwork() sScriptMgr->OnNetworkStop(); } -void WorldSocketMgr::OnSocketOpen(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex) +void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) { // set some options here if (_socketSystemSendBufferSize >= 0) @@ -117,7 +115,7 @@ void WorldSocketMgr::OnSocketOpen(boost::asio::ip::tcp::socket&& sock, uint32 th BaseSocketMgr::OnSocketOpen(std::move(sock), threadIndex); } -NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const +Trinity::Net::NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const { return new WorldSocketThread[GetNetworkThreadCount()]; } diff --git a/src/server/game/Server/WorldSocketMgr.h b/src/server/game/Server/WorldSocketMgr.h index 84b190575a4..8859da81074 100644 --- a/src/server/game/Server/WorldSocketMgr.h +++ b/src/server/game/Server/WorldSocketMgr.h @@ -15,21 +15,15 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -/** \addtogroup u2w User to World Communication - * @{ - * \file WorldSocketMgr.h - * \author Derex <derex101@gmail.com> - */ - -#ifndef __WORLDSOCKETMGR_H -#define __WORLDSOCKETMGR_H +#ifndef TRINITYCORE_WORLD_SOCKET_MGR_H +#define TRINITYCORE_WORLD_SOCKET_MGR_H #include "SocketMgr.h" class WorldSocket; /// Manages all sockets connected to peers and network threads -class TC_GAME_API WorldSocketMgr : public SocketMgr<WorldSocket> +class TC_GAME_API WorldSocketMgr : public Trinity::Net::SocketMgr<WorldSocket> { typedef SocketMgr<WorldSocket> BaseSocketMgr; @@ -44,14 +38,14 @@ public: /// Stops all network threads, It will wait for all running threads . void StopNetwork() override; - void OnSocketOpen(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex) override; + void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) override; std::size_t GetApplicationSendBufferSize() const { return _socketApplicationSendBufferSize; } protected: WorldSocketMgr(); - NetworkThread<WorldSocket>* CreateThreads() const override; + Trinity::Net::NetworkThread<WorldSocket>* CreateThreads() const override; private: int32 _socketSystemSendBufferSize; @@ -62,4 +56,3 @@ private: #define sWorldSocketMgr WorldSocketMgr::Instance() #endif -/// @} diff --git a/src/server/shared/Networking/AsyncAcceptor.h b/src/server/shared/Networking/AsyncAcceptor.h index 95e25e02166..dd0857c2b38 100644 --- a/src/server/shared/Networking/AsyncAcceptor.h +++ b/src/server/shared/Networking/AsyncAcceptor.h @@ -15,12 +15,13 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef __ASYNCACCEPT_H_ -#define __ASYNCACCEPT_H_ +#ifndef TRINITYCORE_ASYNC_ACCEPTOR_H +#define TRINITYCORE_ASYNC_ACCEPTOR_H #include "IoContext.h" #include "IpAddress.h" #include "Log.h" +#include "Socket.h" #include <boost/asio/ip/tcp.hpp> #include <boost/asio/ip/v6_only.hpp> #include <atomic> @@ -28,28 +29,28 @@ #define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections +namespace Trinity::Net +{ +template <typename Callable> +concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&, uint32>; + class AsyncAcceptor { public: - typedef void(*AcceptCallback)(boost::asio::ip::tcp::socket&& newSocket, uint32 threadIndex); - - AsyncAcceptor(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) : - _acceptor(ioContext), _endpoint(Trinity::Net::make_address(bindIp), port), + AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) : + _acceptor(ioContext), _endpoint(make_address(bindIp), port), _socket(ioContext), _closed(false), _socketFactory([this] { return DefeaultSocketFactory(); }) { } - template<class T> - void AsyncAccept(); - - template<AcceptCallback acceptCallback> - void AsyncAcceptWithCallback() + template <AcceptCallback Callback> + void AsyncAccept(Callback&& acceptCallback) { auto [tmpSocket, tmpThreadIndex] = _socketFactory(); // TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures) - boost::asio::ip::tcp::socket* socket = tmpSocket; + IoContextTcpSocket* socket = tmpSocket; uint32 threadIndex = tmpThreadIndex; - _acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error) + _acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error) mutable { if (!error) { @@ -66,7 +67,7 @@ public: } if (!_closed) - this->AsyncAcceptWithCallback<acceptCallback>(); + this->AsyncAccept(std::move(acceptCallback)); }); } @@ -120,40 +121,17 @@ public: _acceptor.close(err); } - void SetSocketFactory(std::function<std::pair<boost::asio::ip::tcp::socket*, uint32>()> func) { _socketFactory = std::move(func); } + void SetSocketFactory(std::function<std::pair<IoContextTcpSocket*, uint32>()> func) { _socketFactory = std::move(func); } private: - std::pair<boost::asio::ip::tcp::socket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); } + std::pair<IoContextTcpSocket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); } - boost::asio::ip::tcp::acceptor _acceptor; + boost::asio::basic_socket_acceptor<boost::asio::ip::tcp, IoContextTcpSocket::executor_type> _acceptor; boost::asio::ip::tcp::endpoint _endpoint; - boost::asio::ip::tcp::socket _socket; + IoContextTcpSocket _socket; std::atomic<bool> _closed; - std::function<std::pair<boost::asio::ip::tcp::socket*, uint32>()> _socketFactory; + std::function<std::pair<IoContextTcpSocket*, uint32>()> _socketFactory; }; - -template<class T> -void AsyncAcceptor::AsyncAccept() -{ - _acceptor.async_accept(_socket, [this](boost::system::error_code error) - { - if (!error) - { - try - { - // this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class - std::make_shared<T>(std::move(this->_socket))->Start(); - } - catch (boost::system::system_error const& err) - { - TC_LOG_INFO("network", "Failed to retrieve client's remote address {}", err.what()); - } - } - - // lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face - if (!_closed) - this->AsyncAccept<T>(); - }); } -#endif /* __ASYNCACCEPT_H_ */ +#endif // TRINITYCORE_ASYNC_ACCEPTOR_H diff --git a/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp new file mode 100644 index 00000000000..5996c40faee --- /dev/null +++ b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp @@ -0,0 +1,42 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "IpBanCheckConnectionInitializer.h" +#include "DatabaseEnv.h" + +QueryCallback Trinity::Net::IpBanCheckHelpers::AsyncQuery(std::string_view ipAddress) +{ + LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); + stmt->setString(0, ipAddress); + return LoginDatabase.AsyncQuery(stmt); +} + +bool Trinity::Net::IpBanCheckHelpers::IsBanned(PreparedQueryResult const& result) +{ + if (result) + { + do + { + Field* fields = result->Fetch(); + if (fields[0].GetUInt64() != 0) + return true; + + } while (result->NextRow()); + } + + return false; +} diff --git a/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h new file mode 100644 index 00000000000..ff8210a7f69 --- /dev/null +++ b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h @@ -0,0 +1,64 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H +#define TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H + +#include "DatabaseEnvFwd.h" +#include "Log.h" +#include "QueryCallback.h" +#include "SocketConnectionInitializer.h" + +namespace Trinity::Net +{ +namespace IpBanCheckHelpers +{ +TC_SHARED_API QueryCallback AsyncQuery(std::string_view ipAddress); +TC_SHARED_API bool IsBanned(PreparedQueryResult const& result); +} + +template <typename SocketImpl> +struct IpBanCheckConnectionInitializer final : SocketConnectionInitializer +{ + explicit IpBanCheckConnectionInitializer(SocketImpl* socket) : _socket(socket) { } + + void Start() override + { + _socket->QueueQuery(IpBanCheckHelpers::AsyncQuery(_socket->GetRemoteIpAddress().to_string()).WithPreparedCallback([socketRef = _socket->weak_from_this(), self = this->shared_from_this()](PreparedQueryResult const& result) + { + std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock()); + if (!socket) + return; + + if (IpBanCheckHelpers::IsBanned(result)) + { + TC_LOG_ERROR("network", "IpBanCheckConnectionInitializer: IP {} is banned.", socket->GetRemoteIpAddress().to_string()); + socket->DelayedCloseSocket(); + return; + } + + if (self->next) + self->next->Start(); + })); + } + +private: + SocketImpl* _socket; +}; +} + +#endif // TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H diff --git a/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h b/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h new file mode 100644 index 00000000000..d3f0bb16dbf --- /dev/null +++ b/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h @@ -0,0 +1,51 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H +#define TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H + +#include <memory> +#include <span> + +namespace Trinity::Net +{ +struct SocketConnectionInitializer : public std::enable_shared_from_this<SocketConnectionInitializer> +{ + SocketConnectionInitializer() = default; + + SocketConnectionInitializer(SocketConnectionInitializer const&) = delete; + SocketConnectionInitializer(SocketConnectionInitializer&&) noexcept = default; + SocketConnectionInitializer& operator=(SocketConnectionInitializer const&) = delete; + SocketConnectionInitializer& operator=(SocketConnectionInitializer&&) noexcept = default; + + virtual ~SocketConnectionInitializer() = default; + + virtual void Start() = 0; + + std::shared_ptr<SocketConnectionInitializer> next; + + static std::shared_ptr<SocketConnectionInitializer>& SetupChain(std::span<std::shared_ptr<SocketConnectionInitializer>> initializers) + { + for (std::size_t i = initializers.size(); i > 1; --i) + initializers[i - 2]->next.swap(initializers[i - 1]); + + return initializers[0]; + } +}; +} + +#endif // TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H diff --git a/src/server/shared/Networking/Http/BaseHttpSocket.cpp b/src/server/shared/Networking/Http/BaseHttpSocket.cpp index ca92c442aa2..33053399c14 100644 --- a/src/server/shared/Networking/Http/BaseHttpSocket.cpp +++ b/src/server/shared/Networking/Http/BaseHttpSocket.cpp @@ -16,6 +16,7 @@ */ #include "BaseHttpSocket.h" +#include <boost/asio/buffers_iterator.hpp> #include <boost/beast/http/serializer.hpp> namespace Trinity::Net::Http @@ -112,4 +113,34 @@ MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response return buffer; } + +void AbstractSocket::LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const +{ + if (Logger const* logger = sLog->GetEnabledLogger("server.http", LOG_LEVEL_DEBUG)) + { + std::string clientInfo = GetClientInfo(); + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_DEBUG, "{} Request {} {} done, status {}", clientInfo, + ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()), context.response.result_int()); + if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE)) + { + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: {}", clientInfo, + CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>"); + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: {}", clientInfo, + CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>"); + } + } +} + +std::string AbstractSocket::GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state) +{ + std::string info = StringFormat("[{}:{}", address.to_string(), port); + if (state) + { + info.append(", Session Id: "); + info.append(boost::uuids::to_string(state->Id)); + } + + info += ']'; + return info; +} } diff --git a/src/server/shared/Networking/Http/BaseHttpSocket.h b/src/server/shared/Networking/Http/BaseHttpSocket.h index c02b13e6463..4b7c3bd9dd1 100644 --- a/src/server/shared/Networking/Http/BaseHttpSocket.h +++ b/src/server/shared/Networking/Http/BaseHttpSocket.h @@ -25,13 +25,46 @@ #include "Optional.h" #include "QueryCallback.h" #include "Socket.h" -#include <boost/asio/buffers_iterator.hpp> +#include "SocketConnectionInitializer.h" +#include <boost/beast/core/basic_stream.hpp> #include <boost/beast/http/parser.hpp> #include <boost/beast/http/string_body.hpp> #include <boost/uuid/uuid_io.hpp> namespace Trinity::Net::Http { +using IoContextHttpSocket = boost::beast::basic_stream<boost::asio::ip::tcp, boost::asio::io_context::executor_type, boost::beast::unlimited_rate_policy>; + +namespace Impl +{ +class BoostBeastSocketWrapper : public IoContextHttpSocket +{ +public: + using IoContextHttpSocket::basic_stream; + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) + { + socket().shutdown(what, shutdownError); + } + + void close(boost::system::error_code& /*error*/) + { + IoContextHttpSocket::close(); + } + + template<typename WaitHandlerType> + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler) + { + socket().async_wait(type, std::forward<WaitHandlerType>(handler)); + } + + IoContextTcpSocket::endpoint_type remote_endpoint() const + { + return socket().remote_endpoint(); + } +}; +} + using RequestParser = boost::beast::http::request_parser<RequestBody>; class TC_SHARED_API AbstractSocket @@ -51,22 +84,59 @@ public: virtual void SendResponse(RequestContext& context) = 0; + void LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const; + virtual void QueueQuery(QueryCallback&& queryCallback) = 0; virtual std::string GetClientInfo() const = 0; - virtual Optional<boost::uuids::uuid> GetSessionId() const = 0; + static std::string GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state); + + virtual SessionState* GetSessionState() const = 0; + + Optional<boost::uuids::uuid> GetSessionId() const + { + if (SessionState* state = this->GetSessionState()) + return state->Id; + + return {}; + } + + virtual void Start() = 0; + + virtual bool Update() = 0; + + virtual boost::asio::ip::address const& GetRemoteIpAddress() const = 0; + + virtual bool IsOpen() const = 0; + + virtual void CloseSocket() = 0; }; -template<typename Derived, typename Stream> -class BaseSocket : public ::Socket<Derived, Stream>, public AbstractSocket +template <typename SocketImpl> +struct HttpConnectionInitializer final : SocketConnectionInitializer { - using Base = ::Socket<Derived, Stream>; + explicit HttpConnectionInitializer(SocketImpl* socket) : _socket(socket) { } + + void Start() override + { + _socket->ResetHttpParser(); + + if (this->next) + this->next->Start(); + } + +private: + SocketImpl* _socket; +}; + +template<typename Stream> +class BaseSocket : public Trinity::Net::Socket<Stream>, public AbstractSocket +{ + using Base = Trinity::Net::Socket<Stream>; public: - template<typename... Args> - explicit BaseSocket(boost::asio::ip::tcp::socket&& socket, Args&&... args) - : Base(std::move(socket), std::forward<Args>(args)...) { } + using Base::Base; BaseSocket(BaseSocket const& other) = delete; BaseSocket(BaseSocket&& other) = delete; @@ -75,11 +145,8 @@ public: ~BaseSocket() = default; - void ReadHandler() override + SocketReadCallbackResult ReadHandler() final { - if (!this->IsOpen()) - return; - MessageBuffer& packet = this->GetReadBuffer(); while (packet.GetActiveSize() > 0) { @@ -92,13 +159,13 @@ public: if (!HandleMessage(_httpParser->get())) { this->CloseSocket(); - break; + return SocketReadCallbackResult::Stop; } this->ResetHttpParser(); } - this->AsyncRead(); + return SocketReadCallbackResult::KeepReading; } bool HandleMessage(Request& request) @@ -118,19 +185,11 @@ public: virtual RequestHandlerResult RequestHandler(RequestContext& context) = 0; - void SendResponse(RequestContext& context) override + void SendResponse(RequestContext& context) final { MessageBuffer buffer = SerializeResponse(context.request, context.response); - TC_LOG_DEBUG("server.http", "{} Request {} {} done, status {}", this->GetClientInfo(), ToStdStringView(context.request.method_string()), - ToStdStringView(context.request.target()), context.response.result_int()); - if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE)) - { - sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Request: {}", this->GetClientInfo(), - CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>"); - sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Response: {}", this->GetClientInfo(), - CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>"); - } + this->LogRequestAndResponse(context, buffer); this->QueuePacket(std::move(buffer)); @@ -138,11 +197,13 @@ public: this->DelayedCloseSocket(); } - void QueueQuery(QueryCallback&& queryCallback) override + void QueueQuery(QueryCallback&& queryCallback) final { this->_queryProcessor.AddCallback(std::move(queryCallback)); } + void Start() override { return this->Base::Start(); } + bool Update() override { if (!this->Base::Update()) @@ -152,27 +213,19 @@ public: return true; } + boost::asio::ip::address const& GetRemoteIpAddress() const final { return this->Base::GetRemoteIpAddress(); } + + bool IsOpen() const final { return this->Base::IsOpen(); } + + void CloseSocket() final { return this->Base::CloseSocket(); } + std::string GetClientInfo() const override { - std::string info; - info.reserve(500); - auto itr = StringFormatTo(std::back_inserter(info), "[{}:{}", this->GetRemoteIpAddress().to_string(), this->GetRemotePort()); - if (_state) - itr = StringFormatTo(itr, ", Session Id: {}", boost::uuids::to_string(_state->Id)); - - StringFormatTo(itr, "]"); - return info; + return AbstractSocket::GetClientInfo(this->GetRemoteIpAddress(), this->GetRemotePort(), this->_state.get()); } - Optional<boost::uuids::uuid> GetSessionId() const final - { - if (this->_state) - return this->_state->Id; - - return {}; - } + SessionState* GetSessionState() const override { return _state.get(); } -protected: void ResetHttpParser() { this->_httpParser.reset(); @@ -180,6 +233,7 @@ protected: this->_httpParser->eager(true); } +protected: virtual std::shared_ptr<SessionState> ObtainSessionState(RequestContext& context) const = 0; QueryCallbackProcessor _queryProcessor; diff --git a/src/server/shared/Networking/Http/HttpService.h b/src/server/shared/Networking/Http/HttpService.h index e68f8d2795f..2a377c734da 100644 --- a/src/server/shared/Networking/Http/HttpService.h +++ b/src/server/shared/Networking/Http/HttpService.h @@ -160,7 +160,7 @@ protected: class Thread : public NetworkThread<SessionImpl> { protected: - void SocketRemoved(std::shared_ptr<SessionImpl> session) override + void SocketRemoved(std::shared_ptr<SessionImpl> const& session) override { if (Optional<boost::uuids::uuid> id = session->GetSessionId()) _service->MarkSessionInactive(*id); diff --git a/src/server/shared/Networking/Http/HttpSocket.h b/src/server/shared/Networking/Http/HttpSocket.h index 9a333a8e779..2cfc3ba8ed8 100644 --- a/src/server/shared/Networking/Http/HttpSocket.h +++ b/src/server/shared/Networking/Http/HttpSocket.h @@ -19,43 +19,16 @@ #define TRINITYCORE_HTTP_SOCKET_H #include "BaseHttpSocket.h" -#include <boost/beast/core/tcp_stream.hpp> +#include <array> namespace Trinity::Net::Http { -namespace Impl +class Socket : public BaseSocket<Impl::BoostBeastSocketWrapper> { -class BoostBeastSocketWrapper : public boost::beast::tcp_stream -{ -public: - using boost::beast::tcp_stream::tcp_stream; - - void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) - { - socket().shutdown(what, shutdownError); - } - - void close(boost::system::error_code& /*error*/) - { - boost::beast::tcp_stream::close(); - } - - boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const - { - return socket().remote_endpoint(); - } -}; -} - -template <typename Derived> -class Socket : public BaseSocket<Derived, Impl::BoostBeastSocketWrapper> -{ - using SocketBase = BaseSocket<Derived, Impl::BoostBeastSocketWrapper>; + using SocketBase = BaseSocket<Impl::BoostBeastSocketWrapper>; public: - template<typename... Args> - explicit Socket(boost::asio::ip::tcp::socket&& socket, Args&&...) - : SocketBase(std::move(socket)) { } + using SocketBase::SocketBase; Socket(Socket const& other) = delete; Socket(Socket&& other) = delete; @@ -66,9 +39,13 @@ public: void Start() override { - this->ResetHttpParser(); + std::array<std::shared_ptr<SocketConnectionInitializer>, 2> initializers = + { { + std::make_shared<HttpConnectionInitializer<SocketBase>>(this), + std::make_shared<ReadConnectionInitializer<SocketBase>>(this), + } }; - this->AsyncRead(); + SocketConnectionInitializer::SetupChain(initializers)->Start(); } }; } diff --git a/src/server/shared/Networking/Http/HttpSslSocket.h b/src/server/shared/Networking/Http/HttpSslSocket.h index cdb70645e05..c789cbfefaf 100644 --- a/src/server/shared/Networking/Http/HttpSslSocket.h +++ b/src/server/shared/Networking/Http/HttpSslSocket.h @@ -19,47 +19,21 @@ #define TRINITYCORE_HTTP_SSL_SOCKET_H #include "BaseHttpSocket.h" -#include "SslSocket.h" -#include <boost/beast/core/stream_traits.hpp> -#include <boost/beast/core/tcp_stream.hpp> -#include <boost/beast/ssl/ssl_stream.hpp> +#include "SslStream.h" namespace Trinity::Net::Http { -namespace Impl +class SslSocket : public BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>> { -class BoostBeastSslSocketWrapper : public ::SslSocket<boost::beast::ssl_stream<boost::beast::tcp_stream>> -{ -public: - using SslSocket::SslSocket; - - void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) - { - _sslSocket.shutdown(shutdownError); - boost::beast::get_lowest_layer(_sslSocket).socket().shutdown(what, shutdownError); - } - - void close(boost::system::error_code& /*error*/) - { - boost::beast::get_lowest_layer(_sslSocket).close(); - } - - boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const - { - return boost::beast::get_lowest_layer(_sslSocket).socket().remote_endpoint(); - } -}; -} - -template <typename Derived> -class SslSocket : public BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper> -{ - using SocketBase = BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper>; + using SocketBase = BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>; public: - explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext) + explicit SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : SocketBase(std::move(socket), sslContext) { } + explicit SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) + : SocketBase(context, sslContext) { } + SslSocket(SslSocket const& other) = delete; SslSocket(SslSocket&& other) = delete; SslSocket& operator=(SslSocket const& other) = delete; @@ -69,27 +43,14 @@ public: void Start() override { - this->AsyncHandshake(); - } - - void AsyncHandshake() - { - this->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server, - [self = this->shared_from_this()](boost::system::error_code const& error) { self->HandshakeHandler(error); }); - } - - void HandshakeHandler(boost::system::error_code const& error) - { - if (error) - { - TC_LOG_ERROR("server.http.session.ssl", "{} SSL Handshake failed {}", this->GetClientInfo(), error.message()); - this->CloseSocket(); - return; - } - - this->ResetHttpParser(); - - this->AsyncRead(); + std::array<std::shared_ptr<SocketConnectionInitializer>, 3> initializers = + { { + std::make_shared<SslHandshakeConnectionInitializer<SocketBase>>(this), + std::make_shared<HttpConnectionInitializer<SocketBase>>(this), + std::make_shared<ReadConnectionInitializer<SocketBase>>(this), + } }; + + SocketConnectionInitializer::SetupChain(initializers)->Start(); } }; } diff --git a/src/server/shared/Networking/NetworkThread.h b/src/server/shared/Networking/NetworkThread.h index fc7d9647fc1..d16da442149 100644 --- a/src/server/shared/Networking/NetworkThread.h +++ b/src/server/shared/Networking/NetworkThread.h @@ -15,20 +15,24 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef NetworkThread_h__ -#define NetworkThread_h__ +#ifndef TRINITYCORE_NETWORK_THREAD_H +#define TRINITYCORE_NETWORK_THREAD_H -#include "Define.h" +#include "Containers.h" #include "DeadlineTimer.h" +#include "Define.h" #include "Errors.h" #include "IoContext.h" #include "Log.h" +#include "Socket.h" #include <boost/asio/ip/tcp.hpp> #include <atomic> #include <memory> #include <mutex> #include <thread> +namespace Trinity::Net +{ template<class SocketType> class NetworkThread { @@ -38,14 +42,16 @@ public: { } + NetworkThread(NetworkThread const&) = delete; + NetworkThread(NetworkThread&&) = delete; + NetworkThread& operator=(NetworkThread const&) = delete; + NetworkThread& operator=(NetworkThread&&) = delete; + virtual ~NetworkThread() { Stop(); if (_thread) - { Wait(); - delete _thread; - } } void Stop() @@ -59,7 +65,7 @@ public: if (_thread) return false; - _thread = new std::thread(&NetworkThread::Run, this); + _thread = std::make_unique<std::thread>(&NetworkThread::Run, this); return true; } @@ -68,7 +74,6 @@ public: ASSERT(_thread); _thread->join(); - delete _thread; _thread = nullptr; } @@ -82,15 +87,14 @@ public: std::lock_guard<std::mutex> lock(_newSocketsLock); ++_connections; - _newSockets.push_back(sock); - SocketAdded(sock); + SocketAdded(_newSockets.emplace_back(std::move(sock))); } - boost::asio::ip::tcp::socket* GetSocketForAccept() { return &_acceptSocket; } + Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; } protected: - virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { } - virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { } + virtual void SocketAdded(std::shared_ptr<SocketType> const& /*sock*/) { } + virtual void SocketRemoved(std::shared_ptr<SocketType> const& /*sock*/) { } void AddNewSockets() { @@ -99,7 +103,7 @@ protected: if (_newSockets.empty()) return; - for (std::shared_ptr<SocketType> sock : _newSockets) + for (std::shared_ptr<SocketType>& sock : _newSockets) { if (!sock->IsOpen()) { @@ -107,7 +111,7 @@ protected: --_connections; } else - _sockets.push_back(sock); + _sockets.emplace_back(std::move(sock)); } _newSockets.clear(); @@ -136,7 +140,7 @@ protected: AddNewSockets(); - _sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr<SocketType> sock) + Trinity::Containers::EraseIf(_sockets, [this](std::shared_ptr<SocketType> const& sock) { if (!sock->Update()) { @@ -150,7 +154,7 @@ protected: } return false; - }), _sockets.end()); + }); } private: @@ -159,7 +163,7 @@ private: std::atomic<int32> _connections; std::atomic<bool> _stopped; - std::thread* _thread; + std::unique_ptr<std::thread> _thread; SocketContainer _sockets; @@ -167,8 +171,9 @@ private: SocketContainer _newSockets; Trinity::Asio::IoContext _ioContext; - boost::asio::ip::tcp::socket _acceptSocket; + Trinity::Net::IoContextTcpSocket _acceptSocket; Trinity::Asio::DeadlineTimer _updateTimer; }; +} -#endif // NetworkThread_h__ +#endif // TRINITYCORE_NETWORK_THREAD_H diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h index 40f5820da92..565cc175318 100644 --- a/src/server/shared/Networking/Socket.h +++ b/src/server/shared/Networking/Socket.h @@ -15,11 +15,14 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef __SOCKET_H__ -#define __SOCKET_H__ +#ifndef TRINITYCORE_SOCKET_H +#define TRINITYCORE_SOCKET_H +#include "Concepts.h" #include "Log.h" #include "MessageBuffer.h" +#include "SocketConnectionInitializer.h" +#include <boost/asio/io_context.hpp> #include <boost/asio/ip/tcp.hpp> #include <atomic> #include <memory> @@ -31,12 +34,51 @@ #define TC_SOCKET_USE_IOCP #endif +namespace Trinity::Net +{ +using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>; + +enum class SocketReadCallbackResult +{ + KeepReading, + Stop +}; + +template <typename Callable> +concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>; + +template <typename SocketType> +struct InvokeReadHandlerCallback +{ + SocketReadCallbackResult operator()() const + { + return this->Socket->ReadHandler(); + } + + SocketType* Socket; +}; + +template <typename SocketType> +struct ReadConnectionInitializer final : SocketConnectionInitializer +{ + explicit ReadConnectionInitializer(SocketType* socket) : ReadCallback({ .Socket = socket }) { } + + void Start() override + { + ReadCallback.Socket->AsyncRead(std::move(ReadCallback)); + + if (this->next) + this->next->Start(); + } + + InvokeReadHandlerCallback<SocketType> ReadCallback; +}; + /** @class Socket Base async socket implementation - @tparam T derived class type (CRTP) @tparam Stream stream type used for operations on socket Stream must implement the following methods: @@ -53,21 +95,27 @@ template<typename ConstBufferSequence> std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error); + template<typename WaitHandlerType> + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler); + template<typename SettableSocketOption> void set_option(SettableSocketOption const& option, boost::system::error_code& error); tcp::socket::endpoint_type remote_endpoint() const; */ -template<class T, class Stream = boost::asio::ip::tcp::socket> -class Socket : public std::enable_shared_from_this<T> +template<class Stream = IoContextTcpSocket> +class Socket : public std::enable_shared_from_this<Socket<Stream>> { public: template<typename... Args> - explicit Socket(boost::asio::ip::tcp::socket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...), - _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), - _closed(false), _closing(false), _isWritingAsync(false) + explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...), + _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open) + { + } + + template<typename... Args> + explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward<Args>(args)...), _openState(OpenState_Closed) { - _readBuffer.Resize(READ_BLOCK_SIZE); } Socket(Socket const& other) = delete; @@ -77,20 +125,20 @@ public: virtual ~Socket() { - _closed = true; + _openState = OpenState_Closed; boost::system::error_code error; _socket.close(error); } - virtual void Start() = 0; + virtual void Start() { } virtual bool Update() { - if (_closed) + if (_openState == OpenState_Closed) return false; #ifndef TC_SOCKET_USE_IOCP - if (_isWritingAsync || (_writeQueue.empty() && !_closing)) + if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open)) return true; for (; HandleQueue();) @@ -100,7 +148,7 @@ public: return true; } - boost::asio::ip::address GetRemoteIpAddress() const + boost::asio::ip::address const& GetRemoteIpAddress() const { return _remoteAddress; } @@ -110,7 +158,8 @@ public: return _remotePort; } - void AsyncRead() + template <SocketReadCallback Callback> + void AsyncRead(Callback&& callback) { if (!IsOpen()) return; @@ -118,23 +167,11 @@ public: _readBuffer.Normalize(); _readBuffer.EnsureFreeSpace(); _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), - [self = this->shared_from_this()](boost::system::error_code const& error, size_t transferredBytes) + [self = this->shared_from_this(), callback = std::forward<Callback>(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable { - self->ReadHandlerInternal(error, transferredBytes); - }); - } - - void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code const&, std::size_t)) - { - if (!IsOpen()) - return; - - _readBuffer.Normalize(); - _readBuffer.EnsureFreeSpace(); - _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), - [self = this->shared_from_this(), callback](boost::system::error_code const& error, size_t transferredBytes) - { - (self.get()->*callback)(error, transferredBytes); + if (self->ReadHandlerInternal(error, transferredBytes)) + if (callback() == SocketReadCallbackResult::KeepReading) + self->AsyncRead(std::forward<Callback>(callback)); }); } @@ -147,11 +184,11 @@ public: #endif } - bool IsOpen() const { return !_closed && !_closing; } + bool IsOpen() const { return _openState == OpenState_Open; } void CloseSocket() { - if (_closed.exchange(true)) + if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0) return; boost::system::error_code shutdownError; @@ -160,13 +197,13 @@ public: TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress().to_string(), shutdownError.value(), shutdownError.message()); - OnClose(); + this->OnClose(); } /// Marks the socket for closing after write buffer becomes empty void DelayedCloseSocket() { - if (_closing.exchange(true)) + if (_openState.fetch_or(OpenState_Closing) != 0) return; if (_writeQueue.empty()) @@ -175,10 +212,15 @@ public: MessageBuffer& GetReadBuffer() { return _readBuffer; } + Stream& underlying_stream() + { + return _socket; + } + protected: virtual void OnClose() { } - virtual void ReadHandler() = 0; + virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; } bool AsyncProcessQueue() { @@ -195,10 +237,10 @@ protected: self->WriteHandler(error, transferedBytes); }); #else - _socket.async_write_some(boost::asio::null_buffers(), - [self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes) + _socket.async_wait(boost::asio::socket_base::wait_type::wait_write, + [self = this->shared_from_this()](boost::system::error_code const& error) { - self->WriteHandlerWrapper(error, transferedBytes); + self->WriteHandlerWrapper(error); }); #endif @@ -214,22 +256,17 @@ protected: GetRemoteIpAddress().to_string(), err.value(), err.message()); } - Stream& underlying_stream() - { - return _socket; - } - private: - void ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes) + bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes) { if (error) { CloseSocket(); - return; + return false; } _readBuffer.WriteCompleted(transferredBytes); - ReadHandler(); + return IsOpen(); } #ifdef TC_SOCKET_USE_IOCP @@ -245,7 +282,7 @@ private: if (!_writeQueue.empty()) AsyncProcessQueue(); - else if (_closing) + else if (_openState == OpenState_Closing) CloseSocket(); } else @@ -254,7 +291,7 @@ private: #else - void WriteHandlerWrapper(boost::system::error_code const& /*error*/, std::size_t /*transferedBytes*/) + void WriteHandlerWrapper(boost::system::error_code const& /*error*/) { _isWritingAsync = false; HandleQueue(); @@ -278,14 +315,14 @@ private: return AsyncProcessQueue(); _writeQueue.pop(); - if (_closing && _writeQueue.empty()) + if (_openState == OpenState_Closing && _writeQueue.empty()) CloseSocket(); return false; } else if (bytesSent == 0) { _writeQueue.pop(); - if (_closing && _writeQueue.empty()) + if (_openState == OpenState_Closing && _writeQueue.empty()) CloseSocket(); return false; } @@ -296,7 +333,7 @@ private: } _writeQueue.pop(); - if (_closing && _writeQueue.empty()) + if (_openState == OpenState_Closing && _writeQueue.empty()) CloseSocket(); return !_writeQueue.empty(); } @@ -306,15 +343,20 @@ private: Stream _socket; boost::asio::ip::address _remoteAddress; - uint16 _remotePort; + uint16 _remotePort = 0; - MessageBuffer _readBuffer; + MessageBuffer _readBuffer = MessageBuffer(READ_BLOCK_SIZE); std::queue<MessageBuffer> _writeQueue; - std::atomic<bool> _closed; - std::atomic<bool> _closing; + // Socket open state "enum" (not enum to enable integral std::atomic api) + static constexpr uint8 OpenState_Open = 0x0; + static constexpr uint8 OpenState_Closing = 0x1; ///< Transition to Closed state after sending all queued data + static constexpr uint8 OpenState_Closed = 0x2; + + std::atomic<uint8> _openState; - bool _isWritingAsync; + bool _isWritingAsync = false; }; +} -#endif // __SOCKET_H__ +#endif // TRINITYCORE_SOCKET_H diff --git a/src/server/shared/Networking/SocketMgr.h b/src/server/shared/Networking/SocketMgr.h index 0b2d03e0944..07252355308 100644 --- a/src/server/shared/Networking/SocketMgr.h +++ b/src/server/shared/Networking/SocketMgr.h @@ -15,32 +15,40 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef SocketMgr_h__ -#define SocketMgr_h__ +#ifndef TRINITYCORE_SOCKET_MGR_H +#define TRINITYCORE_SOCKET_MGR_H #include "AsyncAcceptor.h" #include "Errors.h" #include "NetworkThread.h" +#include "Socket.h" #include <boost/asio/ip/tcp.hpp> #include <memory> +namespace Trinity::Net +{ template<class SocketType> class SocketMgr { public: + SocketMgr(SocketMgr const&) = delete; + SocketMgr(SocketMgr&&) = delete; + SocketMgr& operator=(SocketMgr const&) = delete; + SocketMgr& operator=(SocketMgr&&) = delete; + virtual ~SocketMgr() { ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction"); } - virtual bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) + virtual bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) { ASSERT(threadCount > 0); - AsyncAcceptor* acceptor = nullptr; + std::unique_ptr<AsyncAcceptor> acceptor = nullptr; try { - acceptor = new AsyncAcceptor(ioContext, bindIp, port); + acceptor = std::make_unique<AsyncAcceptor>(ioContext, bindIp, port); } catch (boost::system::system_error const& err) { @@ -51,13 +59,12 @@ public: if (!acceptor->Bind()) { TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor"); - delete acceptor; return false; } - _acceptor = acceptor; + _acceptor = std::move(acceptor); _threadCount = threadCount; - _threads = CreateThreads(); + _threads.reset(CreateThreads()); ASSERT(_threads); @@ -73,27 +80,23 @@ public: { _acceptor->Close(); - if (_threadCount != 0) - for (int32 i = 0; i < _threadCount; ++i) - _threads[i].Stop(); + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Stop(); Wait(); - delete _acceptor; _acceptor = nullptr; - delete[] _threads; _threads = nullptr; _threadCount = 0; } void Wait() { - if (_threadCount != 0) - for (int32 i = 0; i < _threadCount; ++i) - _threads[i].Wait(); + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Wait(); } - virtual void OnSocketOpen(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex) + virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex) { try { @@ -121,22 +124,23 @@ public: return min; } - std::pair<boost::asio::ip::tcp::socket*, uint32> GetSocketForAccept() + std::pair<IoContextTcpSocket*, uint32> GetSocketForAccept() { uint32 threadIndex = SelectThreadWithMinConnections(); return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex); } protected: - SocketMgr() : _acceptor(nullptr), _threads(nullptr), _threadCount(0) + SocketMgr() : _threadCount(0) { } virtual NetworkThread<SocketType>* CreateThreads() const = 0; - AsyncAcceptor* _acceptor; - NetworkThread<SocketType>* _threads; + std::unique_ptr<AsyncAcceptor> _acceptor; + std::unique_ptr<NetworkThread<SocketType>[]> _threads; int32 _threadCount; }; +} -#endif // SocketMgr_h__ +#endif // TRINITYCORE_SOCKET_MGR_H diff --git a/src/server/shared/Networking/SslSocket.h b/src/server/shared/Networking/SslSocket.h deleted file mode 100644 index c19c8612edf..00000000000 --- a/src/server/shared/Networking/SslSocket.h +++ /dev/null @@ -1,88 +0,0 @@ -/* - * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information - * - * This program is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation; either version 2 of the License, or (at your - * option) any later version. - * - * This program is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for - * more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#ifndef SslSocket_h__ -#define SslSocket_h__ - -#include <boost/asio/ip/tcp.hpp> -#include <boost/asio/ssl/stream.hpp> -#include <boost/system/error_code.hpp> - -namespace boostssl = boost::asio::ssl; - -template<class Stream = boostssl::stream<boost::asio::ip::tcp::socket>> -class SslSocket -{ -public: - explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext) - { - _sslSocket.set_verify_mode(boostssl::verify_none); - } - - // adapting tcp::socket api - void close(boost::system::error_code& error) - { - _sslSocket.lowest_layer().close(error); - } - - void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) - { - _sslSocket.shutdown(shutdownError); - _sslSocket.lowest_layer().shutdown(what, shutdownError); - } - - template<typename MutableBufferSequence, typename ReadHandlerType> - void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler) - { - _sslSocket.async_read_some(buffers, std::move(handler)); - } - - template<typename ConstBufferSequence, typename WriteHandlerType> - void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler) - { - _sslSocket.async_write_some(buffers, std::move(handler)); - } - - template<typename ConstBufferSequence> - std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error) - { - return _sslSocket.write_some(buffers, error); - } - - template<typename SettableSocketOption> - void set_option(SettableSocketOption const& option, boost::system::error_code& error) - { - _sslSocket.lowest_layer().set_option(option, error); - } - - boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const - { - return _sslSocket.lowest_layer().remote_endpoint(); - } - - // ssl api - template<typename HandshakeHandlerType> - void async_handshake(boostssl::stream_base::handshake_type type, HandshakeHandlerType&& handler) - { - _sslSocket.async_handshake(type, std::move(handler)); - } - -protected: - Stream _sslSocket; -}; - -#endif // SslSocket_h__ diff --git a/src/server/shared/Networking/SslStream.h b/src/server/shared/Networking/SslStream.h new file mode 100644 index 00000000000..2cced44e5ff --- /dev/null +++ b/src/server/shared/Networking/SslStream.h @@ -0,0 +1,131 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_SSL_STREAM_H +#define TRINITYCORE_SSL_STREAM_H + +#include "SocketConnectionInitializer.h" +#include <boost/asio/ip/tcp.hpp> +#include <boost/asio/ssl/stream.hpp> +#include <boost/system/error_code.hpp> + +namespace Trinity::Net +{ +template <typename SocketImpl> +struct SslHandshakeConnectionInitializer final : SocketConnectionInitializer +{ + explicit SslHandshakeConnectionInitializer(SocketImpl* socket) : _socket(socket) { } + + void Start() override + { + _socket->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server, + [socketRef = _socket->weak_from_this(), self = this->shared_from_this()](boost::system::error_code const& error) + { + std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock()); + if (!socket) + return; + + if (error) + { + TC_LOG_ERROR("session", "{} SSL Handshake failed {}", socket->GetClientInfo(), error.message()); + socket->CloseSocket(); + return; + } + + if (self->next) + self->next->Start(); + }); + } + +private: + SocketImpl* _socket; +}; + +template<class WrappedStream = IoContextTcpSocket> +class SslStream +{ +public: + explicit SslStream(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext) + { + _sslSocket.set_verify_mode(boost::asio::ssl::verify_none); + } + + explicit SslStream(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) : _sslSocket(context, sslContext) + { + _sslSocket.set_verify_mode(boost::asio::ssl::verify_none); + } + + // adapting tcp::socket api + void close(boost::system::error_code& error) + { + _sslSocket.next_layer().close(error); + } + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) + { + _sslSocket.shutdown(shutdownError); + _sslSocket.next_layer().shutdown(what, shutdownError); + } + + template<typename MutableBufferSequence, typename ReadHandlerType> + void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler) + { + _sslSocket.async_read_some(buffers, std::forward<ReadHandlerType>(handler)); + } + + template<typename ConstBufferSequence, typename WriteHandlerType> + void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler) + { + _sslSocket.async_write_some(buffers, std::forward<WriteHandlerType>(handler)); + } + + template<typename ConstBufferSequence> + std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error) + { + return _sslSocket.write_some(buffers, error); + } + + template<typename WaitHandlerType> + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler) + { + _sslSocket.next_layer().async_wait(type, std::forward<WaitHandlerType>(handler)); + } + + template<typename SettableSocketOption> + void set_option(SettableSocketOption const& option, boost::system::error_code& error) + { + _sslSocket.next_layer().set_option(option, error); + } + + IoContextTcpSocket::endpoint_type remote_endpoint() const + { + return _sslSocket.next_layer().remote_endpoint(); + } + + // ssl api + template<typename HandshakeHandlerType> + void async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler) + { + _sslSocket.async_handshake(type, std::forward<HandshakeHandlerType>(handler)); + } + +protected: + boost::asio::ssl::stream<WrappedStream> _sslSocket; +}; +} + +#endif // TRINITYCORE_SSL_STREAM_H diff --git a/src/server/worldserver/Main.cpp b/src/server/worldserver/Main.cpp index a482d109890..ba5c8f9c758 100644 --- a/src/server/worldserver/Main.cpp +++ b/src/server/worldserver/Main.cpp @@ -15,10 +15,6 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -/// \addtogroup Trinityd Trinity Daemon -/// @{ -/// \file - #include "Common.h" #include "AppenderDB.h" #include "AsyncAcceptor.h" @@ -122,7 +118,7 @@ private: }; void SignalHandler(boost::system::error_code const& error, int signalNumber); -AsyncAcceptor* StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext); +std::unique_ptr<Trinity::Net::AsyncAcceptor> StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext); bool StartDB(); void StopDB(); void WorldUpdateLoop(); @@ -372,9 +368,9 @@ int main(int argc, char** argv) auto battlegroundMgrHandle = Trinity::make_unique_ptr_with_deleter<&BattlegroundMgr::DeleteAllBattlegrounds>(sBattlegroundMgr); // Start the Remote Access port (acceptor) if enabled - std::unique_ptr<AsyncAcceptor> raAcceptor; + std::unique_ptr<Trinity::Net::AsyncAcceptor> raAcceptor; if (sConfigMgr->GetBoolDefault("Ra.Enable", false)) - raAcceptor.reset(StartRaSocketAcceptor(*ioContext)); + raAcceptor = StartRaSocketAcceptor(*ioContext); // Start soap serving thread if enabled std::unique_ptr<std::thread, ShutdownTCSoapThread> soapThread; @@ -632,20 +628,23 @@ void FreezeDetector::Handler(std::weak_ptr<FreezeDetector> freezeDetectorRef, bo } } -AsyncAcceptor* StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext) +std::unique_ptr<Trinity::Net::AsyncAcceptor> StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext) { uint16 raPort = uint16(sConfigMgr->GetIntDefault("Ra.Port", 3443)); std::string raListener = sConfigMgr->GetStringDefault("Ra.IP", "0.0.0.0"); - AsyncAcceptor* acceptor = new AsyncAcceptor(ioContext, raListener, raPort); + std::unique_ptr<Trinity::Net::AsyncAcceptor> acceptor = std::make_unique<Trinity::Net::AsyncAcceptor>(ioContext, raListener, raPort); if (!acceptor->Bind()) { TC_LOG_ERROR("server.worldserver", "Failed to bind RA socket acceptor"); - delete acceptor; return nullptr; } - acceptor->AsyncAccept<RASession>(); + acceptor->AsyncAccept([](Trinity::Net::IoContextTcpSocket&& sock, uint32 /*threadIndex*/) + { + std::make_shared<RASession>(std::move(sock))->Start(); + + }); return acceptor; } @@ -696,7 +695,6 @@ void ClearOnlineAccounts(uint32 realmId) // Battleground instance ids reset at server restart CharacterDatabase.DirectExecute("UPDATE character_battleground_data SET instanceId = 0"); } -/// @} variables_map GetConsoleArguments(int argc, char** argv, fs::path& configFile, fs::path& configDir, [[maybe_unused]] std::string& winServiceAction) { diff --git a/src/server/worldserver/RemoteAccess/RASession.cpp b/src/server/worldserver/RemoteAccess/RASession.cpp index b4e9e6317be..910cfdf5e4f 100644 --- a/src/server/worldserver/RemoteAccess/RASession.cpp +++ b/src/server/worldserver/RemoteAccess/RASession.cpp @@ -29,9 +29,11 @@ void RASession::Start() { + _socket.non_blocking(false); + // wait 1 second for active connections to send negotiation request for (int counter = 0; counter < 10 && _socket.available() == 0; counter++) - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::this_thread::sleep_for(100ms); // Check if there are bytes available, if they are, then the client is requesting the negotiation if (_socket.available() > 0) diff --git a/src/server/worldserver/RemoteAccess/RASession.h b/src/server/worldserver/RemoteAccess/RASession.h index e0f4b373f74..23fd4e70c55 100644 --- a/src/server/worldserver/RemoteAccess/RASession.h +++ b/src/server/worldserver/RemoteAccess/RASession.h @@ -15,10 +15,11 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef __RASESSION_H__ -#define __RASESSION_H__ +#ifndef TRINITYCORE_RA_SESSION_H +#define TRINITYCORE_RA_SESSION_H #include "Define.h" +#include "Socket.h" #include <boost/asio/ip/tcp.hpp> #include <boost/asio/streambuf.hpp> #include <future> @@ -29,7 +30,7 @@ const size_t bufferSize = 4096; class RASession : public std::enable_shared_from_this <RASession> { public: - RASession(boost::asio::ip::tcp::socket&& socket) : _socket(std::move(socket)), _commandExecuting(nullptr) + RASession(Trinity::Net::IoContextTcpSocket&& socket) : _socket(std::move(socket)), _commandExecuting(nullptr) { } @@ -47,7 +48,7 @@ private: static void CommandPrint(void* callbackArg, std::string_view text); static void CommandFinished(void* callbackArg, bool); - boost::asio::ip::tcp::socket _socket; + Trinity::Net::IoContextTcpSocket _socket; boost::asio::streambuf _readBuffer; boost::asio::streambuf _writeBuffer; std::promise<void>* _commandExecuting; |