diff options
| author | Shauren <shauren.trinity@gmail.com> | 2025-04-08 19:15:16 +0200 |
|---|---|---|
| committer | Ovahlord <dreadkiller@gmx.de> | 2025-04-08 19:57:45 +0200 |
| commit | 4aa991e7e4450232df4ceda0b2f439bccce1d260 (patch) | |
| tree | 6ba65344105e100dd29da50cb667a53c465252a7 /src/server/bnetserver | |
| parent | a28b2999b1b500185aeecf343c3b1e14a39c26cf (diff) | |
Core/Network: Socket refactors
* Devirtualize calls to Read and Update by marking concrete implementations as final
* Removed derived class template argument
* Specialize boost::asio::basic_stream_socket for boost::asio::io_context instead of type-erased any_io_executor
* Make socket initialization easier composable (before entering Read loop)
* Remove use of deprecated boost::asio::null_buffers and boost::beast::ssl_stream
(cherry picked from commit e8b2be3527c7683e8bfca70ed7706fc20da566fd)
Diffstat (limited to 'src/server/bnetserver')
| -rw-r--r-- | src/server/bnetserver/REST/LoginHttpSession.cpp | 169 | ||||
| -rw-r--r-- | src/server/bnetserver/REST/LoginHttpSession.h | 64 | ||||
| -rw-r--r-- | src/server/bnetserver/REST/LoginRESTService.cpp | 34 | ||||
| -rw-r--r-- | src/server/bnetserver/REST/LoginRESTService.h | 22 | ||||
| -rw-r--r-- | src/server/bnetserver/Server/Session.cpp | 103 | ||||
| -rw-r--r-- | src/server/bnetserver/Server/Session.h | 24 | ||||
| -rw-r--r-- | src/server/bnetserver/Server/SessionManager.cpp | 15 | ||||
| -rw-r--r-- | src/server/bnetserver/Server/SessionManager.h | 13 |
8 files changed, 176 insertions, 268 deletions
diff --git a/src/server/bnetserver/REST/LoginHttpSession.cpp b/src/server/bnetserver/REST/LoginHttpSession.cpp index aff579de7f9..23a317d3726 100644 --- a/src/server/bnetserver/REST/LoginHttpSession.cpp +++ b/src/server/bnetserver/REST/LoginHttpSession.cpp @@ -17,126 +17,117 @@ #include "LoginHttpSession.h" #include "DatabaseEnv.h" +#include "HttpSocket.h" +#include "HttpSslSocket.h" +#include "IpBanCheckConnectionInitializer.h" #include "LoginRESTService.h" #include "SslContext.h" #include "Util.h" +#include <boost/container/static_vector.hpp> -namespace Battlenet -{ -template<template<typename> typename SocketImpl> -LoginHttpSession<SocketImpl>::LoginHttpSession(boost::asio::ip::tcp::socket&& socket, LoginHttpSessionWrapper& owner) - : BaseSocket(std::move(socket), SslContext::instance()), _owner(owner) -{ -} - -template<template<typename> typename SocketImpl> -LoginHttpSession<SocketImpl>::~LoginHttpSession() = default; - -template<template<typename> typename SocketImpl> -void LoginHttpSession<SocketImpl>::Start() -{ - std::string ip_address = this->GetRemoteIpAddress().to_string(); - TC_LOG_TRACE("server.http.session", "{} Accepted connection", this->GetClientInfo()); - - // Verify that this IP is not in the ip_banned table - LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS)); - - LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); - stmt->setString(0, ip_address); - - this->_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt) - .WithPreparedCallback([sess = this->shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); })); -} - -template<template<typename> typename SocketImpl> -void LoginHttpSession<SocketImpl>::CheckIpCallback(PreparedQueryResult result) -{ - if (result) - { - bool banned = false; - do - { - Field* fields = result->Fetch(); - if (fields[0].GetUInt64() != 0) - banned = true; - - } while (result->NextRow()); - - if (banned) - { - TC_LOG_DEBUG("server.http.session", "{} tries to log in using banned IP!", this->GetClientInfo()); - this->CloseSocket(); - return; - } - } - - if constexpr (std::is_same_v<BaseSocket, Trinity::Net::Http::SslSocket<LoginHttpSession<Trinity::Net::Http::SslSocket>>>) - { - this->AsyncHandshake(); - } - else - { - this->ResetHttpParser(); - this->AsyncRead(); - } -} - -template<template<typename> typename SocketImpl> -Trinity::Net::Http::RequestHandlerResult LoginHttpSession<SocketImpl>::RequestHandler(Trinity::Net::Http::RequestContext& context) +namespace { - return sLoginService.HandleRequest(_owner.shared_from_this(), context); -} - std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context, boost::asio::ip::address const& remoteAddress) { using namespace std::string_literals; - std::shared_ptr<Trinity::Net::Http::SessionState> state; - auto cookieItr = context.request.find(boost::beast::http::field::cookie); if (cookieItr != context.request.end()) { std::vector<std::string_view> cookies = Trinity::Tokenize(Trinity::Net::Http::ToStdStringView(cookieItr->value()), ';', false); - std::size_t eq = 0; - auto sessionIdItr = std::find_if(cookies.begin(), cookies.end(), [&](std::string_view cookie) + auto sessionIdItr = std::ranges::find_if(cookies, [](std::string_view const& cookie) { - std::string_view name = cookie; - eq = cookie.find('='); - if (eq != std::string_view::npos) - name = cookie.substr(0, eq); - - return name == LoginHttpSessionWrapper::SESSION_ID_COOKIE; + return cookie.length() > Battlenet::LoginHttpSession::SESSION_ID_COOKIE.length() + && cookie.starts_with(Battlenet::LoginHttpSession::SESSION_ID_COOKIE); }); if (sessionIdItr != cookies.end()) { - std::string_view value = sessionIdItr->substr(eq + 1); - state = sLoginService.FindAndRefreshSessionState(value, remoteAddress); + sessionIdItr->remove_prefix(Battlenet::LoginHttpSession::SESSION_ID_COOKIE.length()); + state = sLoginService.FindAndRefreshSessionState(*sessionIdItr, remoteAddress); } } - if (!state) { state = sLoginService.CreateNewSessionState(remoteAddress); std::string_view host = Trinity::Net::Http::ToStdStringView(context.request[boost::beast::http::field::host]); - if (std::size_t port = host.find(':'); port != std::string_view::npos) + if (size_t port = host.find(':'); port != std::string_view::npos) host.remove_suffix(host.length() - port); - context.response.insert(boost::beast::http::field::set_cookie, Trinity::StringFormat("{}={}; Path=/bnetserver; Domain={}; Secure; HttpOnly; SameSite=None", - LoginHttpSessionWrapper::SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host)); + context.response.insert(boost::beast::http::field::set_cookie, Trinity::StringFormat("{}{}; Path=/bnetserver; Domain={}; Secure; HttpOnly; SameSite=None", + Battlenet::LoginHttpSession::SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host)); } - return state; } -template class LoginHttpSession<Trinity::Net::Http::SslSocket>; -template class LoginHttpSession<Trinity::Net::Http::Socket>; +template<typename SocketImpl> +class LoginHttpSocketImpl final : public SocketImpl +{ +public: + using BaseSocket = SocketImpl; + + explicit LoginHttpSocketImpl(Trinity::Net::IoContextTcpSocket&& socket, Battlenet::LoginHttpSession& owner) + : BaseSocket(std::move(socket)), _owner(owner) + { + } + + LoginHttpSocketImpl(LoginHttpSocketImpl const&) = delete; + LoginHttpSocketImpl(LoginHttpSocketImpl&&) = delete; + LoginHttpSocketImpl& operator=(LoginHttpSocketImpl const&) = delete; + LoginHttpSocketImpl& operator=(LoginHttpSocketImpl&&) = delete; + + ~LoginHttpSocketImpl() = default; + + void Start() override + { + // build initializer chain + boost::container::static_vector<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 4> initializers; + + initializers.stable_emplace_back(std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<BaseSocket>>(this)); + + if constexpr (std::is_same_v<BaseSocket, Trinity::Net::Http::SslSocket>) + initializers.stable_emplace_back(std::make_shared<Trinity::Net::SslHandshakeConnectionInitializer<BaseSocket>>(this)); + + initializers.stable_emplace_back(std::make_shared<Trinity::Net::Http::HttpConnectionInitializer<BaseSocket>>(this)); + initializers.stable_emplace_back(std::make_shared<Trinity::Net::ReadConnectionInitializer<BaseSocket>>(this)); + + Trinity::Net::SocketConnectionInitializer::SetupChain(std::span(initializers.data(), initializers.size()))->Start(); + } + + Trinity::Net::Http::RequestHandlerResult RequestHandler(Trinity::Net::Http::RequestContext& context) override + { + return sLoginService.HandleRequest(_owner.shared_from_this(), context); + } + +protected: + std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context) const override + { + return ::ObtainSessionState(context, this->GetRemoteIpAddress()); + } + + Battlenet::LoginHttpSession& _owner; +}; + +template<> +LoginHttpSocketImpl<Trinity::Net::Http::SslSocket>::LoginHttpSocketImpl(Trinity::Net::IoContextTcpSocket&& socket, Battlenet::LoginHttpSession& owner) + : BaseSocket(std::move(socket), Battlenet::SslContext::instance()), _owner(owner) +{ +} +} + +namespace Battlenet +{ +LoginHttpSession::LoginHttpSession(Trinity::Net::IoContextTcpSocket&& socket) + : _socket(!SslContext::UsesDevWildcardCertificate() + ? std::shared_ptr<AbstractSocket>(std::make_shared<LoginHttpSocketImpl<Trinity::Net::Http::SslSocket>>(std::move(socket), *this)) + : std::shared_ptr<AbstractSocket>(std::make_shared<LoginHttpSocketImpl<Trinity::Net::Http::Socket>>(std::move(socket), *this))) +{ +} -LoginHttpSessionWrapper::LoginHttpSessionWrapper(boost::asio::ip::tcp::socket&& socket) +void LoginHttpSession::Start() { - if (!SslContext::UsesDevWildcardCertificate()) - _socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::SslSocket>>(std::move(socket), *this); - else - _socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::Socket>>(std::move(socket), *this); + TC_LOG_TRACE("server.http.session", "{} Accepted connection", GetClientInfo()); + + return _socket->Start(); } } diff --git a/src/server/bnetserver/REST/LoginHttpSession.h b/src/server/bnetserver/REST/LoginHttpSession.h index 6bd1ec113d0..c15442f9e0c 100644 --- a/src/server/bnetserver/REST/LoginHttpSession.h +++ b/src/server/bnetserver/REST/LoginHttpSession.h @@ -18,10 +18,8 @@ #ifndef TRINITYCORE_LOGIN_HTTP_SESSION_H #define TRINITYCORE_LOGIN_HTTP_SESSION_H -#include "HttpSocket.h" -#include "HttpSslSocket.h" +#include "BaseHttpSocket.h" #include "SRP6.h" -#include <variant> namespace Battlenet { @@ -30,59 +28,27 @@ struct LoginSessionState : public Trinity::Net::Http::SessionState std::unique_ptr<Trinity::Crypto::SRP::BnetSRP6Base> Srp; }; -class LoginHttpSessionWrapper; -std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context, boost::asio::ip::address const& remoteAddress); - -template<template<typename> typename SocketImpl> -class LoginHttpSession : public SocketImpl<LoginHttpSession<SocketImpl>> +class LoginHttpSession : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSession> { - using BaseSocket = SocketImpl<LoginHttpSession<SocketImpl>>; - public: - explicit LoginHttpSession(boost::asio::ip::tcp::socket&& socket, LoginHttpSessionWrapper& owner); - ~LoginHttpSession(); - - void Start() override; + static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID="; - void CheckIpCallback(PreparedQueryResult result); + explicit LoginHttpSession(Trinity::Net::IoContextTcpSocket&& socket); - Trinity::Net::Http::RequestHandlerResult RequestHandler(Trinity::Net::Http::RequestContext& context) override; - - LoginSessionState* GetSessionState() const { return static_cast<LoginSessionState*>(this->_state.get()); } - -protected: - std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context) const override - { - return Battlenet::ObtainSessionState(context, this->GetRemoteIpAddress()); - } - - LoginHttpSessionWrapper& _owner; -}; - -class LoginHttpSessionWrapper : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSessionWrapper> -{ -public: - static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID"; - - explicit LoginHttpSessionWrapper(boost::asio::ip::tcp::socket&& socket); - - void Start() { return std::visit([&](auto&& socket) { return socket->Start(); }, _socket); } - bool Update() { return std::visit([&](auto&& socket) { return socket->Update(); }, _socket); } - boost::asio::ip::address GetRemoteIpAddress() const { return std::visit([&](auto&& socket) { return socket->GetRemoteIpAddress(); }, _socket); } - bool IsOpen() const { return std::visit([&](auto&& socket) { return socket->IsOpen(); }, _socket); } - void CloseSocket() { return std::visit([&](auto&& socket) { return socket->CloseSocket(); }, _socket); } + void Start() override; + bool Update() override { return _socket->Update(); } + boost::asio::ip::address const& GetRemoteIpAddress() const override { return _socket->GetRemoteIpAddress(); } + bool IsOpen() const override { return _socket->IsOpen(); } + void CloseSocket() override { return _socket->CloseSocket(); } - void SendResponse(Trinity::Net::Http::RequestContext& context) override { return std::visit([&](auto&& socket) { return socket->SendResponse(context); }, _socket); } - void QueueQuery(QueryCallback&& queryCallback) override { return std::visit([&](auto&& socket) { return socket->QueueQuery(std::move(queryCallback)); }, _socket); } - std::string GetClientInfo() const override { return std::visit([&](auto&& socket) { return socket->GetClientInfo(); }, _socket); } - Optional<boost::uuids::uuid> GetSessionId() const override { return std::visit([&](auto&& socket) { return socket->GetSessionId(); }, _socket); } - LoginSessionState* GetSessionState() const { return std::visit([&](auto&& socket) { return socket->GetSessionState(); }, _socket); } + void SendResponse(Trinity::Net::Http::RequestContext& context) override { return _socket->SendResponse(context); } + void QueueQuery(QueryCallback&& queryCallback) override { return _socket->QueueQuery(std::move(queryCallback)); } + std::string GetClientInfo() const override { return _socket->GetClientInfo(); } + LoginSessionState* GetSessionState() const override { return static_cast<LoginSessionState*>(_socket->GetSessionState()); } private: - std::variant< - std::shared_ptr<LoginHttpSession<Trinity::Net::Http::SslSocket>>, - std::shared_ptr<LoginHttpSession<Trinity::Net::Http::Socket>> - > _socket; + std::shared_ptr<Trinity::Net::Http::AbstractSocket> _socket; }; } + #endif // TRINITYCORE_LOGIN_HTTP_SESSION_H diff --git a/src/server/bnetserver/REST/LoginRESTService.cpp b/src/server/bnetserver/REST/LoginRESTService.cpp index 12f6cf63712..1afcc5c88bb 100644 --- a/src/server/bnetserver/REST/LoginRESTService.cpp +++ b/src/server/bnetserver/REST/LoginRESTService.cpp @@ -44,32 +44,32 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st using Trinity::Net::Http::RequestHandlerFlag; - RegisterHandler(boost::beast::http::verb::get, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::get, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandleGetForm(std::move(session), context); }); - RegisterHandler(boost::beast::http::verb::get, "/bnetserver/gameAccounts/", [](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::get, "/bnetserver/gameAccounts/", [](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandleGetGameAccounts(std::move(session), context); }); - RegisterHandler(boost::beast::http::verb::get, "/bnetserver/portal/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::get, "/bnetserver/portal/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandleGetPortal(std::move(session), context); }); - RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandlePostLogin(std::move(session), context); }, RequestHandlerFlag::DoNotLogRequestContent); - RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/srp/", [](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/srp/", [](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandlePostLoginSrpChallenge(std::move(session), context); }); - RegisterHandler(boost::beast::http::verb::post, "/bnetserver/refreshLoginTicket/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) + RegisterHandler(boost::beast::http::verb::post, "/bnetserver/refreshLoginTicket/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { return HandlePostRefreshLoginTicket(std::move(session), context); }); @@ -121,7 +121,10 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st MigrateLegacyPasswordHashes(); - _acceptor->AsyncAcceptWithCallback<&LoginRESTService::OnSocketAccept>(); + _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) + { + OnSocketOpen(std::move(sock), threadIndex); + }); return true; } @@ -165,7 +168,7 @@ std::string LoginRESTService::ExtractAuthorization(HttpRequest const& request) return ticket; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const +LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const { JSON::Login::FormInputs form = _formInputs; form.set_srp_url(Trinity::StringFormat("http{}://{}:{}/bnetserver/login/srp/", !SslContext::UsesDevWildcardCertificate() ? "s" : "", @@ -176,7 +179,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shar return RequestHandlerResult::Handled; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) +LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { std::string ticket = ExtractAuthorization(context.request); if (ticket.empty()) @@ -225,14 +228,14 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(s return RequestHandlerResult::Async; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetPortal(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const +LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetPortal(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const { context.response.set(boost::beast::http::field::content_type, "text/plain"); context.response.body() = Trinity::StringFormat("{}:{}", GetHostnameForClient(session->GetRemoteIpAddress()), sConfigMgr->GetIntDefault("BattlenetPort", 1119)); return RequestHandlerResult::Handled; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const +LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const { std::shared_ptr<JSON::Login::LoginForm> loginForm = std::make_shared<JSON::Login::LoginForm>(); if (!::JSON::Deserialize(context.request.body(), loginForm.get())) @@ -396,7 +399,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::sh return RequestHandlerResult::Async; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) +LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { JSON::Login::LoginForm loginForm; if (!::JSON::Deserialize(context.request.body(), &loginForm)) @@ -483,7 +486,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChall return RequestHandlerResult::Async; } -LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const +LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const { std::string ticket = ExtractAuthorization(context.request); if (ticket.empty()) @@ -551,11 +554,6 @@ std::shared_ptr<Trinity::Net::Http::SessionState> LoginRESTService::CreateNewSes return state; } -void LoginRESTService::OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex) -{ - sLoginService.OnSocketOpen(std::move(sock), threadIndex); -} - void LoginRESTService::MigrateLegacyPasswordHashes() const { if (!LoginDatabase.Query("SELECT 1 FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = SCHEMA() AND TABLE_NAME = 'battlenet_accounts' AND COLUMN_NAME = 'sha_pass_hash'")) diff --git a/src/server/bnetserver/REST/LoginRESTService.h b/src/server/bnetserver/REST/LoginRESTService.h index 079bb6fca9c..517f5e295b7 100644 --- a/src/server/bnetserver/REST/LoginRESTService.h +++ b/src/server/bnetserver/REST/LoginRESTService.h @@ -15,8 +15,8 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef LoginRESTService_h__ -#define LoginRESTService_h__ +#ifndef TRINITYCORE_LOGIN_REST_SERVICE_H +#define TRINITYCORE_LOGIN_REST_SERVICE_H #include "HttpService.h" #include "Login.pb.h" @@ -42,7 +42,7 @@ enum class BanMode BAN_ACCOUNT = 1 }; -class LoginRESTService : public Trinity::Net::Http::HttpService<LoginHttpSessionWrapper> +class LoginRESTService : public Trinity::Net::Http::HttpService<LoginHttpSession> { public: using RequestHandlerResult = Trinity::Net::Http::RequestHandlerResult; @@ -63,17 +63,15 @@ public: std::shared_ptr<Trinity::Net::Http::SessionState> CreateNewSessionState(boost::asio::ip::address const& address) override; private: - static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex); - static std::string ExtractAuthorization(HttpRequest const& request); - RequestHandlerResult HandleGetForm(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const; - static RequestHandlerResult HandleGetGameAccounts(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context); - RequestHandlerResult HandleGetPortal(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const; + RequestHandlerResult HandleGetForm(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const; + static RequestHandlerResult HandleGetGameAccounts(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context); + RequestHandlerResult HandleGetPortal(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const; - RequestHandlerResult HandlePostLogin(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const; - static RequestHandlerResult HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context); - RequestHandlerResult HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const; + RequestHandlerResult HandlePostLogin(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const; + static RequestHandlerResult HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context); + RequestHandlerResult HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const; static std::unique_ptr<Trinity::Crypto::SRP::BnetSRP6Base> CreateSrpImplementation(SrpVersion version, SrpHashFunction hashFunction, std::string const& username, Trinity::Crypto::SRP::Salt const& salt, Trinity::Crypto::SRP::Verifier const& verifier); @@ -90,4 +88,4 @@ private: #define sLoginService Battlenet::LoginRESTService::Instance() -#endif // LoginRESTService_h__ +#endif // TRINITYCORE_LOGIN_REST_SERVICE_H diff --git a/src/server/bnetserver/Server/Session.cpp b/src/server/bnetserver/Server/Session.cpp index c3a8ae6f211..7b9d6f45a3e 100644 --- a/src/server/bnetserver/Server/Session.cpp +++ b/src/server/bnetserver/Server/Session.cpp @@ -23,6 +23,7 @@ #include "Errors.h" #include "Hash.h" #include "IPLocation.h" +#include "IpBanCheckConnectionInitializer.h" #include "LoginRESTService.h" #include "MapUtils.h" #include "ProtobufJSON.h" @@ -74,7 +75,7 @@ void Battlenet::Session::GameAccountInfo::LoadResult(Field const* fields) DisplayName = Name; } -Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSocket(std::move(socket), SslContext::instance()), +Battlenet::Session::Session(Trinity::Net::IoContextTcpSocket&& socket) : BaseSocket(std::move(socket), SslContext::instance()), _accountInfo(new AccountInfo()), _gameAccountInfo(nullptr), _locale(), _os(), _build(0), _clientInfo(), _timezoneOffset(0min), _ipCountry(), _clientSecret(), _authed(false), _requestToken(0) { @@ -83,54 +84,24 @@ Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSo Battlenet::Session::~Session() = default; -void Battlenet::Session::AsyncHandshake() -{ - underlying_stream().async_handshake(boost::asio::ssl::stream_base::server, - [sess = shared_from_this()](boost::system::error_code const& error) { sess->HandshakeHandler(error); }); -} - void Battlenet::Session::Start() { - std::string ip_address = GetRemoteIpAddress().to_string(); TC_LOG_TRACE("session", "{} Accepted connection", GetClientInfo()); - // Verify that this IP is not in the ip_banned table - LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS)); + // build initializer chain + std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers = + { { + std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<Session>>(this), + std::make_shared<Trinity::Net::SslHandshakeConnectionInitializer<Session>>(this), + std::make_shared<Trinity::Net::ReadConnectionInitializer<Session>>(this), + } }; - LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); - stmt->setString(0, ip_address); - - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt) - .WithPreparedCallback([sess = shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); })); -} - -void Battlenet::Session::CheckIpCallback(PreparedQueryResult result) -{ - if (result) - { - bool banned = false; - do - { - Field* fields = result->Fetch(); - if (fields[0].GetUInt64() != 0) - banned = true; - - } while (result->NextRow()); - - if (banned) - { - TC_LOG_DEBUG("session", "{} tries to log in using banned IP!", GetClientInfo()); - CloseSocket(); - return; - } - } - - AsyncHandshake(); + Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start(); } bool Battlenet::Session::Update() { - if (!BattlenetSocket::Update()) + if (!BaseSocket::Update()) return false; _queryProcessor.ProcessReadyCallbacks(); @@ -211,6 +182,11 @@ void Battlenet::Session::SendRequest(uint32 serviceHash, uint32 methodId, pb::Me AsyncWrite(&packet); } +void Battlenet::Session::QueueQuery(QueryCallback&& queryCallback) +{ + _queryProcessor.AddCallback(std::move(queryCallback)); +} + uint32 Battlenet::Session::HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation) { if (logonRequest->program() != "WoW") @@ -292,7 +268,7 @@ uint32 Battlenet::Session::HandleGenerateWebCredentials(authentication::v1::Gene LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_BNET_EXISTING_AUTHENTICATION_BY_ID); stmt->setUInt32(0, _accountInfo->Id); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result) + QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result) { // just send existing credentials back (not the best but it works for now with them being stored in db) Battlenet::Services::Authentication asyncContinuationService(this); @@ -314,7 +290,7 @@ uint32 Battlenet::Session::VerifyWebCredentials(std::string const& webCredential std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)> asyncContinuation = std::move(continuation); std::shared_ptr<AccountInfo> accountInfo = std::make_shared<AccountInfo>(); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result) + QueueQuery(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result) { Battlenet::Services::Authentication asyncContinuationService(this); NoData response; @@ -713,20 +689,8 @@ uint32 Battlenet::Session::HandleGetAllValuesForAttribute(game_utilities::v1::Ge return ERROR_RPC_NOT_IMPLEMENTED; } -void Battlenet::Session::HandshakeHandler(boost::system::error_code const& error) -{ - if (error) - { - TC_LOG_ERROR("session", "{} SSL Handshake failed {}", GetClientInfo(), error.message()); - CloseSocket(); - return; - } - - AsyncRead(); -} - template<bool(Battlenet::Session::*processMethod)(), MessageBuffer Battlenet::Session::*outputBuffer> -inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer) +static inline Optional<Trinity::Net::SocketReadCallbackResult> PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer) { MessageBuffer& buffer = session->*outputBuffer; @@ -738,46 +702,45 @@ inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inp buffer.Write(inputBuffer.GetReadPointer(), readDataSize); inputBuffer.ReadCompleted(readDataSize); } + else + return { }; // go to next buffer if (buffer.GetRemainingSpace() > 0) { // Couldn't receive the whole data this time. ASSERT(inputBuffer.GetActiveSize() == 0); - return false; + return Trinity::Net::SocketReadCallbackResult::KeepReading; } // just received fresh new payload if (!(session->*processMethod)()) { session->CloseSocket(); - return false; + return Trinity::Net::SocketReadCallbackResult::Stop; } - return true; + return { }; // go to next buffer } -void Battlenet::Session::ReadHandler() +Trinity::Net::SocketReadCallbackResult Battlenet::Session::ReadHandler() { - if (!IsOpen()) - return; - MessageBuffer& packet = GetReadBuffer(); while (packet.GetActiveSize() > 0) { - if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderLengthHandler, &Battlenet::Session::_headerLengthBuffer>(this, packet)) - break; + if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderLengthHandler, &Session::_headerLengthBuffer>(this, packet)) + return *partialResult; - if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderHandler, &Battlenet::Session::_headerBuffer>(this, packet)) - break; + if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderHandler, &Session::_headerBuffer>(this, packet)) + return *partialResult; - if (!PartialProcessPacket<&Battlenet::Session::ReadDataHandler, &Battlenet::Session::_packetBuffer>(this, packet)) - break; + if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadDataHandler, &Session::_packetBuffer>(this, packet)) + return *partialResult; _headerLengthBuffer.Reset(); _headerBuffer.Reset(); } - AsyncRead(); + return Trinity::Net::SocketReadCallbackResult::KeepReading; } bool Battlenet::Session::ReadHeaderLengthHandler() @@ -835,5 +798,5 @@ std::string Battlenet::Session::GetClientInfo() const stream << ']'; - return stream.str(); + return std::move(stream).str(); } diff --git a/src/server/bnetserver/Server/Session.h b/src/server/bnetserver/Server/Session.h index 8a8b1f0a15b..faf82115f18 100644 --- a/src/server/bnetserver/Server/Session.h +++ b/src/server/bnetserver/Server/Session.h @@ -15,8 +15,8 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef Session_h__ -#define Session_h__ +#ifndef TRINITYCORE_SESSION_H +#define TRINITYCORE_SESSION_H #include "AsyncCallbackProcessor.h" #include "ClientBuildInfo.h" @@ -24,7 +24,7 @@ #include "QueryResult.h" #include "Realm.h" #include "Socket.h" -#include "SslSocket.h" +#include "SslStream.h" #include <boost/asio/ip/tcp.hpp> #include <google/protobuf/message.h> #include <memory> @@ -65,9 +65,9 @@ using namespace bgs::protocol; namespace Battlenet { - class Session : public Socket<Session, SslSocket<>> + class Session final : public Trinity::Net::Socket<Trinity::Net::SslStream<>> { - typedef Socket<Session, SslSocket<>> BattlenetSocket; + using BaseSocket = Socket<Trinity::Net::SslStream<>>; public: struct LastPlayedCharacterInfo @@ -110,7 +110,7 @@ namespace Battlenet std::unordered_map<uint32, GameAccountInfo> GameAccounts; }; - explicit Session(boost::asio::ip::tcp::socket&& socket); + explicit Session(Trinity::Net::IoContextTcpSocket&& socket); ~Session(); void Start() override; @@ -130,6 +130,8 @@ namespace Battlenet void SendRequest(uint32 serviceHash, uint32 methodId, pb::Message const* request); + void QueueQuery(QueryCallback&& queryCallback); + uint32 HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation); uint32 HandleVerifyWebCredentials(authentication::v1::VerifyWebCredentialsRequest const* verifyWebCredentialsRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation); uint32 HandleGenerateWebCredentials(authentication::v1::GenerateWebCredentialsRequest const* request, std::function<void(ServiceBase*, uint32, google::protobuf::Message const*)>& continuation); @@ -140,9 +142,9 @@ namespace Battlenet std::string GetClientInfo() const; + Trinity::Net::SocketReadCallbackResult ReadHandler() override; + protected: - void HandshakeHandler(boost::system::error_code const& error); - void ReadHandler() override; bool ReadHeaderLengthHandler(); bool ReadHeaderHandler(); bool ReadDataHandler(); @@ -150,10 +152,6 @@ namespace Battlenet private: void AsyncWrite(MessageBuffer* packet); - void AsyncHandshake(); - - void CheckIpCallback(PreparedQueryResult result); - uint32 VerifyWebCredentials(std::string const& webCredentials, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation); typedef uint32(Session::*ClientRequestHandler)(std::unordered_map<std::string, Variant const*> const&, game_utilities::v1::ClientResponse*); @@ -190,4 +188,4 @@ namespace Battlenet }; } -#endif // Session_h__ +#endif // TRINITYCORE_SESSION_H diff --git a/src/server/bnetserver/Server/SessionManager.cpp b/src/server/bnetserver/Server/SessionManager.cpp index cb56972a9c2..4c5b532ee60 100644 --- a/src/server/bnetserver/Server/SessionManager.cpp +++ b/src/server/bnetserver/Server/SessionManager.cpp @@ -16,7 +16,6 @@ */ #include "SessionManager.h" -#include "DatabaseEnv.h" #include "Util.h" bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) @@ -24,18 +23,16 @@ bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount)) return false; - _acceptor->AsyncAcceptWithCallback<&OnSocketAccept>(); + _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) + { + OnSocketOpen(std::move(sock), threadIndex); + }); return true; } -NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() const +Trinity::Net::NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() const { - return new NetworkThread<Session>[GetNetworkThreadCount()]; -} - -void Battlenet::SessionManager::OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex) -{ - sSessionMgr.OnSocketOpen(std::move(sock), threadIndex); + return new Trinity::Net::NetworkThread<Session>[GetNetworkThreadCount()]; } Battlenet::SessionManager& Battlenet::SessionManager::Instance() diff --git a/src/server/bnetserver/Server/SessionManager.h b/src/server/bnetserver/Server/SessionManager.h index c635122c977..528ece8739e 100644 --- a/src/server/bnetserver/Server/SessionManager.h +++ b/src/server/bnetserver/Server/SessionManager.h @@ -15,15 +15,15 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef SessionManager_h__ -#define SessionManager_h__ +#ifndef TRINITYCORE_SESSION_MANAGER_H +#define TRINITYCORE_SESSION_MANAGER_H #include "SocketMgr.h" #include "Session.h" namespace Battlenet { - class SessionManager : public SocketMgr<Session> + class SessionManager : public Trinity::Net::SocketMgr<Session> { typedef SocketMgr<Session> BaseSocketMgr; @@ -33,13 +33,10 @@ namespace Battlenet bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount = 1) override; protected: - NetworkThread<Session>* CreateThreads() const override; - - private: - static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex); + Trinity::Net::NetworkThread<Session>* CreateThreads() const override; }; } #define sSessionMgr Battlenet::SessionManager::Instance() -#endif // SessionManager_h__ +#endif // TRINITYCORE_SESSION_MANAGER_H |
