diff options
author | funjoker <funjoker109@gmail.com> | 2024-03-28 20:18:59 +0100 |
---|---|---|
committer | funjoker <funjoker109@gmail.com> | 2024-03-28 20:20:04 +0100 |
commit | e769af1044f524ee9ef469a3d1bfb728ee5ef4d0 (patch) | |
tree | 8da8700300f27e6070a12356e97f459a7841b081 /src/server/bnetserver/REST/LoginHttpSession.cpp | |
parent | cdaf8ffc068ef00a3427af2b95a46e360d63e1fc (diff) |
Core: port sneaky fixes from "Core: Updated to 10.2.6.53840"
Diffstat (limited to 'src/server/bnetserver/REST/LoginHttpSession.cpp')
-rw-r--r-- | src/server/bnetserver/REST/LoginHttpSession.cpp | 62 |
1 files changed, 43 insertions, 19 deletions
diff --git a/src/server/bnetserver/REST/LoginHttpSession.cpp b/src/server/bnetserver/REST/LoginHttpSession.cpp index 95112cb8836..aff579de7f9 100644 --- a/src/server/bnetserver/REST/LoginHttpSession.cpp +++ b/src/server/bnetserver/REST/LoginHttpSession.cpp @@ -23,17 +23,20 @@ namespace Battlenet { -LoginHttpSession::LoginHttpSession(boost::asio::ip::tcp::socket&& socket) - : SslSocket(std::move(socket), SslContext::instance()) +template<template<typename> typename SocketImpl> +LoginHttpSession<SocketImpl>::LoginHttpSession(boost::asio::ip::tcp::socket&& socket, LoginHttpSessionWrapper& owner) + : BaseSocket(std::move(socket), SslContext::instance()), _owner(owner) { } -LoginHttpSession::~LoginHttpSession() = default; +template<template<typename> typename SocketImpl> +LoginHttpSession<SocketImpl>::~LoginHttpSession() = default; -void LoginHttpSession::Start() +template<template<typename> typename SocketImpl> +void LoginHttpSession<SocketImpl>::Start() { - std::string ip_address = GetRemoteIpAddress().to_string(); - TC_LOG_TRACE("server.http.session", "{} Accepted connection", GetClientInfo()); + std::string ip_address = this->GetRemoteIpAddress().to_string(); + TC_LOG_TRACE("server.http.session", "{} Accepted connection", this->GetClientInfo()); // Verify that this IP is not in the ip_banned table LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS)); @@ -41,11 +44,12 @@ void LoginHttpSession::Start() LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); stmt->setString(0, ip_address); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt) - .WithPreparedCallback([sess = shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); })); + this->_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt) + .WithPreparedCallback([sess = this->shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); })); } -void LoginHttpSession::CheckIpCallback(PreparedQueryResult result) +template<template<typename> typename SocketImpl> +void LoginHttpSession<SocketImpl>::CheckIpCallback(PreparedQueryResult result) { if (result) { @@ -60,21 +64,30 @@ void LoginHttpSession::CheckIpCallback(PreparedQueryResult result) if (banned) { - TC_LOG_DEBUG("server.http.session", "{} tries to log in using banned IP!", GetClientInfo()); - CloseSocket(); + TC_LOG_DEBUG("server.http.session", "{} tries to log in using banned IP!", this->GetClientInfo()); + this->CloseSocket(); return; } } - AsyncHandshake(); + if constexpr (std::is_same_v<BaseSocket, Trinity::Net::Http::SslSocket<LoginHttpSession<Trinity::Net::Http::SslSocket>>>) + { + this->AsyncHandshake(); + } + else + { + this->ResetHttpParser(); + this->AsyncRead(); + } } -Trinity::Net::Http::RequestHandlerResult LoginHttpSession::RequestHandler(Trinity::Net::Http::RequestContext& context) +template<template<typename> typename SocketImpl> +Trinity::Net::Http::RequestHandlerResult LoginHttpSession<SocketImpl>::RequestHandler(Trinity::Net::Http::RequestContext& context) { - return sLoginService.HandleRequest(shared_from_this(), context); + return sLoginService.HandleRequest(_owner.shared_from_this(), context); } -std::shared_ptr<Trinity::Net::Http::SessionState> LoginHttpSession::ObtainSessionState(Trinity::Net::Http::RequestContext& context) const +std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context, boost::asio::ip::address const& remoteAddress) { using namespace std::string_literals; @@ -92,27 +105,38 @@ std::shared_ptr<Trinity::Net::Http::SessionState> LoginHttpSession::ObtainSessio if (eq != std::string_view::npos) name = cookie.substr(0, eq); - return name == SESSION_ID_COOKIE; + return name == LoginHttpSessionWrapper::SESSION_ID_COOKIE; }); if (sessionIdItr != cookies.end()) { std::string_view value = sessionIdItr->substr(eq + 1); - state = sLoginService.FindAndRefreshSessionState(value, GetRemoteIpAddress()); + state = sLoginService.FindAndRefreshSessionState(value, remoteAddress); } } if (!state) { - state = sLoginService.CreateNewSessionState(GetRemoteIpAddress()); + state = sLoginService.CreateNewSessionState(remoteAddress); std::string_view host = Trinity::Net::Http::ToStdStringView(context.request[boost::beast::http::field::host]); if (std::size_t port = host.find(':'); port != std::string_view::npos) host.remove_suffix(host.length() - port); context.response.insert(boost::beast::http::field::set_cookie, Trinity::StringFormat("{}={}; Path=/bnetserver; Domain={}; Secure; HttpOnly; SameSite=None", - SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host)); + LoginHttpSessionWrapper::SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host)); } return state; } + +template class LoginHttpSession<Trinity::Net::Http::SslSocket>; +template class LoginHttpSession<Trinity::Net::Http::Socket>; + +LoginHttpSessionWrapper::LoginHttpSessionWrapper(boost::asio::ip::tcp::socket&& socket) +{ + if (!SslContext::UsesDevWildcardCertificate()) + _socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::SslSocket>>(std::move(socket), *this); + else + _socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::Socket>>(std::move(socket), *this); +} } |