mirror of
https://github.com/TrinityCore/TrinityCore.git
synced 2026-01-15 23:20:36 +01:00
Core/Network: Socket refactors
* Devirtualize calls to Read and Update by marking concrete implementations as final * Removed derived class template argument * Specialize boost::asio::basic_stream_socket for boost::asio::io_context instead of type-erased any_io_executor * Make socket initialization easier composable (before entering Read loop) * Remove use of deprecated boost::asio::null_buffers and boost::beast::ssl_stream
This commit is contained in:
@@ -23,6 +23,7 @@
|
||||
#include "Errors.h"
|
||||
#include "Hash.h"
|
||||
#include "IPLocation.h"
|
||||
#include "IpBanCheckConnectionInitializer.h"
|
||||
#include "LoginRESTService.h"
|
||||
#include "MapUtils.h"
|
||||
#include "ProtobufJSON.h"
|
||||
@@ -74,7 +75,7 @@ void Battlenet::Session::GameAccountInfo::LoadResult(Field const* fields)
|
||||
DisplayName = Name;
|
||||
}
|
||||
|
||||
Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSocket(std::move(socket), SslContext::instance()),
|
||||
Battlenet::Session::Session(Trinity::Net::IoContextTcpSocket&& socket) : BaseSocket(std::move(socket), SslContext::instance()),
|
||||
_accountInfo(new AccountInfo()), _gameAccountInfo(nullptr), _locale(),
|
||||
_os(), _build(0), _clientInfo(), _timezoneOffset(0min), _ipCountry(), _clientSecret(), _authed(false), _requestToken(0)
|
||||
{
|
||||
@@ -83,54 +84,24 @@ Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSo
|
||||
|
||||
Battlenet::Session::~Session() = default;
|
||||
|
||||
void Battlenet::Session::AsyncHandshake()
|
||||
{
|
||||
underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
|
||||
[sess = shared_from_this()](boost::system::error_code const& error) { sess->HandshakeHandler(error); });
|
||||
}
|
||||
|
||||
void Battlenet::Session::Start()
|
||||
{
|
||||
std::string ip_address = GetRemoteIpAddress().to_string();
|
||||
TC_LOG_TRACE("session", "{} Accepted connection", GetClientInfo());
|
||||
|
||||
// Verify that this IP is not in the ip_banned table
|
||||
LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS));
|
||||
// build initializer chain
|
||||
std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers =
|
||||
{ {
|
||||
std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<Session>>(this),
|
||||
std::make_shared<Trinity::Net::SslHandshakeConnectionInitializer<Session>>(this),
|
||||
std::make_shared<Trinity::Net::ReadConnectionInitializer<Session>>(this),
|
||||
} };
|
||||
|
||||
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
|
||||
stmt->setString(0, ip_address);
|
||||
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
|
||||
.WithPreparedCallback([sess = shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); }));
|
||||
}
|
||||
|
||||
void Battlenet::Session::CheckIpCallback(PreparedQueryResult result)
|
||||
{
|
||||
if (result)
|
||||
{
|
||||
bool banned = false;
|
||||
do
|
||||
{
|
||||
Field* fields = result->Fetch();
|
||||
if (fields[0].GetUInt64() != 0)
|
||||
banned = true;
|
||||
|
||||
} while (result->NextRow());
|
||||
|
||||
if (banned)
|
||||
{
|
||||
TC_LOG_DEBUG("session", "{} tries to log in using banned IP!", GetClientInfo());
|
||||
CloseSocket();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
AsyncHandshake();
|
||||
Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start();
|
||||
}
|
||||
|
||||
bool Battlenet::Session::Update()
|
||||
{
|
||||
if (!BattlenetSocket::Update())
|
||||
if (!BaseSocket::Update())
|
||||
return false;
|
||||
|
||||
_queryProcessor.ProcessReadyCallbacks();
|
||||
@@ -211,6 +182,11 @@ void Battlenet::Session::SendRequest(uint32 serviceHash, uint32 methodId, pb::Me
|
||||
AsyncWrite(&packet);
|
||||
}
|
||||
|
||||
void Battlenet::Session::QueueQuery(QueryCallback&& queryCallback)
|
||||
{
|
||||
_queryProcessor.AddCallback(std::move(queryCallback));
|
||||
}
|
||||
|
||||
uint32 Battlenet::Session::HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation)
|
||||
{
|
||||
if (logonRequest->program() != "WoW")
|
||||
@@ -292,7 +268,7 @@ uint32 Battlenet::Session::HandleGenerateWebCredentials(authentication::v1::Gene
|
||||
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_BNET_EXISTING_AUTHENTICATION_BY_ID);
|
||||
stmt->setUInt32(0, _accountInfo->Id);
|
||||
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result)
|
||||
QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result)
|
||||
{
|
||||
// just send existing credentials back (not the best but it works for now with them being stored in db)
|
||||
Battlenet::Services::Authentication asyncContinuationService(this);
|
||||
@@ -314,7 +290,7 @@ uint32 Battlenet::Session::VerifyWebCredentials(std::string const& webCredential
|
||||
|
||||
std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)> asyncContinuation = std::move(continuation);
|
||||
std::shared_ptr<AccountInfo> accountInfo = std::make_shared<AccountInfo>();
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result)
|
||||
QueueQuery(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result)
|
||||
{
|
||||
Battlenet::Services::Authentication asyncContinuationService(this);
|
||||
NoData response;
|
||||
@@ -713,20 +689,8 @@ uint32 Battlenet::Session::HandleGetAllValuesForAttribute(game_utilities::v1::Ge
|
||||
return ERROR_RPC_NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
void Battlenet::Session::HandshakeHandler(boost::system::error_code const& error)
|
||||
{
|
||||
if (error)
|
||||
{
|
||||
TC_LOG_ERROR("session", "{} SSL Handshake failed {}", GetClientInfo(), error.message());
|
||||
CloseSocket();
|
||||
return;
|
||||
}
|
||||
|
||||
AsyncRead();
|
||||
}
|
||||
|
||||
template<bool(Battlenet::Session::*processMethod)(), MessageBuffer Battlenet::Session::*outputBuffer>
|
||||
inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer)
|
||||
static inline Optional<Trinity::Net::SocketReadCallbackResult> PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer)
|
||||
{
|
||||
MessageBuffer& buffer = session->*outputBuffer;
|
||||
|
||||
@@ -738,46 +702,45 @@ inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inp
|
||||
buffer.Write(inputBuffer.GetReadPointer(), readDataSize);
|
||||
inputBuffer.ReadCompleted(readDataSize);
|
||||
}
|
||||
else
|
||||
return { }; // go to next buffer
|
||||
|
||||
if (buffer.GetRemainingSpace() > 0)
|
||||
{
|
||||
// Couldn't receive the whole data this time.
|
||||
ASSERT(inputBuffer.GetActiveSize() == 0);
|
||||
return false;
|
||||
return Trinity::Net::SocketReadCallbackResult::KeepReading;
|
||||
}
|
||||
|
||||
// just received fresh new payload
|
||||
if (!(session->*processMethod)())
|
||||
{
|
||||
session->CloseSocket();
|
||||
return false;
|
||||
return Trinity::Net::SocketReadCallbackResult::Stop;
|
||||
}
|
||||
|
||||
return true;
|
||||
return { }; // go to next buffer
|
||||
}
|
||||
|
||||
void Battlenet::Session::ReadHandler()
|
||||
Trinity::Net::SocketReadCallbackResult Battlenet::Session::ReadHandler()
|
||||
{
|
||||
if (!IsOpen())
|
||||
return;
|
||||
|
||||
MessageBuffer& packet = GetReadBuffer();
|
||||
while (packet.GetActiveSize() > 0)
|
||||
{
|
||||
if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderLengthHandler, &Battlenet::Session::_headerLengthBuffer>(this, packet))
|
||||
break;
|
||||
if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderLengthHandler, &Session::_headerLengthBuffer>(this, packet))
|
||||
return *partialResult;
|
||||
|
||||
if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderHandler, &Battlenet::Session::_headerBuffer>(this, packet))
|
||||
break;
|
||||
if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderHandler, &Session::_headerBuffer>(this, packet))
|
||||
return *partialResult;
|
||||
|
||||
if (!PartialProcessPacket<&Battlenet::Session::ReadDataHandler, &Battlenet::Session::_packetBuffer>(this, packet))
|
||||
break;
|
||||
if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadDataHandler, &Session::_packetBuffer>(this, packet))
|
||||
return *partialResult;
|
||||
|
||||
_headerLengthBuffer.Reset();
|
||||
_headerBuffer.Reset();
|
||||
}
|
||||
|
||||
AsyncRead();
|
||||
return Trinity::Net::SocketReadCallbackResult::KeepReading;
|
||||
}
|
||||
|
||||
bool Battlenet::Session::ReadHeaderLengthHandler()
|
||||
@@ -835,5 +798,5 @@ std::string Battlenet::Session::GetClientInfo() const
|
||||
|
||||
stream << ']';
|
||||
|
||||
return stream.str();
|
||||
return std::move(stream).str();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user