aboutsummaryrefslogtreecommitdiff
path: root/src/server/bnetserver/REST/LoginHttpSession.cpp
diff options
context:
space:
mode:
authorfunjoker <funjoker109@gmail.com>2024-03-28 20:18:59 +0100
committerfunjoker <funjoker109@gmail.com>2024-03-28 20:20:04 +0100
commite769af1044f524ee9ef469a3d1bfb728ee5ef4d0 (patch)
tree8da8700300f27e6070a12356e97f459a7841b081 /src/server/bnetserver/REST/LoginHttpSession.cpp
parentcdaf8ffc068ef00a3427af2b95a46e360d63e1fc (diff)
Core: port sneaky fixes from "Core: Updated to 10.2.6.53840"
Diffstat (limited to 'src/server/bnetserver/REST/LoginHttpSession.cpp')
-rw-r--r--src/server/bnetserver/REST/LoginHttpSession.cpp62
1 files changed, 43 insertions, 19 deletions
diff --git a/src/server/bnetserver/REST/LoginHttpSession.cpp b/src/server/bnetserver/REST/LoginHttpSession.cpp
index 95112cb8836..aff579de7f9 100644
--- a/src/server/bnetserver/REST/LoginHttpSession.cpp
+++ b/src/server/bnetserver/REST/LoginHttpSession.cpp
@@ -23,17 +23,20 @@
namespace Battlenet
{
-LoginHttpSession::LoginHttpSession(boost::asio::ip::tcp::socket&& socket)
- : SslSocket(std::move(socket), SslContext::instance())
+template<template<typename> typename SocketImpl>
+LoginHttpSession<SocketImpl>::LoginHttpSession(boost::asio::ip::tcp::socket&& socket, LoginHttpSessionWrapper& owner)
+ : BaseSocket(std::move(socket), SslContext::instance()), _owner(owner)
{
}
-LoginHttpSession::~LoginHttpSession() = default;
+template<template<typename> typename SocketImpl>
+LoginHttpSession<SocketImpl>::~LoginHttpSession() = default;
-void LoginHttpSession::Start()
+template<template<typename> typename SocketImpl>
+void LoginHttpSession<SocketImpl>::Start()
{
- std::string ip_address = GetRemoteIpAddress().to_string();
- TC_LOG_TRACE("server.http.session", "{} Accepted connection", GetClientInfo());
+ std::string ip_address = this->GetRemoteIpAddress().to_string();
+ TC_LOG_TRACE("server.http.session", "{} Accepted connection", this->GetClientInfo());
// Verify that this IP is not in the ip_banned table
LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS));
@@ -41,11 +44,12 @@ void LoginHttpSession::Start()
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
stmt->setString(0, ip_address);
- _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
- .WithPreparedCallback([sess = shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); }));
+ this->_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
+ .WithPreparedCallback([sess = this->shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); }));
}
-void LoginHttpSession::CheckIpCallback(PreparedQueryResult result)
+template<template<typename> typename SocketImpl>
+void LoginHttpSession<SocketImpl>::CheckIpCallback(PreparedQueryResult result)
{
if (result)
{
@@ -60,21 +64,30 @@ void LoginHttpSession::CheckIpCallback(PreparedQueryResult result)
if (banned)
{
- TC_LOG_DEBUG("server.http.session", "{} tries to log in using banned IP!", GetClientInfo());
- CloseSocket();
+ TC_LOG_DEBUG("server.http.session", "{} tries to log in using banned IP!", this->GetClientInfo());
+ this->CloseSocket();
return;
}
}
- AsyncHandshake();
+ if constexpr (std::is_same_v<BaseSocket, Trinity::Net::Http::SslSocket<LoginHttpSession<Trinity::Net::Http::SslSocket>>>)
+ {
+ this->AsyncHandshake();
+ }
+ else
+ {
+ this->ResetHttpParser();
+ this->AsyncRead();
+ }
}
-Trinity::Net::Http::RequestHandlerResult LoginHttpSession::RequestHandler(Trinity::Net::Http::RequestContext& context)
+template<template<typename> typename SocketImpl>
+Trinity::Net::Http::RequestHandlerResult LoginHttpSession<SocketImpl>::RequestHandler(Trinity::Net::Http::RequestContext& context)
{
- return sLoginService.HandleRequest(shared_from_this(), context);
+ return sLoginService.HandleRequest(_owner.shared_from_this(), context);
}
-std::shared_ptr<Trinity::Net::Http::SessionState> LoginHttpSession::ObtainSessionState(Trinity::Net::Http::RequestContext& context) const
+std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context, boost::asio::ip::address const& remoteAddress)
{
using namespace std::string_literals;
@@ -92,27 +105,38 @@ std::shared_ptr<Trinity::Net::Http::SessionState> LoginHttpSession::ObtainSessio
if (eq != std::string_view::npos)
name = cookie.substr(0, eq);
- return name == SESSION_ID_COOKIE;
+ return name == LoginHttpSessionWrapper::SESSION_ID_COOKIE;
});
if (sessionIdItr != cookies.end())
{
std::string_view value = sessionIdItr->substr(eq + 1);
- state = sLoginService.FindAndRefreshSessionState(value, GetRemoteIpAddress());
+ state = sLoginService.FindAndRefreshSessionState(value, remoteAddress);
}
}
if (!state)
{
- state = sLoginService.CreateNewSessionState(GetRemoteIpAddress());
+ state = sLoginService.CreateNewSessionState(remoteAddress);
std::string_view host = Trinity::Net::Http::ToStdStringView(context.request[boost::beast::http::field::host]);
if (std::size_t port = host.find(':'); port != std::string_view::npos)
host.remove_suffix(host.length() - port);
context.response.insert(boost::beast::http::field::set_cookie, Trinity::StringFormat("{}={}; Path=/bnetserver; Domain={}; Secure; HttpOnly; SameSite=None",
- SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host));
+ LoginHttpSessionWrapper::SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host));
}
return state;
}
+
+template class LoginHttpSession<Trinity::Net::Http::SslSocket>;
+template class LoginHttpSession<Trinity::Net::Http::Socket>;
+
+LoginHttpSessionWrapper::LoginHttpSessionWrapper(boost::asio::ip::tcp::socket&& socket)
+{
+ if (!SslContext::UsesDevWildcardCertificate())
+ _socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::SslSocket>>(std::move(socket), *this);
+ else
+ _socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::Socket>>(std::move(socket), *this);
+}
}