aboutsummaryrefslogtreecommitdiff
path: root/src/server/bnetserver/Server
diff options
context:
space:
mode:
authorShauren <shauren.trinity@gmail.com>2025-04-08 19:15:16 +0200
committerShauren <shauren.trinity@gmail.com>2025-04-08 19:15:16 +0200
commite8b2be3527c7683e8bfca70ed7706fc20da566fd (patch)
tree54d5099554c8628cad719e6f1a49d387c7eced4f /src/server/bnetserver/Server
parent40d80f3476ade4898be24659408e82aa4234b099 (diff)
Core/Network: Socket refactors
* Devirtualize calls to Read and Update by marking concrete implementations as final * Removed derived class template argument * Specialize boost::asio::basic_stream_socket for boost::asio::io_context instead of type-erased any_io_executor * Make socket initialization easier composable (before entering Read loop) * Remove use of deprecated boost::asio::null_buffers and boost::beast::ssl_stream
Diffstat (limited to 'src/server/bnetserver/Server')
-rw-r--r--src/server/bnetserver/Server/Session.cpp103
-rw-r--r--src/server/bnetserver/Server/Session.h24
-rw-r--r--src/server/bnetserver/Server/SessionManager.cpp15
-rw-r--r--src/server/bnetserver/Server/SessionManager.h13
4 files changed, 55 insertions, 100 deletions
diff --git a/src/server/bnetserver/Server/Session.cpp b/src/server/bnetserver/Server/Session.cpp
index c3a8ae6f211..7b9d6f45a3e 100644
--- a/src/server/bnetserver/Server/Session.cpp
+++ b/src/server/bnetserver/Server/Session.cpp
@@ -23,6 +23,7 @@
#include "Errors.h"
#include "Hash.h"
#include "IPLocation.h"
+#include "IpBanCheckConnectionInitializer.h"
#include "LoginRESTService.h"
#include "MapUtils.h"
#include "ProtobufJSON.h"
@@ -74,7 +75,7 @@ void Battlenet::Session::GameAccountInfo::LoadResult(Field const* fields)
DisplayName = Name;
}
-Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSocket(std::move(socket), SslContext::instance()),
+Battlenet::Session::Session(Trinity::Net::IoContextTcpSocket&& socket) : BaseSocket(std::move(socket), SslContext::instance()),
_accountInfo(new AccountInfo()), _gameAccountInfo(nullptr), _locale(),
_os(), _build(0), _clientInfo(), _timezoneOffset(0min), _ipCountry(), _clientSecret(), _authed(false), _requestToken(0)
{
@@ -83,54 +84,24 @@ Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSo
Battlenet::Session::~Session() = default;
-void Battlenet::Session::AsyncHandshake()
-{
- underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
- [sess = shared_from_this()](boost::system::error_code const& error) { sess->HandshakeHandler(error); });
-}
-
void Battlenet::Session::Start()
{
- std::string ip_address = GetRemoteIpAddress().to_string();
TC_LOG_TRACE("session", "{} Accepted connection", GetClientInfo());
- // Verify that this IP is not in the ip_banned table
- LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS));
+ // build initializer chain
+ std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers =
+ { {
+ std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<Session>>(this),
+ std::make_shared<Trinity::Net::SslHandshakeConnectionInitializer<Session>>(this),
+ std::make_shared<Trinity::Net::ReadConnectionInitializer<Session>>(this),
+ } };
- LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
- stmt->setString(0, ip_address);
-
- _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
- .WithPreparedCallback([sess = shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); }));
-}
-
-void Battlenet::Session::CheckIpCallback(PreparedQueryResult result)
-{
- if (result)
- {
- bool banned = false;
- do
- {
- Field* fields = result->Fetch();
- if (fields[0].GetUInt64() != 0)
- banned = true;
-
- } while (result->NextRow());
-
- if (banned)
- {
- TC_LOG_DEBUG("session", "{} tries to log in using banned IP!", GetClientInfo());
- CloseSocket();
- return;
- }
- }
-
- AsyncHandshake();
+ Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start();
}
bool Battlenet::Session::Update()
{
- if (!BattlenetSocket::Update())
+ if (!BaseSocket::Update())
return false;
_queryProcessor.ProcessReadyCallbacks();
@@ -211,6 +182,11 @@ void Battlenet::Session::SendRequest(uint32 serviceHash, uint32 methodId, pb::Me
AsyncWrite(&packet);
}
+void Battlenet::Session::QueueQuery(QueryCallback&& queryCallback)
+{
+ _queryProcessor.AddCallback(std::move(queryCallback));
+}
+
uint32 Battlenet::Session::HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation)
{
if (logonRequest->program() != "WoW")
@@ -292,7 +268,7 @@ uint32 Battlenet::Session::HandleGenerateWebCredentials(authentication::v1::Gene
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_BNET_EXISTING_AUTHENTICATION_BY_ID);
stmt->setUInt32(0, _accountInfo->Id);
- _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result)
+ QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result)
{
// just send existing credentials back (not the best but it works for now with them being stored in db)
Battlenet::Services::Authentication asyncContinuationService(this);
@@ -314,7 +290,7 @@ uint32 Battlenet::Session::VerifyWebCredentials(std::string const& webCredential
std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)> asyncContinuation = std::move(continuation);
std::shared_ptr<AccountInfo> accountInfo = std::make_shared<AccountInfo>();
- _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result)
+ QueueQuery(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result)
{
Battlenet::Services::Authentication asyncContinuationService(this);
NoData response;
@@ -713,20 +689,8 @@ uint32 Battlenet::Session::HandleGetAllValuesForAttribute(game_utilities::v1::Ge
return ERROR_RPC_NOT_IMPLEMENTED;
}
-void Battlenet::Session::HandshakeHandler(boost::system::error_code const& error)
-{
- if (error)
- {
- TC_LOG_ERROR("session", "{} SSL Handshake failed {}", GetClientInfo(), error.message());
- CloseSocket();
- return;
- }
-
- AsyncRead();
-}
-
template<bool(Battlenet::Session::*processMethod)(), MessageBuffer Battlenet::Session::*outputBuffer>
-inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer)
+static inline Optional<Trinity::Net::SocketReadCallbackResult> PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer)
{
MessageBuffer& buffer = session->*outputBuffer;
@@ -738,46 +702,45 @@ inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inp
buffer.Write(inputBuffer.GetReadPointer(), readDataSize);
inputBuffer.ReadCompleted(readDataSize);
}
+ else
+ return { }; // go to next buffer
if (buffer.GetRemainingSpace() > 0)
{
// Couldn't receive the whole data this time.
ASSERT(inputBuffer.GetActiveSize() == 0);
- return false;
+ return Trinity::Net::SocketReadCallbackResult::KeepReading;
}
// just received fresh new payload
if (!(session->*processMethod)())
{
session->CloseSocket();
- return false;
+ return Trinity::Net::SocketReadCallbackResult::Stop;
}
- return true;
+ return { }; // go to next buffer
}
-void Battlenet::Session::ReadHandler()
+Trinity::Net::SocketReadCallbackResult Battlenet::Session::ReadHandler()
{
- if (!IsOpen())
- return;
-
MessageBuffer& packet = GetReadBuffer();
while (packet.GetActiveSize() > 0)
{
- if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderLengthHandler, &Battlenet::Session::_headerLengthBuffer>(this, packet))
- break;
+ if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderLengthHandler, &Session::_headerLengthBuffer>(this, packet))
+ return *partialResult;
- if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderHandler, &Battlenet::Session::_headerBuffer>(this, packet))
- break;
+ if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderHandler, &Session::_headerBuffer>(this, packet))
+ return *partialResult;
- if (!PartialProcessPacket<&Battlenet::Session::ReadDataHandler, &Battlenet::Session::_packetBuffer>(this, packet))
- break;
+ if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadDataHandler, &Session::_packetBuffer>(this, packet))
+ return *partialResult;
_headerLengthBuffer.Reset();
_headerBuffer.Reset();
}
- AsyncRead();
+ return Trinity::Net::SocketReadCallbackResult::KeepReading;
}
bool Battlenet::Session::ReadHeaderLengthHandler()
@@ -835,5 +798,5 @@ std::string Battlenet::Session::GetClientInfo() const
stream << ']';
- return stream.str();
+ return std::move(stream).str();
}
diff --git a/src/server/bnetserver/Server/Session.h b/src/server/bnetserver/Server/Session.h
index 8a8b1f0a15b..faf82115f18 100644
--- a/src/server/bnetserver/Server/Session.h
+++ b/src/server/bnetserver/Server/Session.h
@@ -15,8 +15,8 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef Session_h__
-#define Session_h__
+#ifndef TRINITYCORE_SESSION_H
+#define TRINITYCORE_SESSION_H
#include "AsyncCallbackProcessor.h"
#include "ClientBuildInfo.h"
@@ -24,7 +24,7 @@
#include "QueryResult.h"
#include "Realm.h"
#include "Socket.h"
-#include "SslSocket.h"
+#include "SslStream.h"
#include <boost/asio/ip/tcp.hpp>
#include <google/protobuf/message.h>
#include <memory>
@@ -65,9 +65,9 @@ using namespace bgs::protocol;
namespace Battlenet
{
- class Session : public Socket<Session, SslSocket<>>
+ class Session final : public Trinity::Net::Socket<Trinity::Net::SslStream<>>
{
- typedef Socket<Session, SslSocket<>> BattlenetSocket;
+ using BaseSocket = Socket<Trinity::Net::SslStream<>>;
public:
struct LastPlayedCharacterInfo
@@ -110,7 +110,7 @@ namespace Battlenet
std::unordered_map<uint32, GameAccountInfo> GameAccounts;
};
- explicit Session(boost::asio::ip::tcp::socket&& socket);
+ explicit Session(Trinity::Net::IoContextTcpSocket&& socket);
~Session();
void Start() override;
@@ -130,6 +130,8 @@ namespace Battlenet
void SendRequest(uint32 serviceHash, uint32 methodId, pb::Message const* request);
+ void QueueQuery(QueryCallback&& queryCallback);
+
uint32 HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation);
uint32 HandleVerifyWebCredentials(authentication::v1::VerifyWebCredentialsRequest const* verifyWebCredentialsRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation);
uint32 HandleGenerateWebCredentials(authentication::v1::GenerateWebCredentialsRequest const* request, std::function<void(ServiceBase*, uint32, google::protobuf::Message const*)>& continuation);
@@ -140,9 +142,9 @@ namespace Battlenet
std::string GetClientInfo() const;
+ Trinity::Net::SocketReadCallbackResult ReadHandler() override;
+
protected:
- void HandshakeHandler(boost::system::error_code const& error);
- void ReadHandler() override;
bool ReadHeaderLengthHandler();
bool ReadHeaderHandler();
bool ReadDataHandler();
@@ -150,10 +152,6 @@ namespace Battlenet
private:
void AsyncWrite(MessageBuffer* packet);
- void AsyncHandshake();
-
- void CheckIpCallback(PreparedQueryResult result);
-
uint32 VerifyWebCredentials(std::string const& webCredentials, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation);
typedef uint32(Session::*ClientRequestHandler)(std::unordered_map<std::string, Variant const*> const&, game_utilities::v1::ClientResponse*);
@@ -190,4 +188,4 @@ namespace Battlenet
};
}
-#endif // Session_h__
+#endif // TRINITYCORE_SESSION_H
diff --git a/src/server/bnetserver/Server/SessionManager.cpp b/src/server/bnetserver/Server/SessionManager.cpp
index cb56972a9c2..4c5b532ee60 100644
--- a/src/server/bnetserver/Server/SessionManager.cpp
+++ b/src/server/bnetserver/Server/SessionManager.cpp
@@ -16,7 +16,6 @@
*/
#include "SessionManager.h"
-#include "DatabaseEnv.h"
#include "Util.h"
bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
@@ -24,18 +23,16 @@ bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
- _acceptor->AsyncAcceptWithCallback<&OnSocketAccept>();
+ _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
+ {
+ OnSocketOpen(std::move(sock), threadIndex);
+ });
return true;
}
-NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() const
+Trinity::Net::NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() const
{
- return new NetworkThread<Session>[GetNetworkThreadCount()];
-}
-
-void Battlenet::SessionManager::OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex)
-{
- sSessionMgr.OnSocketOpen(std::move(sock), threadIndex);
+ return new Trinity::Net::NetworkThread<Session>[GetNetworkThreadCount()];
}
Battlenet::SessionManager& Battlenet::SessionManager::Instance()
diff --git a/src/server/bnetserver/Server/SessionManager.h b/src/server/bnetserver/Server/SessionManager.h
index c635122c977..528ece8739e 100644
--- a/src/server/bnetserver/Server/SessionManager.h
+++ b/src/server/bnetserver/Server/SessionManager.h
@@ -15,15 +15,15 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef SessionManager_h__
-#define SessionManager_h__
+#ifndef TRINITYCORE_SESSION_MANAGER_H
+#define TRINITYCORE_SESSION_MANAGER_H
#include "SocketMgr.h"
#include "Session.h"
namespace Battlenet
{
- class SessionManager : public SocketMgr<Session>
+ class SessionManager : public Trinity::Net::SocketMgr<Session>
{
typedef SocketMgr<Session> BaseSocketMgr;
@@ -33,13 +33,10 @@ namespace Battlenet
bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount = 1) override;
protected:
- NetworkThread<Session>* CreateThreads() const override;
-
- private:
- static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex);
+ Trinity::Net::NetworkThread<Session>* CreateThreads() const override;
};
}
#define sSessionMgr Battlenet::SessionManager::Instance()
-#endif // SessionManager_h__
+#endif // TRINITYCORE_SESSION_MANAGER_H