mirror of
https://github.com/TrinityCore/TrinityCore.git
synced 2026-01-15 23:20:36 +01:00
Core/Network: Socket refactors
* Devirtualize calls to Read and Update by marking concrete implementations as final * Removed derived class template argument * Specialize boost::asio::basic_stream_socket for boost::asio::io_context instead of type-erased any_io_executor * Make socket initialization easier composable (before entering Read loop) * Remove use of deprecated boost::asio::null_buffers and boost::beast::ssl_stream
This commit is contained in:
@@ -17,126 +17,117 @@
|
||||
|
||||
#include "LoginHttpSession.h"
|
||||
#include "DatabaseEnv.h"
|
||||
#include "HttpSocket.h"
|
||||
#include "HttpSslSocket.h"
|
||||
#include "IpBanCheckConnectionInitializer.h"
|
||||
#include "LoginRESTService.h"
|
||||
#include "SslContext.h"
|
||||
#include "Util.h"
|
||||
#include <boost/container/static_vector.hpp>
|
||||
|
||||
namespace Battlenet
|
||||
namespace
|
||||
{
|
||||
template<template<typename> typename SocketImpl>
|
||||
LoginHttpSession<SocketImpl>::LoginHttpSession(boost::asio::ip::tcp::socket&& socket, LoginHttpSessionWrapper& owner)
|
||||
: BaseSocket(std::move(socket), SslContext::instance()), _owner(owner)
|
||||
{
|
||||
}
|
||||
|
||||
template<template<typename> typename SocketImpl>
|
||||
LoginHttpSession<SocketImpl>::~LoginHttpSession() = default;
|
||||
|
||||
template<template<typename> typename SocketImpl>
|
||||
void LoginHttpSession<SocketImpl>::Start()
|
||||
{
|
||||
std::string ip_address = this->GetRemoteIpAddress().to_string();
|
||||
TC_LOG_TRACE("server.http.session", "{} Accepted connection", this->GetClientInfo());
|
||||
|
||||
// Verify that this IP is not in the ip_banned table
|
||||
LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS));
|
||||
|
||||
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
|
||||
stmt->setString(0, ip_address);
|
||||
|
||||
this->_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
|
||||
.WithPreparedCallback([sess = this->shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); }));
|
||||
}
|
||||
|
||||
template<template<typename> typename SocketImpl>
|
||||
void LoginHttpSession<SocketImpl>::CheckIpCallback(PreparedQueryResult result)
|
||||
{
|
||||
if (result)
|
||||
{
|
||||
bool banned = false;
|
||||
do
|
||||
{
|
||||
Field* fields = result->Fetch();
|
||||
if (fields[0].GetUInt64() != 0)
|
||||
banned = true;
|
||||
|
||||
} while (result->NextRow());
|
||||
|
||||
if (banned)
|
||||
{
|
||||
TC_LOG_DEBUG("server.http.session", "{} tries to log in using banned IP!", this->GetClientInfo());
|
||||
this->CloseSocket();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<BaseSocket, Trinity::Net::Http::SslSocket<LoginHttpSession<Trinity::Net::Http::SslSocket>>>)
|
||||
{
|
||||
this->AsyncHandshake();
|
||||
}
|
||||
else
|
||||
{
|
||||
this->ResetHttpParser();
|
||||
this->AsyncRead();
|
||||
}
|
||||
}
|
||||
|
||||
template<template<typename> typename SocketImpl>
|
||||
Trinity::Net::Http::RequestHandlerResult LoginHttpSession<SocketImpl>::RequestHandler(Trinity::Net::Http::RequestContext& context)
|
||||
{
|
||||
return sLoginService.HandleRequest(_owner.shared_from_this(), context);
|
||||
}
|
||||
|
||||
std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context, boost::asio::ip::address const& remoteAddress)
|
||||
{
|
||||
using namespace std::string_literals;
|
||||
|
||||
std::shared_ptr<Trinity::Net::Http::SessionState> state;
|
||||
|
||||
auto cookieItr = context.request.find(boost::beast::http::field::cookie);
|
||||
if (cookieItr != context.request.end())
|
||||
{
|
||||
std::vector<std::string_view> cookies = Trinity::Tokenize(Trinity::Net::Http::ToStdStringView(cookieItr->value()), ';', false);
|
||||
std::size_t eq = 0;
|
||||
auto sessionIdItr = std::find_if(cookies.begin(), cookies.end(), [&](std::string_view cookie)
|
||||
auto sessionIdItr = std::ranges::find_if(cookies, [](std::string_view const& cookie)
|
||||
{
|
||||
std::string_view name = cookie;
|
||||
eq = cookie.find('=');
|
||||
if (eq != std::string_view::npos)
|
||||
name = cookie.substr(0, eq);
|
||||
|
||||
return name == LoginHttpSessionWrapper::SESSION_ID_COOKIE;
|
||||
return cookie.length() > Battlenet::LoginHttpSession::SESSION_ID_COOKIE.length()
|
||||
&& cookie.starts_with(Battlenet::LoginHttpSession::SESSION_ID_COOKIE);
|
||||
});
|
||||
if (sessionIdItr != cookies.end())
|
||||
{
|
||||
std::string_view value = sessionIdItr->substr(eq + 1);
|
||||
state = sLoginService.FindAndRefreshSessionState(value, remoteAddress);
|
||||
sessionIdItr->remove_prefix(Battlenet::LoginHttpSession::SESSION_ID_COOKIE.length());
|
||||
state = sLoginService.FindAndRefreshSessionState(*sessionIdItr, remoteAddress);
|
||||
}
|
||||
}
|
||||
|
||||
if (!state)
|
||||
{
|
||||
state = sLoginService.CreateNewSessionState(remoteAddress);
|
||||
|
||||
std::string_view host = Trinity::Net::Http::ToStdStringView(context.request[boost::beast::http::field::host]);
|
||||
if (std::size_t port = host.find(':'); port != std::string_view::npos)
|
||||
if (size_t port = host.find(':'); port != std::string_view::npos)
|
||||
host.remove_suffix(host.length() - port);
|
||||
|
||||
context.response.insert(boost::beast::http::field::set_cookie, Trinity::StringFormat("{}={}; Path=/bnetserver; Domain={}; Secure; HttpOnly; SameSite=None",
|
||||
LoginHttpSessionWrapper::SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host));
|
||||
context.response.insert(boost::beast::http::field::set_cookie, Trinity::StringFormat("{}{}; Path=/bnetserver; Domain={}; Secure; HttpOnly; SameSite=None",
|
||||
Battlenet::LoginHttpSession::SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host));
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
template class LoginHttpSession<Trinity::Net::Http::SslSocket>;
|
||||
template class LoginHttpSession<Trinity::Net::Http::Socket>;
|
||||
|
||||
LoginHttpSessionWrapper::LoginHttpSessionWrapper(boost::asio::ip::tcp::socket&& socket)
|
||||
template<typename SocketImpl>
|
||||
class LoginHttpSocketImpl final : public SocketImpl
|
||||
{
|
||||
public:
|
||||
using BaseSocket = SocketImpl;
|
||||
|
||||
explicit LoginHttpSocketImpl(Trinity::Net::IoContextTcpSocket&& socket, Battlenet::LoginHttpSession& owner)
|
||||
: BaseSocket(std::move(socket)), _owner(owner)
|
||||
{
|
||||
}
|
||||
|
||||
LoginHttpSocketImpl(LoginHttpSocketImpl const&) = delete;
|
||||
LoginHttpSocketImpl(LoginHttpSocketImpl&&) = delete;
|
||||
LoginHttpSocketImpl& operator=(LoginHttpSocketImpl const&) = delete;
|
||||
LoginHttpSocketImpl& operator=(LoginHttpSocketImpl&&) = delete;
|
||||
|
||||
~LoginHttpSocketImpl() = default;
|
||||
|
||||
void Start() override
|
||||
{
|
||||
// build initializer chain
|
||||
boost::container::static_vector<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 4> initializers;
|
||||
|
||||
initializers.stable_emplace_back(std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<BaseSocket>>(this));
|
||||
|
||||
if constexpr (std::is_same_v<BaseSocket, Trinity::Net::Http::SslSocket>)
|
||||
initializers.stable_emplace_back(std::make_shared<Trinity::Net::SslHandshakeConnectionInitializer<BaseSocket>>(this));
|
||||
|
||||
initializers.stable_emplace_back(std::make_shared<Trinity::Net::Http::HttpConnectionInitializer<BaseSocket>>(this));
|
||||
initializers.stable_emplace_back(std::make_shared<Trinity::Net::ReadConnectionInitializer<BaseSocket>>(this));
|
||||
|
||||
Trinity::Net::SocketConnectionInitializer::SetupChain(std::span(initializers.data(), initializers.size()))->Start();
|
||||
}
|
||||
|
||||
Trinity::Net::Http::RequestHandlerResult RequestHandler(Trinity::Net::Http::RequestContext& context) override
|
||||
{
|
||||
return sLoginService.HandleRequest(_owner.shared_from_this(), context);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context) const override
|
||||
{
|
||||
return ::ObtainSessionState(context, this->GetRemoteIpAddress());
|
||||
}
|
||||
|
||||
Battlenet::LoginHttpSession& _owner;
|
||||
};
|
||||
|
||||
template<>
|
||||
LoginHttpSocketImpl<Trinity::Net::Http::SslSocket>::LoginHttpSocketImpl(Trinity::Net::IoContextTcpSocket&& socket, Battlenet::LoginHttpSession& owner)
|
||||
: BaseSocket(std::move(socket), Battlenet::SslContext::instance()), _owner(owner)
|
||||
{
|
||||
if (!SslContext::UsesDevWildcardCertificate())
|
||||
_socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::SslSocket>>(std::move(socket), *this);
|
||||
else
|
||||
_socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::Socket>>(std::move(socket), *this);
|
||||
}
|
||||
}
|
||||
|
||||
namespace Battlenet
|
||||
{
|
||||
LoginHttpSession::LoginHttpSession(Trinity::Net::IoContextTcpSocket&& socket)
|
||||
: _socket(!SslContext::UsesDevWildcardCertificate()
|
||||
? std::shared_ptr<AbstractSocket>(std::make_shared<LoginHttpSocketImpl<Trinity::Net::Http::SslSocket>>(std::move(socket), *this))
|
||||
: std::shared_ptr<AbstractSocket>(std::make_shared<LoginHttpSocketImpl<Trinity::Net::Http::Socket>>(std::move(socket), *this)))
|
||||
{
|
||||
}
|
||||
|
||||
void LoginHttpSession::Start()
|
||||
{
|
||||
TC_LOG_TRACE("server.http.session", "{} Accepted connection", GetClientInfo());
|
||||
|
||||
return _socket->Start();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,10 +18,8 @@
|
||||
#ifndef TRINITYCORE_LOGIN_HTTP_SESSION_H
|
||||
#define TRINITYCORE_LOGIN_HTTP_SESSION_H
|
||||
|
||||
#include "HttpSocket.h"
|
||||
#include "HttpSslSocket.h"
|
||||
#include "BaseHttpSocket.h"
|
||||
#include "SRP6.h"
|
||||
#include <variant>
|
||||
|
||||
namespace Battlenet
|
||||
{
|
||||
@@ -30,59 +28,27 @@ struct LoginSessionState : public Trinity::Net::Http::SessionState
|
||||
std::unique_ptr<Trinity::Crypto::SRP::BnetSRP6Base> Srp;
|
||||
};
|
||||
|
||||
class LoginHttpSessionWrapper;
|
||||
std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context, boost::asio::ip::address const& remoteAddress);
|
||||
|
||||
template<template<typename> typename SocketImpl>
|
||||
class LoginHttpSession : public SocketImpl<LoginHttpSession<SocketImpl>>
|
||||
class LoginHttpSession : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSession>
|
||||
{
|
||||
using BaseSocket = SocketImpl<LoginHttpSession<SocketImpl>>;
|
||||
|
||||
public:
|
||||
explicit LoginHttpSession(boost::asio::ip::tcp::socket&& socket, LoginHttpSessionWrapper& owner);
|
||||
~LoginHttpSession();
|
||||
static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID=";
|
||||
|
||||
explicit LoginHttpSession(Trinity::Net::IoContextTcpSocket&& socket);
|
||||
|
||||
void Start() override;
|
||||
bool Update() override { return _socket->Update(); }
|
||||
boost::asio::ip::address const& GetRemoteIpAddress() const override { return _socket->GetRemoteIpAddress(); }
|
||||
bool IsOpen() const override { return _socket->IsOpen(); }
|
||||
void CloseSocket() override { return _socket->CloseSocket(); }
|
||||
|
||||
void CheckIpCallback(PreparedQueryResult result);
|
||||
|
||||
Trinity::Net::Http::RequestHandlerResult RequestHandler(Trinity::Net::Http::RequestContext& context) override;
|
||||
|
||||
LoginSessionState* GetSessionState() const { return static_cast<LoginSessionState*>(this->_state.get()); }
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context) const override
|
||||
{
|
||||
return Battlenet::ObtainSessionState(context, this->GetRemoteIpAddress());
|
||||
}
|
||||
|
||||
LoginHttpSessionWrapper& _owner;
|
||||
};
|
||||
|
||||
class LoginHttpSessionWrapper : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSessionWrapper>
|
||||
{
|
||||
public:
|
||||
static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID";
|
||||
|
||||
explicit LoginHttpSessionWrapper(boost::asio::ip::tcp::socket&& socket);
|
||||
|
||||
void Start() { return std::visit([&](auto&& socket) { return socket->Start(); }, _socket); }
|
||||
bool Update() { return std::visit([&](auto&& socket) { return socket->Update(); }, _socket); }
|
||||
boost::asio::ip::address GetRemoteIpAddress() const { return std::visit([&](auto&& socket) { return socket->GetRemoteIpAddress(); }, _socket); }
|
||||
bool IsOpen() const { return std::visit([&](auto&& socket) { return socket->IsOpen(); }, _socket); }
|
||||
void CloseSocket() { return std::visit([&](auto&& socket) { return socket->CloseSocket(); }, _socket); }
|
||||
|
||||
void SendResponse(Trinity::Net::Http::RequestContext& context) override { return std::visit([&](auto&& socket) { return socket->SendResponse(context); }, _socket); }
|
||||
void QueueQuery(QueryCallback&& queryCallback) override { return std::visit([&](auto&& socket) { return socket->QueueQuery(std::move(queryCallback)); }, _socket); }
|
||||
std::string GetClientInfo() const override { return std::visit([&](auto&& socket) { return socket->GetClientInfo(); }, _socket); }
|
||||
Optional<boost::uuids::uuid> GetSessionId() const override { return std::visit([&](auto&& socket) { return socket->GetSessionId(); }, _socket); }
|
||||
LoginSessionState* GetSessionState() const { return std::visit([&](auto&& socket) { return socket->GetSessionState(); }, _socket); }
|
||||
void SendResponse(Trinity::Net::Http::RequestContext& context) override { return _socket->SendResponse(context); }
|
||||
void QueueQuery(QueryCallback&& queryCallback) override { return _socket->QueueQuery(std::move(queryCallback)); }
|
||||
std::string GetClientInfo() const override { return _socket->GetClientInfo(); }
|
||||
LoginSessionState* GetSessionState() const override { return static_cast<LoginSessionState*>(_socket->GetSessionState()); }
|
||||
|
||||
private:
|
||||
std::variant<
|
||||
std::shared_ptr<LoginHttpSession<Trinity::Net::Http::SslSocket>>,
|
||||
std::shared_ptr<LoginHttpSession<Trinity::Net::Http::Socket>>
|
||||
> _socket;
|
||||
std::shared_ptr<Trinity::Net::Http::AbstractSocket> _socket;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // TRINITYCORE_LOGIN_HTTP_SESSION_H
|
||||
|
||||
@@ -44,32 +44,32 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st
|
||||
|
||||
using Trinity::Net::Http::RequestHandlerFlag;
|
||||
|
||||
RegisterHandler(boost::beast::http::verb::get, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
|
||||
RegisterHandler(boost::beast::http::verb::get, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
|
||||
{
|
||||
return HandleGetForm(std::move(session), context);
|
||||
});
|
||||
|
||||
RegisterHandler(boost::beast::http::verb::get, "/bnetserver/gameAccounts/", [](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
|
||||
RegisterHandler(boost::beast::http::verb::get, "/bnetserver/gameAccounts/", [](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
|
||||
{
|
||||
return HandleGetGameAccounts(std::move(session), context);
|
||||
});
|
||||
|
||||
RegisterHandler(boost::beast::http::verb::get, "/bnetserver/portal/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
|
||||
RegisterHandler(boost::beast::http::verb::get, "/bnetserver/portal/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
|
||||
{
|
||||
return HandleGetPortal(std::move(session), context);
|
||||
});
|
||||
|
||||
RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
|
||||
RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
|
||||
{
|
||||
return HandlePostLogin(std::move(session), context);
|
||||
}, RequestHandlerFlag::DoNotLogRequestContent);
|
||||
|
||||
RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/srp/", [](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
|
||||
RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/srp/", [](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
|
||||
{
|
||||
return HandlePostLoginSrpChallenge(std::move(session), context);
|
||||
});
|
||||
|
||||
RegisterHandler(boost::beast::http::verb::post, "/bnetserver/refreshLoginTicket/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
|
||||
RegisterHandler(boost::beast::http::verb::post, "/bnetserver/refreshLoginTicket/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
|
||||
{
|
||||
return HandlePostRefreshLoginTicket(std::move(session), context);
|
||||
});
|
||||
@@ -121,7 +121,10 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st
|
||||
|
||||
MigrateLegacyPasswordHashes();
|
||||
|
||||
_acceptor->AsyncAcceptWithCallback<&LoginRESTService::OnSocketAccept>();
|
||||
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
|
||||
{
|
||||
OnSocketOpen(std::move(sock), threadIndex);
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -165,7 +168,7 @@ std::string LoginRESTService::ExtractAuthorization(HttpRequest const& request)
|
||||
return ticket;
|
||||
}
|
||||
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const
|
||||
{
|
||||
JSON::Login::FormInputs form = _formInputs;
|
||||
form.set_srp_url(Trinity::StringFormat("http{}://{}:{}/bnetserver/login/srp/", !SslContext::UsesDevWildcardCertificate() ? "s" : "",
|
||||
@@ -176,7 +179,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shar
|
||||
return RequestHandlerResult::Handled;
|
||||
}
|
||||
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
|
||||
{
|
||||
std::string ticket = ExtractAuthorization(context.request);
|
||||
if (ticket.empty())
|
||||
@@ -225,14 +228,14 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(s
|
||||
return RequestHandlerResult::Async;
|
||||
}
|
||||
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetPortal(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetPortal(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const
|
||||
{
|
||||
context.response.set(boost::beast::http::field::content_type, "text/plain");
|
||||
context.response.body() = Trinity::StringFormat("{}:{}", GetHostnameForClient(session->GetRemoteIpAddress()), sConfigMgr->GetIntDefault("BattlenetPort", 1119));
|
||||
return RequestHandlerResult::Handled;
|
||||
}
|
||||
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const
|
||||
{
|
||||
std::shared_ptr<JSON::Login::LoginForm> loginForm = std::make_shared<JSON::Login::LoginForm>();
|
||||
if (!::JSON::Deserialize(context.request.body(), loginForm.get()))
|
||||
@@ -396,7 +399,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::sh
|
||||
return RequestHandlerResult::Async;
|
||||
}
|
||||
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
|
||||
{
|
||||
JSON::Login::LoginForm loginForm;
|
||||
if (!::JSON::Deserialize(context.request.body(), &loginForm))
|
||||
@@ -483,7 +486,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChall
|
||||
return RequestHandlerResult::Async;
|
||||
}
|
||||
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const
|
||||
LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const
|
||||
{
|
||||
std::string ticket = ExtractAuthorization(context.request);
|
||||
if (ticket.empty())
|
||||
@@ -551,11 +554,6 @@ std::shared_ptr<Trinity::Net::Http::SessionState> LoginRESTService::CreateNewSes
|
||||
return state;
|
||||
}
|
||||
|
||||
void LoginRESTService::OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex)
|
||||
{
|
||||
sLoginService.OnSocketOpen(std::move(sock), threadIndex);
|
||||
}
|
||||
|
||||
void LoginRESTService::MigrateLegacyPasswordHashes() const
|
||||
{
|
||||
if (!LoginDatabase.Query("SELECT 1 FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = SCHEMA() AND TABLE_NAME = 'battlenet_accounts' AND COLUMN_NAME = 'sha_pass_hash'"))
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef LoginRESTService_h__
|
||||
#define LoginRESTService_h__
|
||||
#ifndef TRINITYCORE_LOGIN_REST_SERVICE_H
|
||||
#define TRINITYCORE_LOGIN_REST_SERVICE_H
|
||||
|
||||
#include "HttpService.h"
|
||||
#include "Login.pb.h"
|
||||
@@ -42,7 +42,7 @@ enum class BanMode
|
||||
BAN_ACCOUNT = 1
|
||||
};
|
||||
|
||||
class LoginRESTService : public Trinity::Net::Http::HttpService<LoginHttpSessionWrapper>
|
||||
class LoginRESTService : public Trinity::Net::Http::HttpService<LoginHttpSession>
|
||||
{
|
||||
public:
|
||||
using RequestHandlerResult = Trinity::Net::Http::RequestHandlerResult;
|
||||
@@ -63,17 +63,15 @@ public:
|
||||
std::shared_ptr<Trinity::Net::Http::SessionState> CreateNewSessionState(boost::asio::ip::address const& address) override;
|
||||
|
||||
private:
|
||||
static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex);
|
||||
|
||||
static std::string ExtractAuthorization(HttpRequest const& request);
|
||||
|
||||
RequestHandlerResult HandleGetForm(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const;
|
||||
static RequestHandlerResult HandleGetGameAccounts(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context);
|
||||
RequestHandlerResult HandleGetPortal(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const;
|
||||
RequestHandlerResult HandleGetForm(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const;
|
||||
static RequestHandlerResult HandleGetGameAccounts(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context);
|
||||
RequestHandlerResult HandleGetPortal(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const;
|
||||
|
||||
RequestHandlerResult HandlePostLogin(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const;
|
||||
static RequestHandlerResult HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context);
|
||||
RequestHandlerResult HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const;
|
||||
RequestHandlerResult HandlePostLogin(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const;
|
||||
static RequestHandlerResult HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context);
|
||||
RequestHandlerResult HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const;
|
||||
|
||||
static std::unique_ptr<Trinity::Crypto::SRP::BnetSRP6Base> CreateSrpImplementation(SrpVersion version, SrpHashFunction hashFunction,
|
||||
std::string const& username, Trinity::Crypto::SRP::Salt const& salt, Trinity::Crypto::SRP::Verifier const& verifier);
|
||||
@@ -90,4 +88,4 @@ private:
|
||||
|
||||
#define sLoginService Battlenet::LoginRESTService::Instance()
|
||||
|
||||
#endif // LoginRESTService_h__
|
||||
#endif // TRINITYCORE_LOGIN_REST_SERVICE_H
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "Errors.h"
|
||||
#include "Hash.h"
|
||||
#include "IPLocation.h"
|
||||
#include "IpBanCheckConnectionInitializer.h"
|
||||
#include "LoginRESTService.h"
|
||||
#include "MapUtils.h"
|
||||
#include "ProtobufJSON.h"
|
||||
@@ -74,7 +75,7 @@ void Battlenet::Session::GameAccountInfo::LoadResult(Field const* fields)
|
||||
DisplayName = Name;
|
||||
}
|
||||
|
||||
Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSocket(std::move(socket), SslContext::instance()),
|
||||
Battlenet::Session::Session(Trinity::Net::IoContextTcpSocket&& socket) : BaseSocket(std::move(socket), SslContext::instance()),
|
||||
_accountInfo(new AccountInfo()), _gameAccountInfo(nullptr), _locale(),
|
||||
_os(), _build(0), _clientInfo(), _timezoneOffset(0min), _ipCountry(), _clientSecret(), _authed(false), _requestToken(0)
|
||||
{
|
||||
@@ -83,54 +84,24 @@ Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSo
|
||||
|
||||
Battlenet::Session::~Session() = default;
|
||||
|
||||
void Battlenet::Session::AsyncHandshake()
|
||||
{
|
||||
underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
|
||||
[sess = shared_from_this()](boost::system::error_code const& error) { sess->HandshakeHandler(error); });
|
||||
}
|
||||
|
||||
void Battlenet::Session::Start()
|
||||
{
|
||||
std::string ip_address = GetRemoteIpAddress().to_string();
|
||||
TC_LOG_TRACE("session", "{} Accepted connection", GetClientInfo());
|
||||
|
||||
// Verify that this IP is not in the ip_banned table
|
||||
LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS));
|
||||
// build initializer chain
|
||||
std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers =
|
||||
{ {
|
||||
std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<Session>>(this),
|
||||
std::make_shared<Trinity::Net::SslHandshakeConnectionInitializer<Session>>(this),
|
||||
std::make_shared<Trinity::Net::ReadConnectionInitializer<Session>>(this),
|
||||
} };
|
||||
|
||||
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
|
||||
stmt->setString(0, ip_address);
|
||||
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
|
||||
.WithPreparedCallback([sess = shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); }));
|
||||
}
|
||||
|
||||
void Battlenet::Session::CheckIpCallback(PreparedQueryResult result)
|
||||
{
|
||||
if (result)
|
||||
{
|
||||
bool banned = false;
|
||||
do
|
||||
{
|
||||
Field* fields = result->Fetch();
|
||||
if (fields[0].GetUInt64() != 0)
|
||||
banned = true;
|
||||
|
||||
} while (result->NextRow());
|
||||
|
||||
if (banned)
|
||||
{
|
||||
TC_LOG_DEBUG("session", "{} tries to log in using banned IP!", GetClientInfo());
|
||||
CloseSocket();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
AsyncHandshake();
|
||||
Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start();
|
||||
}
|
||||
|
||||
bool Battlenet::Session::Update()
|
||||
{
|
||||
if (!BattlenetSocket::Update())
|
||||
if (!BaseSocket::Update())
|
||||
return false;
|
||||
|
||||
_queryProcessor.ProcessReadyCallbacks();
|
||||
@@ -211,6 +182,11 @@ void Battlenet::Session::SendRequest(uint32 serviceHash, uint32 methodId, pb::Me
|
||||
AsyncWrite(&packet);
|
||||
}
|
||||
|
||||
void Battlenet::Session::QueueQuery(QueryCallback&& queryCallback)
|
||||
{
|
||||
_queryProcessor.AddCallback(std::move(queryCallback));
|
||||
}
|
||||
|
||||
uint32 Battlenet::Session::HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation)
|
||||
{
|
||||
if (logonRequest->program() != "WoW")
|
||||
@@ -292,7 +268,7 @@ uint32 Battlenet::Session::HandleGenerateWebCredentials(authentication::v1::Gene
|
||||
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_BNET_EXISTING_AUTHENTICATION_BY_ID);
|
||||
stmt->setUInt32(0, _accountInfo->Id);
|
||||
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result)
|
||||
QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result)
|
||||
{
|
||||
// just send existing credentials back (not the best but it works for now with them being stored in db)
|
||||
Battlenet::Services::Authentication asyncContinuationService(this);
|
||||
@@ -314,7 +290,7 @@ uint32 Battlenet::Session::VerifyWebCredentials(std::string const& webCredential
|
||||
|
||||
std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)> asyncContinuation = std::move(continuation);
|
||||
std::shared_ptr<AccountInfo> accountInfo = std::make_shared<AccountInfo>();
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result)
|
||||
QueueQuery(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result)
|
||||
{
|
||||
Battlenet::Services::Authentication asyncContinuationService(this);
|
||||
NoData response;
|
||||
@@ -713,20 +689,8 @@ uint32 Battlenet::Session::HandleGetAllValuesForAttribute(game_utilities::v1::Ge
|
||||
return ERROR_RPC_NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
void Battlenet::Session::HandshakeHandler(boost::system::error_code const& error)
|
||||
{
|
||||
if (error)
|
||||
{
|
||||
TC_LOG_ERROR("session", "{} SSL Handshake failed {}", GetClientInfo(), error.message());
|
||||
CloseSocket();
|
||||
return;
|
||||
}
|
||||
|
||||
AsyncRead();
|
||||
}
|
||||
|
||||
template<bool(Battlenet::Session::*processMethod)(), MessageBuffer Battlenet::Session::*outputBuffer>
|
||||
inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer)
|
||||
static inline Optional<Trinity::Net::SocketReadCallbackResult> PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer)
|
||||
{
|
||||
MessageBuffer& buffer = session->*outputBuffer;
|
||||
|
||||
@@ -738,46 +702,45 @@ inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inp
|
||||
buffer.Write(inputBuffer.GetReadPointer(), readDataSize);
|
||||
inputBuffer.ReadCompleted(readDataSize);
|
||||
}
|
||||
else
|
||||
return { }; // go to next buffer
|
||||
|
||||
if (buffer.GetRemainingSpace() > 0)
|
||||
{
|
||||
// Couldn't receive the whole data this time.
|
||||
ASSERT(inputBuffer.GetActiveSize() == 0);
|
||||
return false;
|
||||
return Trinity::Net::SocketReadCallbackResult::KeepReading;
|
||||
}
|
||||
|
||||
// just received fresh new payload
|
||||
if (!(session->*processMethod)())
|
||||
{
|
||||
session->CloseSocket();
|
||||
return false;
|
||||
return Trinity::Net::SocketReadCallbackResult::Stop;
|
||||
}
|
||||
|
||||
return true;
|
||||
return { }; // go to next buffer
|
||||
}
|
||||
|
||||
void Battlenet::Session::ReadHandler()
|
||||
Trinity::Net::SocketReadCallbackResult Battlenet::Session::ReadHandler()
|
||||
{
|
||||
if (!IsOpen())
|
||||
return;
|
||||
|
||||
MessageBuffer& packet = GetReadBuffer();
|
||||
while (packet.GetActiveSize() > 0)
|
||||
{
|
||||
if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderLengthHandler, &Battlenet::Session::_headerLengthBuffer>(this, packet))
|
||||
break;
|
||||
if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderLengthHandler, &Session::_headerLengthBuffer>(this, packet))
|
||||
return *partialResult;
|
||||
|
||||
if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderHandler, &Battlenet::Session::_headerBuffer>(this, packet))
|
||||
break;
|
||||
if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderHandler, &Session::_headerBuffer>(this, packet))
|
||||
return *partialResult;
|
||||
|
||||
if (!PartialProcessPacket<&Battlenet::Session::ReadDataHandler, &Battlenet::Session::_packetBuffer>(this, packet))
|
||||
break;
|
||||
if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadDataHandler, &Session::_packetBuffer>(this, packet))
|
||||
return *partialResult;
|
||||
|
||||
_headerLengthBuffer.Reset();
|
||||
_headerBuffer.Reset();
|
||||
}
|
||||
|
||||
AsyncRead();
|
||||
return Trinity::Net::SocketReadCallbackResult::KeepReading;
|
||||
}
|
||||
|
||||
bool Battlenet::Session::ReadHeaderLengthHandler()
|
||||
@@ -835,5 +798,5 @@ std::string Battlenet::Session::GetClientInfo() const
|
||||
|
||||
stream << ']';
|
||||
|
||||
return stream.str();
|
||||
return std::move(stream).str();
|
||||
}
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef Session_h__
|
||||
#define Session_h__
|
||||
#ifndef TRINITYCORE_SESSION_H
|
||||
#define TRINITYCORE_SESSION_H
|
||||
|
||||
#include "AsyncCallbackProcessor.h"
|
||||
#include "ClientBuildInfo.h"
|
||||
@@ -24,7 +24,7 @@
|
||||
#include "QueryResult.h"
|
||||
#include "Realm.h"
|
||||
#include "Socket.h"
|
||||
#include "SslSocket.h"
|
||||
#include "SslStream.h"
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <google/protobuf/message.h>
|
||||
#include <memory>
|
||||
@@ -65,9 +65,9 @@ using namespace bgs::protocol;
|
||||
|
||||
namespace Battlenet
|
||||
{
|
||||
class Session : public Socket<Session, SslSocket<>>
|
||||
class Session final : public Trinity::Net::Socket<Trinity::Net::SslStream<>>
|
||||
{
|
||||
typedef Socket<Session, SslSocket<>> BattlenetSocket;
|
||||
using BaseSocket = Socket<Trinity::Net::SslStream<>>;
|
||||
|
||||
public:
|
||||
struct LastPlayedCharacterInfo
|
||||
@@ -110,7 +110,7 @@ namespace Battlenet
|
||||
std::unordered_map<uint32, GameAccountInfo> GameAccounts;
|
||||
};
|
||||
|
||||
explicit Session(boost::asio::ip::tcp::socket&& socket);
|
||||
explicit Session(Trinity::Net::IoContextTcpSocket&& socket);
|
||||
~Session();
|
||||
|
||||
void Start() override;
|
||||
@@ -130,6 +130,8 @@ namespace Battlenet
|
||||
|
||||
void SendRequest(uint32 serviceHash, uint32 methodId, pb::Message const* request);
|
||||
|
||||
void QueueQuery(QueryCallback&& queryCallback);
|
||||
|
||||
uint32 HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation);
|
||||
uint32 HandleVerifyWebCredentials(authentication::v1::VerifyWebCredentialsRequest const* verifyWebCredentialsRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation);
|
||||
uint32 HandleGenerateWebCredentials(authentication::v1::GenerateWebCredentialsRequest const* request, std::function<void(ServiceBase*, uint32, google::protobuf::Message const*)>& continuation);
|
||||
@@ -140,9 +142,9 @@ namespace Battlenet
|
||||
|
||||
std::string GetClientInfo() const;
|
||||
|
||||
Trinity::Net::SocketReadCallbackResult ReadHandler() override;
|
||||
|
||||
protected:
|
||||
void HandshakeHandler(boost::system::error_code const& error);
|
||||
void ReadHandler() override;
|
||||
bool ReadHeaderLengthHandler();
|
||||
bool ReadHeaderHandler();
|
||||
bool ReadDataHandler();
|
||||
@@ -150,10 +152,6 @@ namespace Battlenet
|
||||
private:
|
||||
void AsyncWrite(MessageBuffer* packet);
|
||||
|
||||
void AsyncHandshake();
|
||||
|
||||
void CheckIpCallback(PreparedQueryResult result);
|
||||
|
||||
uint32 VerifyWebCredentials(std::string const& webCredentials, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation);
|
||||
|
||||
typedef uint32(Session::*ClientRequestHandler)(std::unordered_map<std::string, Variant const*> const&, game_utilities::v1::ClientResponse*);
|
||||
@@ -190,4 +188,4 @@ namespace Battlenet
|
||||
};
|
||||
}
|
||||
|
||||
#endif // Session_h__
|
||||
#endif // TRINITYCORE_SESSION_H
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
*/
|
||||
|
||||
#include "SessionManager.h"
|
||||
#include "DatabaseEnv.h"
|
||||
#include "Util.h"
|
||||
|
||||
bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
|
||||
@@ -24,18 +23,16 @@ bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext
|
||||
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
|
||||
return false;
|
||||
|
||||
_acceptor->AsyncAcceptWithCallback<&OnSocketAccept>();
|
||||
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
|
||||
{
|
||||
OnSocketOpen(std::move(sock), threadIndex);
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() const
|
||||
Trinity::Net::NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() const
|
||||
{
|
||||
return new NetworkThread<Session>[GetNetworkThreadCount()];
|
||||
}
|
||||
|
||||
void Battlenet::SessionManager::OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex)
|
||||
{
|
||||
sSessionMgr.OnSocketOpen(std::move(sock), threadIndex);
|
||||
return new Trinity::Net::NetworkThread<Session>[GetNetworkThreadCount()];
|
||||
}
|
||||
|
||||
Battlenet::SessionManager& Battlenet::SessionManager::Instance()
|
||||
|
||||
@@ -15,15 +15,15 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef SessionManager_h__
|
||||
#define SessionManager_h__
|
||||
#ifndef TRINITYCORE_SESSION_MANAGER_H
|
||||
#define TRINITYCORE_SESSION_MANAGER_H
|
||||
|
||||
#include "SocketMgr.h"
|
||||
#include "Session.h"
|
||||
|
||||
namespace Battlenet
|
||||
{
|
||||
class SessionManager : public SocketMgr<Session>
|
||||
class SessionManager : public Trinity::Net::SocketMgr<Session>
|
||||
{
|
||||
typedef SocketMgr<Session> BaseSocketMgr;
|
||||
|
||||
@@ -33,13 +33,10 @@ namespace Battlenet
|
||||
bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount = 1) override;
|
||||
|
||||
protected:
|
||||
NetworkThread<Session>* CreateThreads() const override;
|
||||
|
||||
private:
|
||||
static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex);
|
||||
Trinity::Net::NetworkThread<Session>* CreateThreads() const override;
|
||||
};
|
||||
}
|
||||
|
||||
#define sSessionMgr Battlenet::SessionManager::Instance()
|
||||
|
||||
#endif // SessionManager_h__
|
||||
#endif // TRINITYCORE_SESSION_MANAGER_H
|
||||
|
||||
Reference in New Issue
Block a user