aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorShauren <shauren.trinity@gmail.com>2025-04-08 19:15:16 +0200
committerShauren <shauren.trinity@gmail.com>2025-04-08 19:15:16 +0200
commite8b2be3527c7683e8bfca70ed7706fc20da566fd (patch)
tree54d5099554c8628cad719e6f1a49d387c7eced4f /src
parent40d80f3476ade4898be24659408e82aa4234b099 (diff)
Core/Network: Socket refactors
* Devirtualize calls to Read and Update by marking concrete implementations as final * Removed derived class template argument * Specialize boost::asio::basic_stream_socket for boost::asio::io_context instead of type-erased any_io_executor * Make socket initialization easier composable (before entering Read loop) * Remove use of deprecated boost::asio::null_buffers and boost::beast::ssl_stream
Diffstat (limited to 'src')
-rw-r--r--src/common/Utilities/Containers.h2
-rw-r--r--src/server/bnetserver/REST/LoginHttpSession.cpp169
-rw-r--r--src/server/bnetserver/REST/LoginHttpSession.h64
-rw-r--r--src/server/bnetserver/REST/LoginRESTService.cpp34
-rw-r--r--src/server/bnetserver/REST/LoginRESTService.h22
-rw-r--r--src/server/bnetserver/Server/Session.cpp103
-rw-r--r--src/server/bnetserver/Server/Session.h24
-rw-r--r--src/server/bnetserver/Server/SessionManager.cpp15
-rw-r--r--src/server/bnetserver/Server/SessionManager.h13
-rw-r--r--src/server/game/Scripting/ScriptMgr.cpp4
-rw-r--r--src/server/game/Scripting/ScriptMgr.h4
-rw-r--r--src/server/game/Server/WorldSocket.cpp240
-rw-r--r--src/server/game/Server/WorldSocket.h30
-rw-r--r--src/server/game/Server/WorldSocketMgr.cpp20
-rw-r--r--src/server/game/Server/WorldSocketMgr.h17
-rw-r--r--src/server/shared/Networking/AsyncAcceptor.h64
-rw-r--r--src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp42
-rw-r--r--src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h64
-rw-r--r--src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h51
-rw-r--r--src/server/shared/Networking/Http/BaseHttpSocket.cpp31
-rw-r--r--src/server/shared/Networking/Http/BaseHttpSocket.h136
-rw-r--r--src/server/shared/Networking/Http/HttpService.h2
-rw-r--r--src/server/shared/Networking/Http/HttpSocket.h43
-rw-r--r--src/server/shared/Networking/Http/HttpSslSocket.h69
-rw-r--r--src/server/shared/Networking/NetworkThread.h45
-rw-r--r--src/server/shared/Networking/Socket.h158
-rw-r--r--src/server/shared/Networking/SocketMgr.h48
-rw-r--r--src/server/shared/Networking/SslSocket.h88
-rw-r--r--src/server/shared/Networking/SslStream.h131
-rw-r--r--src/server/worldserver/Main.cpp22
-rw-r--r--src/server/worldserver/RemoteAccess/RASession.cpp4
-rw-r--r--src/server/worldserver/RemoteAccess/RASession.h9
32 files changed, 962 insertions, 806 deletions
diff --git a/src/common/Utilities/Containers.h b/src/common/Utilities/Containers.h
index 4a764629937..11928fa053d 100644
--- a/src/common/Utilities/Containers.h
+++ b/src/common/Utilities/Containers.h
@@ -259,7 +259,7 @@ namespace Trinity
if (!p(*rpos))
{
if (rpos != wpos)
- std::swap(*rpos, *wpos);
+ std::ranges::swap(*rpos, *wpos);
++wpos;
}
}
diff --git a/src/server/bnetserver/REST/LoginHttpSession.cpp b/src/server/bnetserver/REST/LoginHttpSession.cpp
index aff579de7f9..23a317d3726 100644
--- a/src/server/bnetserver/REST/LoginHttpSession.cpp
+++ b/src/server/bnetserver/REST/LoginHttpSession.cpp
@@ -17,126 +17,117 @@
#include "LoginHttpSession.h"
#include "DatabaseEnv.h"
+#include "HttpSocket.h"
+#include "HttpSslSocket.h"
+#include "IpBanCheckConnectionInitializer.h"
#include "LoginRESTService.h"
#include "SslContext.h"
#include "Util.h"
+#include <boost/container/static_vector.hpp>
-namespace Battlenet
-{
-template<template<typename> typename SocketImpl>
-LoginHttpSession<SocketImpl>::LoginHttpSession(boost::asio::ip::tcp::socket&& socket, LoginHttpSessionWrapper& owner)
- : BaseSocket(std::move(socket), SslContext::instance()), _owner(owner)
-{
-}
-
-template<template<typename> typename SocketImpl>
-LoginHttpSession<SocketImpl>::~LoginHttpSession() = default;
-
-template<template<typename> typename SocketImpl>
-void LoginHttpSession<SocketImpl>::Start()
-{
- std::string ip_address = this->GetRemoteIpAddress().to_string();
- TC_LOG_TRACE("server.http.session", "{} Accepted connection", this->GetClientInfo());
-
- // Verify that this IP is not in the ip_banned table
- LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS));
-
- LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
- stmt->setString(0, ip_address);
-
- this->_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
- .WithPreparedCallback([sess = this->shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); }));
-}
-
-template<template<typename> typename SocketImpl>
-void LoginHttpSession<SocketImpl>::CheckIpCallback(PreparedQueryResult result)
-{
- if (result)
- {
- bool banned = false;
- do
- {
- Field* fields = result->Fetch();
- if (fields[0].GetUInt64() != 0)
- banned = true;
-
- } while (result->NextRow());
-
- if (banned)
- {
- TC_LOG_DEBUG("server.http.session", "{} tries to log in using banned IP!", this->GetClientInfo());
- this->CloseSocket();
- return;
- }
- }
-
- if constexpr (std::is_same_v<BaseSocket, Trinity::Net::Http::SslSocket<LoginHttpSession<Trinity::Net::Http::SslSocket>>>)
- {
- this->AsyncHandshake();
- }
- else
- {
- this->ResetHttpParser();
- this->AsyncRead();
- }
-}
-
-template<template<typename> typename SocketImpl>
-Trinity::Net::Http::RequestHandlerResult LoginHttpSession<SocketImpl>::RequestHandler(Trinity::Net::Http::RequestContext& context)
+namespace
{
- return sLoginService.HandleRequest(_owner.shared_from_this(), context);
-}
-
std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context, boost::asio::ip::address const& remoteAddress)
{
using namespace std::string_literals;
-
std::shared_ptr<Trinity::Net::Http::SessionState> state;
-
auto cookieItr = context.request.find(boost::beast::http::field::cookie);
if (cookieItr != context.request.end())
{
std::vector<std::string_view> cookies = Trinity::Tokenize(Trinity::Net::Http::ToStdStringView(cookieItr->value()), ';', false);
- std::size_t eq = 0;
- auto sessionIdItr = std::find_if(cookies.begin(), cookies.end(), [&](std::string_view cookie)
+ auto sessionIdItr = std::ranges::find_if(cookies, [](std::string_view const& cookie)
{
- std::string_view name = cookie;
- eq = cookie.find('=');
- if (eq != std::string_view::npos)
- name = cookie.substr(0, eq);
-
- return name == LoginHttpSessionWrapper::SESSION_ID_COOKIE;
+ return cookie.length() > Battlenet::LoginHttpSession::SESSION_ID_COOKIE.length()
+ && cookie.starts_with(Battlenet::LoginHttpSession::SESSION_ID_COOKIE);
});
if (sessionIdItr != cookies.end())
{
- std::string_view value = sessionIdItr->substr(eq + 1);
- state = sLoginService.FindAndRefreshSessionState(value, remoteAddress);
+ sessionIdItr->remove_prefix(Battlenet::LoginHttpSession::SESSION_ID_COOKIE.length());
+ state = sLoginService.FindAndRefreshSessionState(*sessionIdItr, remoteAddress);
}
}
-
if (!state)
{
state = sLoginService.CreateNewSessionState(remoteAddress);
std::string_view host = Trinity::Net::Http::ToStdStringView(context.request[boost::beast::http::field::host]);
- if (std::size_t port = host.find(':'); port != std::string_view::npos)
+ if (size_t port = host.find(':'); port != std::string_view::npos)
host.remove_suffix(host.length() - port);
- context.response.insert(boost::beast::http::field::set_cookie, Trinity::StringFormat("{}={}; Path=/bnetserver; Domain={}; Secure; HttpOnly; SameSite=None",
- LoginHttpSessionWrapper::SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host));
+ context.response.insert(boost::beast::http::field::set_cookie, Trinity::StringFormat("{}{}; Path=/bnetserver; Domain={}; Secure; HttpOnly; SameSite=None",
+ Battlenet::LoginHttpSession::SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host));
}
-
return state;
}
-template class LoginHttpSession<Trinity::Net::Http::SslSocket>;
-template class LoginHttpSession<Trinity::Net::Http::Socket>;
+template<typename SocketImpl>
+class LoginHttpSocketImpl final : public SocketImpl
+{
+public:
+ using BaseSocket = SocketImpl;
+
+ explicit LoginHttpSocketImpl(Trinity::Net::IoContextTcpSocket&& socket, Battlenet::LoginHttpSession& owner)
+ : BaseSocket(std::move(socket)), _owner(owner)
+ {
+ }
+
+ LoginHttpSocketImpl(LoginHttpSocketImpl const&) = delete;
+ LoginHttpSocketImpl(LoginHttpSocketImpl&&) = delete;
+ LoginHttpSocketImpl& operator=(LoginHttpSocketImpl const&) = delete;
+ LoginHttpSocketImpl& operator=(LoginHttpSocketImpl&&) = delete;
+
+ ~LoginHttpSocketImpl() = default;
+
+ void Start() override
+ {
+ // build initializer chain
+ boost::container::static_vector<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 4> initializers;
+
+ initializers.stable_emplace_back(std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<BaseSocket>>(this));
+
+ if constexpr (std::is_same_v<BaseSocket, Trinity::Net::Http::SslSocket>)
+ initializers.stable_emplace_back(std::make_shared<Trinity::Net::SslHandshakeConnectionInitializer<BaseSocket>>(this));
+
+ initializers.stable_emplace_back(std::make_shared<Trinity::Net::Http::HttpConnectionInitializer<BaseSocket>>(this));
+ initializers.stable_emplace_back(std::make_shared<Trinity::Net::ReadConnectionInitializer<BaseSocket>>(this));
+
+ Trinity::Net::SocketConnectionInitializer::SetupChain(std::span(initializers.data(), initializers.size()))->Start();
+ }
+
+ Trinity::Net::Http::RequestHandlerResult RequestHandler(Trinity::Net::Http::RequestContext& context) override
+ {
+ return sLoginService.HandleRequest(_owner.shared_from_this(), context);
+ }
+
+protected:
+ std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context) const override
+ {
+ return ::ObtainSessionState(context, this->GetRemoteIpAddress());
+ }
+
+ Battlenet::LoginHttpSession& _owner;
+};
+
+template<>
+LoginHttpSocketImpl<Trinity::Net::Http::SslSocket>::LoginHttpSocketImpl(Trinity::Net::IoContextTcpSocket&& socket, Battlenet::LoginHttpSession& owner)
+ : BaseSocket(std::move(socket), Battlenet::SslContext::instance()), _owner(owner)
+{
+}
+}
+
+namespace Battlenet
+{
+LoginHttpSession::LoginHttpSession(Trinity::Net::IoContextTcpSocket&& socket)
+ : _socket(!SslContext::UsesDevWildcardCertificate()
+ ? std::shared_ptr<AbstractSocket>(std::make_shared<LoginHttpSocketImpl<Trinity::Net::Http::SslSocket>>(std::move(socket), *this))
+ : std::shared_ptr<AbstractSocket>(std::make_shared<LoginHttpSocketImpl<Trinity::Net::Http::Socket>>(std::move(socket), *this)))
+{
+}
-LoginHttpSessionWrapper::LoginHttpSessionWrapper(boost::asio::ip::tcp::socket&& socket)
+void LoginHttpSession::Start()
{
- if (!SslContext::UsesDevWildcardCertificate())
- _socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::SslSocket>>(std::move(socket), *this);
- else
- _socket = std::make_shared<LoginHttpSession<Trinity::Net::Http::Socket>>(std::move(socket), *this);
+ TC_LOG_TRACE("server.http.session", "{} Accepted connection", GetClientInfo());
+
+ return _socket->Start();
}
}
diff --git a/src/server/bnetserver/REST/LoginHttpSession.h b/src/server/bnetserver/REST/LoginHttpSession.h
index 6bd1ec113d0..c15442f9e0c 100644
--- a/src/server/bnetserver/REST/LoginHttpSession.h
+++ b/src/server/bnetserver/REST/LoginHttpSession.h
@@ -18,10 +18,8 @@
#ifndef TRINITYCORE_LOGIN_HTTP_SESSION_H
#define TRINITYCORE_LOGIN_HTTP_SESSION_H
-#include "HttpSocket.h"
-#include "HttpSslSocket.h"
+#include "BaseHttpSocket.h"
#include "SRP6.h"
-#include <variant>
namespace Battlenet
{
@@ -30,59 +28,27 @@ struct LoginSessionState : public Trinity::Net::Http::SessionState
std::unique_ptr<Trinity::Crypto::SRP::BnetSRP6Base> Srp;
};
-class LoginHttpSessionWrapper;
-std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context, boost::asio::ip::address const& remoteAddress);
-
-template<template<typename> typename SocketImpl>
-class LoginHttpSession : public SocketImpl<LoginHttpSession<SocketImpl>>
+class LoginHttpSession : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSession>
{
- using BaseSocket = SocketImpl<LoginHttpSession<SocketImpl>>;
-
public:
- explicit LoginHttpSession(boost::asio::ip::tcp::socket&& socket, LoginHttpSessionWrapper& owner);
- ~LoginHttpSession();
-
- void Start() override;
+ static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID=";
- void CheckIpCallback(PreparedQueryResult result);
+ explicit LoginHttpSession(Trinity::Net::IoContextTcpSocket&& socket);
- Trinity::Net::Http::RequestHandlerResult RequestHandler(Trinity::Net::Http::RequestContext& context) override;
-
- LoginSessionState* GetSessionState() const { return static_cast<LoginSessionState*>(this->_state.get()); }
-
-protected:
- std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context) const override
- {
- return Battlenet::ObtainSessionState(context, this->GetRemoteIpAddress());
- }
-
- LoginHttpSessionWrapper& _owner;
-};
-
-class LoginHttpSessionWrapper : public Trinity::Net::Http::AbstractSocket, public std::enable_shared_from_this<LoginHttpSessionWrapper>
-{
-public:
- static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID";
-
- explicit LoginHttpSessionWrapper(boost::asio::ip::tcp::socket&& socket);
-
- void Start() { return std::visit([&](auto&& socket) { return socket->Start(); }, _socket); }
- bool Update() { return std::visit([&](auto&& socket) { return socket->Update(); }, _socket); }
- boost::asio::ip::address GetRemoteIpAddress() const { return std::visit([&](auto&& socket) { return socket->GetRemoteIpAddress(); }, _socket); }
- bool IsOpen() const { return std::visit([&](auto&& socket) { return socket->IsOpen(); }, _socket); }
- void CloseSocket() { return std::visit([&](auto&& socket) { return socket->CloseSocket(); }, _socket); }
+ void Start() override;
+ bool Update() override { return _socket->Update(); }
+ boost::asio::ip::address const& GetRemoteIpAddress() const override { return _socket->GetRemoteIpAddress(); }
+ bool IsOpen() const override { return _socket->IsOpen(); }
+ void CloseSocket() override { return _socket->CloseSocket(); }
- void SendResponse(Trinity::Net::Http::RequestContext& context) override { return std::visit([&](auto&& socket) { return socket->SendResponse(context); }, _socket); }
- void QueueQuery(QueryCallback&& queryCallback) override { return std::visit([&](auto&& socket) { return socket->QueueQuery(std::move(queryCallback)); }, _socket); }
- std::string GetClientInfo() const override { return std::visit([&](auto&& socket) { return socket->GetClientInfo(); }, _socket); }
- Optional<boost::uuids::uuid> GetSessionId() const override { return std::visit([&](auto&& socket) { return socket->GetSessionId(); }, _socket); }
- LoginSessionState* GetSessionState() const { return std::visit([&](auto&& socket) { return socket->GetSessionState(); }, _socket); }
+ void SendResponse(Trinity::Net::Http::RequestContext& context) override { return _socket->SendResponse(context); }
+ void QueueQuery(QueryCallback&& queryCallback) override { return _socket->QueueQuery(std::move(queryCallback)); }
+ std::string GetClientInfo() const override { return _socket->GetClientInfo(); }
+ LoginSessionState* GetSessionState() const override { return static_cast<LoginSessionState*>(_socket->GetSessionState()); }
private:
- std::variant<
- std::shared_ptr<LoginHttpSession<Trinity::Net::Http::SslSocket>>,
- std::shared_ptr<LoginHttpSession<Trinity::Net::Http::Socket>>
- > _socket;
+ std::shared_ptr<Trinity::Net::Http::AbstractSocket> _socket;
};
}
+
#endif // TRINITYCORE_LOGIN_HTTP_SESSION_H
diff --git a/src/server/bnetserver/REST/LoginRESTService.cpp b/src/server/bnetserver/REST/LoginRESTService.cpp
index 12f6cf63712..1afcc5c88bb 100644
--- a/src/server/bnetserver/REST/LoginRESTService.cpp
+++ b/src/server/bnetserver/REST/LoginRESTService.cpp
@@ -44,32 +44,32 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st
using Trinity::Net::Http::RequestHandlerFlag;
- RegisterHandler(boost::beast::http::verb::get, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
+ RegisterHandler(boost::beast::http::verb::get, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
{
return HandleGetForm(std::move(session), context);
});
- RegisterHandler(boost::beast::http::verb::get, "/bnetserver/gameAccounts/", [](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
+ RegisterHandler(boost::beast::http::verb::get, "/bnetserver/gameAccounts/", [](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
{
return HandleGetGameAccounts(std::move(session), context);
});
- RegisterHandler(boost::beast::http::verb::get, "/bnetserver/portal/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
+ RegisterHandler(boost::beast::http::verb::get, "/bnetserver/portal/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
{
return HandleGetPortal(std::move(session), context);
});
- RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
+ RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
{
return HandlePostLogin(std::move(session), context);
}, RequestHandlerFlag::DoNotLogRequestContent);
- RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/srp/", [](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
+ RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/srp/", [](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
{
return HandlePostLoginSrpChallenge(std::move(session), context);
});
- RegisterHandler(boost::beast::http::verb::post, "/bnetserver/refreshLoginTicket/", [this](std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
+ RegisterHandler(boost::beast::http::verb::post, "/bnetserver/refreshLoginTicket/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
{
return HandlePostRefreshLoginTicket(std::move(session), context);
});
@@ -121,7 +121,10 @@ bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::st
MigrateLegacyPasswordHashes();
- _acceptor->AsyncAcceptWithCallback<&LoginRESTService::OnSocketAccept>();
+ _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
+ {
+ OnSocketOpen(std::move(sock), threadIndex);
+ });
return true;
}
@@ -165,7 +168,7 @@ std::string LoginRESTService::ExtractAuthorization(HttpRequest const& request)
return ticket;
}
-LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const
+LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const
{
JSON::Login::FormInputs form = _formInputs;
form.set_srp_url(Trinity::StringFormat("http{}://{}:{}/bnetserver/login/srp/", !SslContext::UsesDevWildcardCertificate() ? "s" : "",
@@ -176,7 +179,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shar
return RequestHandlerResult::Handled;
}
-LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
+LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
{
std::string ticket = ExtractAuthorization(context.request);
if (ticket.empty())
@@ -225,14 +228,14 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(s
return RequestHandlerResult::Async;
}
-LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetPortal(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const
+LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetPortal(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const
{
context.response.set(boost::beast::http::field::content_type, "text/plain");
context.response.body() = Trinity::StringFormat("{}:{}", GetHostnameForClient(session->GetRemoteIpAddress()), sConfigMgr->GetIntDefault("BattlenetPort", 1119));
return RequestHandlerResult::Handled;
}
-LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const
+LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const
{
std::shared_ptr<JSON::Login::LoginForm> loginForm = std::make_shared<JSON::Login::LoginForm>();
if (!::JSON::Deserialize(context.request.body(), loginForm.get()))
@@ -396,7 +399,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::sh
return RequestHandlerResult::Async;
}
-LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context)
+LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context)
{
JSON::Login::LoginForm loginForm;
if (!::JSON::Deserialize(context.request.body(), &loginForm))
@@ -483,7 +486,7 @@ LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLoginSrpChall
return RequestHandlerResult::Async;
}
-LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const
+LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const
{
std::string ticket = ExtractAuthorization(context.request);
if (ticket.empty())
@@ -551,11 +554,6 @@ std::shared_ptr<Trinity::Net::Http::SessionState> LoginRESTService::CreateNewSes
return state;
}
-void LoginRESTService::OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex)
-{
- sLoginService.OnSocketOpen(std::move(sock), threadIndex);
-}
-
void LoginRESTService::MigrateLegacyPasswordHashes() const
{
if (!LoginDatabase.Query("SELECT 1 FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = SCHEMA() AND TABLE_NAME = 'battlenet_accounts' AND COLUMN_NAME = 'sha_pass_hash'"))
diff --git a/src/server/bnetserver/REST/LoginRESTService.h b/src/server/bnetserver/REST/LoginRESTService.h
index 079bb6fca9c..517f5e295b7 100644
--- a/src/server/bnetserver/REST/LoginRESTService.h
+++ b/src/server/bnetserver/REST/LoginRESTService.h
@@ -15,8 +15,8 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef LoginRESTService_h__
-#define LoginRESTService_h__
+#ifndef TRINITYCORE_LOGIN_REST_SERVICE_H
+#define TRINITYCORE_LOGIN_REST_SERVICE_H
#include "HttpService.h"
#include "Login.pb.h"
@@ -42,7 +42,7 @@ enum class BanMode
BAN_ACCOUNT = 1
};
-class LoginRESTService : public Trinity::Net::Http::HttpService<LoginHttpSessionWrapper>
+class LoginRESTService : public Trinity::Net::Http::HttpService<LoginHttpSession>
{
public:
using RequestHandlerResult = Trinity::Net::Http::RequestHandlerResult;
@@ -63,17 +63,15 @@ public:
std::shared_ptr<Trinity::Net::Http::SessionState> CreateNewSessionState(boost::asio::ip::address const& address) override;
private:
- static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex);
-
static std::string ExtractAuthorization(HttpRequest const& request);
- RequestHandlerResult HandleGetForm(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const;
- static RequestHandlerResult HandleGetGameAccounts(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context);
- RequestHandlerResult HandleGetPortal(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const;
+ RequestHandlerResult HandleGetForm(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const;
+ static RequestHandlerResult HandleGetGameAccounts(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context);
+ RequestHandlerResult HandleGetPortal(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const;
- RequestHandlerResult HandlePostLogin(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const;
- static RequestHandlerResult HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context);
- RequestHandlerResult HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSessionWrapper> session, HttpRequestContext& context) const;
+ RequestHandlerResult HandlePostLogin(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const;
+ static RequestHandlerResult HandlePostLoginSrpChallenge(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context);
+ RequestHandlerResult HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) const;
static std::unique_ptr<Trinity::Crypto::SRP::BnetSRP6Base> CreateSrpImplementation(SrpVersion version, SrpHashFunction hashFunction,
std::string const& username, Trinity::Crypto::SRP::Salt const& salt, Trinity::Crypto::SRP::Verifier const& verifier);
@@ -90,4 +88,4 @@ private:
#define sLoginService Battlenet::LoginRESTService::Instance()
-#endif // LoginRESTService_h__
+#endif // TRINITYCORE_LOGIN_REST_SERVICE_H
diff --git a/src/server/bnetserver/Server/Session.cpp b/src/server/bnetserver/Server/Session.cpp
index c3a8ae6f211..7b9d6f45a3e 100644
--- a/src/server/bnetserver/Server/Session.cpp
+++ b/src/server/bnetserver/Server/Session.cpp
@@ -23,6 +23,7 @@
#include "Errors.h"
#include "Hash.h"
#include "IPLocation.h"
+#include "IpBanCheckConnectionInitializer.h"
#include "LoginRESTService.h"
#include "MapUtils.h"
#include "ProtobufJSON.h"
@@ -74,7 +75,7 @@ void Battlenet::Session::GameAccountInfo::LoadResult(Field const* fields)
DisplayName = Name;
}
-Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSocket(std::move(socket), SslContext::instance()),
+Battlenet::Session::Session(Trinity::Net::IoContextTcpSocket&& socket) : BaseSocket(std::move(socket), SslContext::instance()),
_accountInfo(new AccountInfo()), _gameAccountInfo(nullptr), _locale(),
_os(), _build(0), _clientInfo(), _timezoneOffset(0min), _ipCountry(), _clientSecret(), _authed(false), _requestToken(0)
{
@@ -83,54 +84,24 @@ Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSo
Battlenet::Session::~Session() = default;
-void Battlenet::Session::AsyncHandshake()
-{
- underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
- [sess = shared_from_this()](boost::system::error_code const& error) { sess->HandshakeHandler(error); });
-}
-
void Battlenet::Session::Start()
{
- std::string ip_address = GetRemoteIpAddress().to_string();
TC_LOG_TRACE("session", "{} Accepted connection", GetClientInfo());
- // Verify that this IP is not in the ip_banned table
- LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS));
+ // build initializer chain
+ std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers =
+ { {
+ std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<Session>>(this),
+ std::make_shared<Trinity::Net::SslHandshakeConnectionInitializer<Session>>(this),
+ std::make_shared<Trinity::Net::ReadConnectionInitializer<Session>>(this),
+ } };
- LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
- stmt->setString(0, ip_address);
-
- _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
- .WithPreparedCallback([sess = shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); }));
-}
-
-void Battlenet::Session::CheckIpCallback(PreparedQueryResult result)
-{
- if (result)
- {
- bool banned = false;
- do
- {
- Field* fields = result->Fetch();
- if (fields[0].GetUInt64() != 0)
- banned = true;
-
- } while (result->NextRow());
-
- if (banned)
- {
- TC_LOG_DEBUG("session", "{} tries to log in using banned IP!", GetClientInfo());
- CloseSocket();
- return;
- }
- }
-
- AsyncHandshake();
+ Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start();
}
bool Battlenet::Session::Update()
{
- if (!BattlenetSocket::Update())
+ if (!BaseSocket::Update())
return false;
_queryProcessor.ProcessReadyCallbacks();
@@ -211,6 +182,11 @@ void Battlenet::Session::SendRequest(uint32 serviceHash, uint32 methodId, pb::Me
AsyncWrite(&packet);
}
+void Battlenet::Session::QueueQuery(QueryCallback&& queryCallback)
+{
+ _queryProcessor.AddCallback(std::move(queryCallback));
+}
+
uint32 Battlenet::Session::HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation)
{
if (logonRequest->program() != "WoW")
@@ -292,7 +268,7 @@ uint32 Battlenet::Session::HandleGenerateWebCredentials(authentication::v1::Gene
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_BNET_EXISTING_AUTHENTICATION_BY_ID);
stmt->setUInt32(0, _accountInfo->Id);
- _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result)
+ QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, asyncContinuation = std::move(continuation)](PreparedQueryResult result)
{
// just send existing credentials back (not the best but it works for now with them being stored in db)
Battlenet::Services::Authentication asyncContinuationService(this);
@@ -314,7 +290,7 @@ uint32 Battlenet::Session::VerifyWebCredentials(std::string const& webCredential
std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)> asyncContinuation = std::move(continuation);
std::shared_ptr<AccountInfo> accountInfo = std::make_shared<AccountInfo>();
- _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result)
+ QueueQuery(LoginDatabase.AsyncQuery(stmt).WithChainingPreparedCallback([this, accountInfo, asyncContinuation](QueryCallback& callback, PreparedQueryResult result)
{
Battlenet::Services::Authentication asyncContinuationService(this);
NoData response;
@@ -713,20 +689,8 @@ uint32 Battlenet::Session::HandleGetAllValuesForAttribute(game_utilities::v1::Ge
return ERROR_RPC_NOT_IMPLEMENTED;
}
-void Battlenet::Session::HandshakeHandler(boost::system::error_code const& error)
-{
- if (error)
- {
- TC_LOG_ERROR("session", "{} SSL Handshake failed {}", GetClientInfo(), error.message());
- CloseSocket();
- return;
- }
-
- AsyncRead();
-}
-
template<bool(Battlenet::Session::*processMethod)(), MessageBuffer Battlenet::Session::*outputBuffer>
-inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer)
+static inline Optional<Trinity::Net::SocketReadCallbackResult> PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inputBuffer)
{
MessageBuffer& buffer = session->*outputBuffer;
@@ -738,46 +702,45 @@ inline bool PartialProcessPacket(Battlenet::Session* session, MessageBuffer& inp
buffer.Write(inputBuffer.GetReadPointer(), readDataSize);
inputBuffer.ReadCompleted(readDataSize);
}
+ else
+ return { }; // go to next buffer
if (buffer.GetRemainingSpace() > 0)
{
// Couldn't receive the whole data this time.
ASSERT(inputBuffer.GetActiveSize() == 0);
- return false;
+ return Trinity::Net::SocketReadCallbackResult::KeepReading;
}
// just received fresh new payload
if (!(session->*processMethod)())
{
session->CloseSocket();
- return false;
+ return Trinity::Net::SocketReadCallbackResult::Stop;
}
- return true;
+ return { }; // go to next buffer
}
-void Battlenet::Session::ReadHandler()
+Trinity::Net::SocketReadCallbackResult Battlenet::Session::ReadHandler()
{
- if (!IsOpen())
- return;
-
MessageBuffer& packet = GetReadBuffer();
while (packet.GetActiveSize() > 0)
{
- if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderLengthHandler, &Battlenet::Session::_headerLengthBuffer>(this, packet))
- break;
+ if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderLengthHandler, &Session::_headerLengthBuffer>(this, packet))
+ return *partialResult;
- if (!PartialProcessPacket<&Battlenet::Session::ReadHeaderHandler, &Battlenet::Session::_headerBuffer>(this, packet))
- break;
+ if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadHeaderHandler, &Session::_headerBuffer>(this, packet))
+ return *partialResult;
- if (!PartialProcessPacket<&Battlenet::Session::ReadDataHandler, &Battlenet::Session::_packetBuffer>(this, packet))
- break;
+ if (Optional<Trinity::Net::SocketReadCallbackResult> partialResult = PartialProcessPacket<&Session::ReadDataHandler, &Session::_packetBuffer>(this, packet))
+ return *partialResult;
_headerLengthBuffer.Reset();
_headerBuffer.Reset();
}
- AsyncRead();
+ return Trinity::Net::SocketReadCallbackResult::KeepReading;
}
bool Battlenet::Session::ReadHeaderLengthHandler()
@@ -835,5 +798,5 @@ std::string Battlenet::Session::GetClientInfo() const
stream << ']';
- return stream.str();
+ return std::move(stream).str();
}
diff --git a/src/server/bnetserver/Server/Session.h b/src/server/bnetserver/Server/Session.h
index 8a8b1f0a15b..faf82115f18 100644
--- a/src/server/bnetserver/Server/Session.h
+++ b/src/server/bnetserver/Server/Session.h
@@ -15,8 +15,8 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef Session_h__
-#define Session_h__
+#ifndef TRINITYCORE_SESSION_H
+#define TRINITYCORE_SESSION_H
#include "AsyncCallbackProcessor.h"
#include "ClientBuildInfo.h"
@@ -24,7 +24,7 @@
#include "QueryResult.h"
#include "Realm.h"
#include "Socket.h"
-#include "SslSocket.h"
+#include "SslStream.h"
#include <boost/asio/ip/tcp.hpp>
#include <google/protobuf/message.h>
#include <memory>
@@ -65,9 +65,9 @@ using namespace bgs::protocol;
namespace Battlenet
{
- class Session : public Socket<Session, SslSocket<>>
+ class Session final : public Trinity::Net::Socket<Trinity::Net::SslStream<>>
{
- typedef Socket<Session, SslSocket<>> BattlenetSocket;
+ using BaseSocket = Socket<Trinity::Net::SslStream<>>;
public:
struct LastPlayedCharacterInfo
@@ -110,7 +110,7 @@ namespace Battlenet
std::unordered_map<uint32, GameAccountInfo> GameAccounts;
};
- explicit Session(boost::asio::ip::tcp::socket&& socket);
+ explicit Session(Trinity::Net::IoContextTcpSocket&& socket);
~Session();
void Start() override;
@@ -130,6 +130,8 @@ namespace Battlenet
void SendRequest(uint32 serviceHash, uint32 methodId, pb::Message const* request);
+ void QueueQuery(QueryCallback&& queryCallback);
+
uint32 HandleLogon(authentication::v1::LogonRequest const* logonRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation);
uint32 HandleVerifyWebCredentials(authentication::v1::VerifyWebCredentialsRequest const* verifyWebCredentialsRequest, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation);
uint32 HandleGenerateWebCredentials(authentication::v1::GenerateWebCredentialsRequest const* request, std::function<void(ServiceBase*, uint32, google::protobuf::Message const*)>& continuation);
@@ -140,9 +142,9 @@ namespace Battlenet
std::string GetClientInfo() const;
+ Trinity::Net::SocketReadCallbackResult ReadHandler() override;
+
protected:
- void HandshakeHandler(boost::system::error_code const& error);
- void ReadHandler() override;
bool ReadHeaderLengthHandler();
bool ReadHeaderHandler();
bool ReadDataHandler();
@@ -150,10 +152,6 @@ namespace Battlenet
private:
void AsyncWrite(MessageBuffer* packet);
- void AsyncHandshake();
-
- void CheckIpCallback(PreparedQueryResult result);
-
uint32 VerifyWebCredentials(std::string const& webCredentials, std::function<void(ServiceBase*, uint32, ::google::protobuf::Message const*)>& continuation);
typedef uint32(Session::*ClientRequestHandler)(std::unordered_map<std::string, Variant const*> const&, game_utilities::v1::ClientResponse*);
@@ -190,4 +188,4 @@ namespace Battlenet
};
}
-#endif // Session_h__
+#endif // TRINITYCORE_SESSION_H
diff --git a/src/server/bnetserver/Server/SessionManager.cpp b/src/server/bnetserver/Server/SessionManager.cpp
index cb56972a9c2..4c5b532ee60 100644
--- a/src/server/bnetserver/Server/SessionManager.cpp
+++ b/src/server/bnetserver/Server/SessionManager.cpp
@@ -16,7 +16,6 @@
*/
#include "SessionManager.h"
-#include "DatabaseEnv.h"
#include "Util.h"
bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
@@ -24,18 +23,16 @@ bool Battlenet::SessionManager::StartNetwork(Trinity::Asio::IoContext& ioContext
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
- _acceptor->AsyncAcceptWithCallback<&OnSocketAccept>();
+ _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
+ {
+ OnSocketOpen(std::move(sock), threadIndex);
+ });
return true;
}
-NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() const
+Trinity::Net::NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() const
{
- return new NetworkThread<Session>[GetNetworkThreadCount()];
-}
-
-void Battlenet::SessionManager::OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex)
-{
- sSessionMgr.OnSocketOpen(std::move(sock), threadIndex);
+ return new Trinity::Net::NetworkThread<Session>[GetNetworkThreadCount()];
}
Battlenet::SessionManager& Battlenet::SessionManager::Instance()
diff --git a/src/server/bnetserver/Server/SessionManager.h b/src/server/bnetserver/Server/SessionManager.h
index c635122c977..528ece8739e 100644
--- a/src/server/bnetserver/Server/SessionManager.h
+++ b/src/server/bnetserver/Server/SessionManager.h
@@ -15,15 +15,15 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef SessionManager_h__
-#define SessionManager_h__
+#ifndef TRINITYCORE_SESSION_MANAGER_H
+#define TRINITYCORE_SESSION_MANAGER_H
#include "SocketMgr.h"
#include "Session.h"
namespace Battlenet
{
- class SessionManager : public SocketMgr<Session>
+ class SessionManager : public Trinity::Net::SocketMgr<Session>
{
typedef SocketMgr<Session> BaseSocketMgr;
@@ -33,13 +33,10 @@ namespace Battlenet
bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount = 1) override;
protected:
- NetworkThread<Session>* CreateThreads() const override;
-
- private:
- static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex);
+ Trinity::Net::NetworkThread<Session>* CreateThreads() const override;
};
}
#define sSessionMgr Battlenet::SessionManager::Instance()
-#endif // SessionManager_h__
+#endif // TRINITYCORE_SESSION_MANAGER_H
diff --git a/src/server/game/Scripting/ScriptMgr.cpp b/src/server/game/Scripting/ScriptMgr.cpp
index 99f25980a0d..ac1cb69b6a1 100644
--- a/src/server/game/Scripting/ScriptMgr.cpp
+++ b/src/server/game/Scripting/ScriptMgr.cpp
@@ -1474,14 +1474,14 @@ void ScriptMgr::OnNetworkStop()
FOREACH_SCRIPT(ServerScript)->OnNetworkStop();
}
-void ScriptMgr::OnSocketOpen(std::shared_ptr<WorldSocket> socket)
+void ScriptMgr::OnSocketOpen(std::shared_ptr<WorldSocket> const& socket)
{
ASSERT(socket);
FOREACH_SCRIPT(ServerScript)->OnSocketOpen(socket);
}
-void ScriptMgr::OnSocketClose(std::shared_ptr<WorldSocket> socket)
+void ScriptMgr::OnSocketClose(std::shared_ptr<WorldSocket> const& socket)
{
ASSERT(socket);
diff --git a/src/server/game/Scripting/ScriptMgr.h b/src/server/game/Scripting/ScriptMgr.h
index d7746b01e04..e3b7dc63981 100644
--- a/src/server/game/Scripting/ScriptMgr.h
+++ b/src/server/game/Scripting/ScriptMgr.h
@@ -1067,8 +1067,8 @@ class TC_GAME_API ScriptMgr
void OnNetworkStart();
void OnNetworkStop();
- void OnSocketOpen(std::shared_ptr<WorldSocket> socket);
- void OnSocketClose(std::shared_ptr<WorldSocket> socket);
+ void OnSocketOpen(std::shared_ptr<WorldSocket> const& socket);
+ void OnSocketClose(std::shared_ptr<WorldSocket> const& socket);
void OnPacketReceive(WorldSession* session, WorldPacket const& packet);
void OnPacketSend(WorldSession* session, WorldPacket const& packet);
diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp
index 4b2ad6c3c5c..2b868ff3884 100644
--- a/src/server/game/Server/WorldSocket.cpp
+++ b/src/server/game/Server/WorldSocket.cpp
@@ -26,6 +26,7 @@
#include "GameTime.h"
#include "HMAC.h"
#include "IPLocation.h"
+#include "IpBanCheckConnectionInitializer.h"
#include "PacketLog.h"
#include "ProtobufJSON.h"
#include "RealmList.h"
@@ -33,6 +34,7 @@
#include "RealmList.pb.h"
#include "ScriptMgr.h"
#include "SessionKeyGenerator.h"
+#include "SslStream.h"
#include "World.h"
#include "WorldPacket.h"
#include "WorldSession.h"
@@ -49,8 +51,6 @@ struct CompressedWorldPacket
#pragma pack(pop)
-std::string const WorldSocket::ServerConnectionInitialize("WORLD OF WARCRAFT CONNECTION - SERVER TO CLIENT - V2");
-std::string const WorldSocket::ClientConnectionInitialize("WORLD OF WARCRAFT CONNECTION - CLIENT TO SERVER - V2");
uint32 const WorldSocket::MinSizeForCompression = 0x400;
std::array<uint8, 32> const WorldSocket::AuthCheckSeed = { 0xDE, 0x3A, 0x2A, 0x8E, 0x6B, 0x89, 0x52, 0x66, 0x88, 0x9D, 0x7E, 0x7A, 0x77, 0x1D, 0x5D, 0x1F,
@@ -62,7 +62,7 @@ std::array<uint8, 32> const WorldSocket::ContinuedSessionSeed = { 0x56, 0x5C, 0x
std::array<uint8, 32> const WorldSocket::EncryptionKeySeed = { 0x71, 0xC9, 0xED, 0x5A, 0xA7, 0x0E, 0x4D, 0xFF, 0x4C, 0x36, 0xA6, 0x5A, 0x3E, 0x46, 0x8A, 0x4A,
0x5D, 0xA1, 0x48, 0xC8, 0x30, 0x47, 0x4A, 0xDE, 0xF6, 0x0D, 0x6C, 0xBE, 0x6F, 0xE4, 0x55, 0x73 };
-WorldSocket::WorldSocket(boost::asio::ip::tcp::socket&& socket) : Socket(std::move(socket)),
+WorldSocket::WorldSocket(Trinity::Net::IoContextTcpSocket&& socket) : BaseSocket(std::move(socket)),
_type(CONNECTION_TYPE_REALM), _key(0), _serverChallenge(), _sessionKey(), _encryptKey(), _OverSpeedPings(0),
_worldSession(nullptr), _authed(false), _canRequestHotfixes(true), _headerBuffer(sizeof(IncomingPacketHeader)), _sendBufferSize(4096), _compressionStream(nullptr)
{
@@ -77,127 +77,127 @@ WorldSocket::~WorldSocket()
}
}
-void WorldSocket::Start()
+struct WorldSocketProtocolInitializer final : Trinity::Net::SocketConnectionInitializer
{
- std::string ip_address = GetRemoteIpAddress().to_string();
- LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
- stmt->setString(0, ip_address);
+ static constexpr std::string_view ServerConnectionInitialize = "WORLD OF WARCRAFT CONNECTION - SERVER TO CLIENT - V2\n";
+ static constexpr std::string_view ClientConnectionInitialize = "WORLD OF WARCRAFT CONNECTION - CLIENT TO SERVER - V2\n";
- _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([self = shared_from_this()](PreparedQueryResult result)
- {
- self->CheckIpCallback(std::move(result));
- }));
-}
+ explicit WorldSocketProtocolInitializer(WorldSocket* socket) : _socket(socket) { }
-void WorldSocket::CheckIpCallback(PreparedQueryResult result)
-{
- if (result)
+ void Start() override
{
- bool banned = false;
- do
- {
- Field* fields = result->Fetch();
- if (fields[0].GetUInt64() != 0)
- banned = true;
+ _packetBuffer.Resize(ClientConnectionInitialize.length());
- } while (result->NextRow());
+ AsyncRead();
- if (banned)
- {
- TC_LOG_ERROR("network", "WorldSocket::CheckIpCallback: Sent Auth Response (IP {} banned).", GetRemoteIpAddress().to_string());
- DelayedCloseSocket();
- return;
- }
+ MessageBuffer initializer;
+ initializer.Write(ServerConnectionInitialize.data(), ServerConnectionInitialize.length());
+
+ // - IoContext.run thread, safe.
+ _socket->QueuePacket(std::move(initializer));
}
- _packetBuffer.Resize(ClientConnectionInitialize.length() + 1);
+ void AsyncRead()
+ {
+ _socket->AsyncRead(
+ [socketRef = _socket->weak_from_this(), self = static_pointer_cast<WorldSocketProtocolInitializer>(this->shared_from_this())]
+ {
+ if (!socketRef.expired())
+ return self->ReadHandler();
+
+ return Trinity::Net::SocketReadCallbackResult::Stop;
+ });
+ }
- AsyncReadWithCallback(&WorldSocket::InitializeHandler);
+ Trinity::Net::SocketReadCallbackResult ReadHandler();
- MessageBuffer initializer;
- initializer.Write(ServerConnectionInitialize.c_str(), ServerConnectionInitialize.length());
- initializer.Write("\n", 1);
+ void HandleDataReady();
- // - IoContext.run thread, safe.
- QueuePacket(std::move(initializer));
+private:
+ WorldSocket* _socket;
+ MessageBuffer _packetBuffer;
+};
+
+void WorldSocket::Start()
+{
+ // build initializer chain
+ std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers =
+ { {
+ std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<WorldSocket>>(this),
+ std::make_shared<WorldSocketProtocolInitializer>(this),
+ std::make_shared<Trinity::Net::ReadConnectionInitializer<WorldSocket>>(this),
+ } };
+
+ Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start();
}
-void WorldSocket::InitializeHandler(boost::system::error_code const& error, std::size_t transferedBytes)
+Trinity::Net::SocketReadCallbackResult WorldSocketProtocolInitializer::ReadHandler()
{
- if (error)
+ MessageBuffer& packet = _socket->GetReadBuffer();
+ if (packet.GetActiveSize() > 0 && _packetBuffer.GetRemainingSpace() > 0)
{
- CloseSocket();
- return;
+ // need to receive the header
+ std::size_t readHeaderSize = std::min(packet.GetActiveSize(), _packetBuffer.GetRemainingSpace());
+ _packetBuffer.Write(packet.GetReadPointer(), readHeaderSize);
+ packet.ReadCompleted(readHeaderSize);
+
+ if (_packetBuffer.GetRemainingSpace() == 0)
+ {
+ HandleDataReady();
+ return Trinity::Net::SocketReadCallbackResult::Stop;
+ }
+
+ // Couldn't receive the whole header this time.
+ ASSERT(packet.GetActiveSize() == 0);
}
- GetReadBuffer().WriteCompleted(transferedBytes);
+ return Trinity::Net::SocketReadCallbackResult::KeepReading;
+}
- MessageBuffer& packet = GetReadBuffer();
- if (packet.GetActiveSize() > 0)
+void WorldSocketProtocolInitializer::HandleDataReady()
+{
+ try
{
- if (_packetBuffer.GetRemainingSpace() > 0)
+ ByteBuffer buffer(std::move(_packetBuffer));
+ if (buffer.ReadString(ClientConnectionInitialize.length()) != ClientConnectionInitialize)
{
- // need to receive the header
- std::size_t readHeaderSize = std::min(packet.GetActiveSize(), _packetBuffer.GetRemainingSpace());
- _packetBuffer.Write(packet.GetReadPointer(), readHeaderSize);
- packet.ReadCompleted(readHeaderSize);
-
- if (_packetBuffer.GetRemainingSpace() > 0)
- {
- // Couldn't receive the whole header this time.
- ASSERT(packet.GetActiveSize() == 0);
- AsyncReadWithCallback(&WorldSocket::InitializeHandler);
- return;
- }
-
- try
- {
- ByteBuffer buffer(std::move(_packetBuffer));
- std::string initializer(buffer.ReadString(ClientConnectionInitialize.length()));
- if (initializer != ClientConnectionInitialize)
- {
- CloseSocket();
- return;
- }
+ _socket->CloseSocket();
+ return;
+ }
+ }
+ catch (ByteBufferException const& ex)
+ {
+ TC_LOG_ERROR("network", "WorldSocket::InitializeHandler ByteBufferException {} occured while parsing initial packet from {}",
+ ex.what(), _socket->GetRemoteIpAddress().to_string());
+ _socket->CloseSocket();
+ return;
+ }
- uint8 terminator;
- buffer >> terminator;
- if (terminator != '\n')
- {
- CloseSocket();
- return;
- }
- }
- catch (ByteBufferException const& ex)
- {
- TC_LOG_ERROR("network", "WorldSocket::InitializeHandler ByteBufferException {} occured while parsing initial packet from {}",
- ex.what(), GetRemoteIpAddress().to_string());
- CloseSocket();
- return;
- }
+ if (!_socket->InitializeCompression())
+ return;
- _compressionStream = new z_stream();
- _compressionStream->zalloc = (alloc_func)nullptr;
- _compressionStream->zfree = (free_func)nullptr;
- _compressionStream->opaque = (voidpf)nullptr;
- _compressionStream->avail_in = 0;
- _compressionStream->next_in = nullptr;
- int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
- if (z_res != Z_OK)
- {
- CloseSocket();
- TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: {} ({})", z_res, zError(z_res));
- return;
- }
+ _socket->SendAuthSession();
+ if (next)
+ next->Start();
+}
- _packetBuffer.Reset();
- HandleSendAuthSession();
- AsyncRead();
- return;
- }
+bool WorldSocket::InitializeCompression()
+{
+ _compressionStream = new z_stream();
+ _compressionStream->zalloc = (alloc_func)nullptr;
+ _compressionStream->zfree = (free_func)nullptr;
+ _compressionStream->opaque = (voidpf)nullptr;
+ _compressionStream->avail_in = 0;
+ _compressionStream->next_in = nullptr;
+ int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
+ if (z_res != Z_OK)
+ {
+ CloseSocket();
+ TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: {} ({})", z_res, zError(z_res));
+ return false;
}
- AsyncReadWithCallback(&WorldSocket::InitializeHandler);
+ return true;
}
bool WorldSocket::Update()
@@ -206,7 +206,7 @@ bool WorldSocket::Update()
MessageBuffer buffer(_sendBufferSize);
while (_bufferQueue.Dequeue(queued))
{
- uint32 packetSize = queued->size() + 2 /*opcode*/;
+ uint32 packetSize = queued->size() + 4 /*opcode*/;
if (packetSize > MinSizeForCompression && queued->NeedsEncryption())
packetSize = deflateBound(_compressionStream, packetSize) + sizeof(CompressedWorldPacket);
@@ -240,7 +240,7 @@ bool WorldSocket::Update()
return true;
}
-void WorldSocket::HandleSendAuthSession()
+void WorldSocket::SendAuthSession()
{
Trinity::Crypto::GetRandomBytes(_serverChallenge);
@@ -260,11 +260,8 @@ void WorldSocket::OnClose()
}
}
-void WorldSocket::ReadHandler()
+Trinity::Net::SocketReadCallbackResult WorldSocket::ReadHandler()
{
- if (!IsOpen())
- return;
-
MessageBuffer& packet = GetReadBuffer();
while (packet.GetActiveSize() > 0)
{
@@ -286,7 +283,7 @@ void WorldSocket::ReadHandler()
if (!ReadHeaderHandler())
{
CloseSocket();
- return;
+ return Trinity::Net::SocketReadCallbackResult::Stop;
}
}
@@ -314,11 +311,16 @@ void WorldSocket::ReadHandler()
if (result != ReadDataHandlerResult::WaitingForQuery)
CloseSocket();
- return;
+ return Trinity::Net::SocketReadCallbackResult::Stop;
}
}
- AsyncRead();
+ return Trinity::Net::SocketReadCallbackResult::KeepReading;
+}
+
+void WorldSocket::QueueQuery(QueryCallback&& queryCallback)
+{
+ _queryProcessor.AddCallback(std::move(queryCallback));
}
void WorldSocket::SetWorldSession(WorldSession* session)
@@ -510,14 +512,13 @@ WorldSocket::ReadDataHandlerResult WorldSocket::ReadDataHandler()
void WorldSocket::LogOpcodeText(OpcodeClient opcode, std::unique_lock<std::mutex> const& guard) const
{
- if (!guard)
+ if (!guard || !_worldSession)
{
TC_LOG_TRACE("network.opcode", "C->S: {} {}", GetRemoteIpAddress().to_string(), GetOpcodeNameForLogging(opcode));
}
else
{
- TC_LOG_TRACE("network.opcode", "C->S: {} {}", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress().to_string()),
- GetOpcodeNameForLogging(opcode));
+ TC_LOG_TRACE("network.opcode", "C->S: {} {}", _worldSession->GetPlayerInfo(), GetOpcodeNameForLogging(opcode));
}
}
@@ -688,7 +689,7 @@ void WorldSocket::HandleAuthSession(std::shared_ptr<WorldPackets::Auth::AuthSess
stmt->setInt32(0, int32(sRealmList->GetCurrentRealmId().Realm));
stmt->setString(1, joinTicket->gameaccount());
- _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession), joinTicket = std::move(joinTicket)](PreparedQueryResult result) mutable
+ QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession), joinTicket = std::move(joinTicket)](PreparedQueryResult result) mutable
{
HandleAuthSessionCallback(std::move(authSession), std::move(joinTicket), std::move(result));
}));
@@ -898,19 +899,20 @@ void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<WorldPackets::Auth::
sScriptMgr->OnAccountLogin(account.Game.Id);
_authed = true;
- _worldSession = new WorldSession(account.Game.Id, std::move(*joinTicket->mutable_gameaccount()), account.BattleNet.Id, shared_from_this(), account.Game.Security,
- account.Game.Expansion, mutetime, account.Game.OS, account.Game.TimezoneOffset, account.Game.Build, buildVariant, account.Game.Locale,
+ _worldSession = new WorldSession(account.Game.Id, std::move(*joinTicket->mutable_gameaccount()), account.BattleNet.Id,
+ static_pointer_cast<WorldSocket>(shared_from_this()), account.Game.Security, account.Game.Expansion, mutetime,
+ account.Game.OS, account.Game.TimezoneOffset, account.Game.Build, buildVariant, account.Game.Locale,
account.Game.Recruiter, account.Game.IsRectuiter);
// Initialize Warden system only if it is enabled by config
if (wardenActive)
_worldSession->InitWarden(_sessionKey);
- _queryProcessor.AddCallback(_worldSession->LoadPermissionsAsync().WithPreparedCallback([this](PreparedQueryResult result)
+ QueueQuery(_worldSession->LoadPermissionsAsync().WithPreparedCallback([this](PreparedQueryResult result)
{
LoadSessionPermissionsCallback(std::move(result));
}));
- AsyncRead();
+ AsyncRead(Trinity::Net::InvokeReadHandlerCallback<WorldSocket>{ .Socket = this });
}
void WorldSocket::LoadSessionPermissionsCallback(PreparedQueryResult result)
@@ -938,7 +940,7 @@ void WorldSocket::HandleAuthContinuedSession(std::shared_ptr<WorldPackets::Auth:
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_CONTINUED_SESSION);
stmt->setUInt32(0, accountId);
- _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession)](PreparedQueryResult result) mutable
+ QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession)](PreparedQueryResult result) mutable
{
HandleAuthContinuedSessionCallback(std::move(authSession), std::move(result));
}));
@@ -985,7 +987,7 @@ void WorldSocket::HandleAuthContinuedSessionCallback(std::shared_ptr<WorldPacket
memcpy(_encryptKey.data(), encryptKeyGen.GetDigest().data(), 32);
SendPacketAndLogOpcode(*WorldPackets::Auth::EnterEncryptedMode(_encryptKey, true).Write());
- AsyncRead();
+ AsyncRead(Trinity::Net::InvokeReadHandlerCallback<WorldSocket>{ .Socket = this });
}
void WorldSocket::HandleConnectToFailed(WorldPackets::Auth::ConnectToFailed& connectToFailed)
@@ -1032,7 +1034,7 @@ void WorldSocket::HandleEnterEncryptedModeAck()
if (_type == CONNECTION_TYPE_REALM)
sWorld->AddSession(_worldSession);
else
- sWorld->AddInstanceSocket(shared_from_this(), _key);
+ sWorld->AddInstanceSocket(static_pointer_cast<WorldSocket>(shared_from_this()), _key);
}
void WorldSocket::SendAuthResponseError(uint32 code)
diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h
index 04451dc26e8..b5388532b4d 100644
--- a/src/server/game/Server/WorldSocket.h
+++ b/src/server/game/Server/WorldSocket.h
@@ -15,8 +15,8 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef __WORLDSOCKET_H__
-#define __WORLDSOCKET_H__
+#ifndef TRINITYCORE_WORLD_SOCKET_H
+#define TRINITYCORE_WORLD_SOCKET_H
#include "AsyncCallbackProcessor.h"
#include "AuthDefines.h"
@@ -77,7 +77,7 @@ struct PacketHeader
uint32 Size;
uint8 Tag[12];
- bool IsValidSize() { return Size < 0x10000; }
+ bool IsValidSize() const { return Size < 0x10000; }
};
struct IncomingPacketHeader : PacketHeader
@@ -87,10 +87,8 @@ struct IncomingPacketHeader : PacketHeader
#pragma pack(pop)
-class TC_GAME_API WorldSocket : public Socket<WorldSocket>
+class TC_GAME_API WorldSocket final : public Trinity::Net::Socket<>
{
- static std::string const ServerConnectionInitialize;
- static std::string const ClientConnectionInitialize;
static uint32 const MinSizeForCompression;
static std::array<uint8, 32> const AuthCheckSeed;
@@ -98,14 +96,16 @@ class TC_GAME_API WorldSocket : public Socket<WorldSocket>
static std::array<uint8, 32> const ContinuedSessionSeed;
static std::array<uint8, 32> const EncryptionKeySeed;
- typedef Socket<WorldSocket> BaseSocket;
+ using BaseSocket = Socket;
public:
- WorldSocket(boost::asio::ip::tcp::socket&& socket);
+ WorldSocket(Trinity::Net::IoContextTcpSocket&& socket);
~WorldSocket();
WorldSocket(WorldSocket const& right) = delete;
+ WorldSocket(WorldSocket&& right) = delete;
WorldSocket& operator=(WorldSocket const& right) = delete;
+ WorldSocket& operator=(WorldSocket&& right) = delete;
void Start() override;
bool Update() override;
@@ -118,9 +118,15 @@ public:
void SetWorldSession(WorldSession* session);
void SetSendBufferSize(std::size_t sendBufferSize) { _sendBufferSize = sendBufferSize; }
-protected:
void OnClose() override;
- void ReadHandler() override;
+ Trinity::Net::SocketReadCallbackResult ReadHandler() override;
+
+ void QueueQuery(QueryCallback&& queryCallback);
+
+ void SendAuthSession();
+ bool InitializeCompression();
+
+protected:
bool ReadHeaderHandler();
enum class ReadDataHandlerResult
@@ -132,9 +138,6 @@ protected:
ReadDataHandlerResult ReadDataHandler();
private:
- void CheckIpCallback(PreparedQueryResult result);
- void InitializeHandler(boost::system::error_code const& error, std::size_t transferedBytes);
-
/// writes network.opcode log
/// accessing WorldSession is not threadsafe, only do it when holding _worldSessionLock
void LogOpcodeText(OpcodeClient opcode, std::unique_lock<std::mutex> const& guard) const;
@@ -143,7 +146,6 @@ private:
void WritePacketToBuffer(EncryptablePacket const& packet, MessageBuffer& buffer);
uint32 CompressPacket(uint8* buffer, WorldPacket const& packet);
- void HandleSendAuthSession();
void HandleAuthSession(std::shared_ptr<WorldPackets::Auth::AuthSession> authSession);
void HandleAuthSessionCallback(std::shared_ptr<WorldPackets::Auth::AuthSession> authSession,
std::shared_ptr<JSON::RealmList::RealmJoinTicket> joinTicket, PreparedQueryResult result);
diff --git a/src/server/game/Server/WorldSocketMgr.cpp b/src/server/game/Server/WorldSocketMgr.cpp
index 58f242b52f8..cc44e5bad4b 100644
--- a/src/server/game/Server/WorldSocketMgr.cpp
+++ b/src/server/game/Server/WorldSocketMgr.cpp
@@ -22,21 +22,16 @@
#include "WorldSocket.h"
#include <boost/system/error_code.hpp>
-static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex)
-{
- sWorldSocketMgr.OnSocketOpen(std::move(sock), threadIndex);
-}
-
-class WorldSocketThread : public NetworkThread<WorldSocket>
+class WorldSocketThread : public Trinity::Net::NetworkThread<WorldSocket>
{
public:
- void SocketAdded(std::shared_ptr<WorldSocket> sock) override
+ void SocketAdded(std::shared_ptr<WorldSocket> const& sock) override
{
sock->SetSendBufferSize(sWorldSocketMgr.GetApplicationSendBufferSize());
sScriptMgr->OnSocketOpen(sock);
}
- void SocketRemoved(std::shared_ptr<WorldSocket> sock) override
+ void SocketRemoved(std::shared_ptr<WorldSocket>const& sock) override
{
sScriptMgr->OnSocketClose(sock);
}
@@ -75,7 +70,10 @@ bool WorldSocketMgr::StartNetwork(Trinity::Asio::IoContext& ioContext, std::stri
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
- _acceptor->AsyncAcceptWithCallback<&OnSocketAccept>();
+ _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
+ {
+ OnSocketOpen(std::move(sock), threadIndex);
+ });
sScriptMgr->OnNetworkStart();
return true;
@@ -88,7 +86,7 @@ void WorldSocketMgr::StopNetwork()
sScriptMgr->OnNetworkStop();
}
-void WorldSocketMgr::OnSocketOpen(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex)
+void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
{
// set some options here
if (_socketSystemSendBufferSize >= 0)
@@ -117,7 +115,7 @@ void WorldSocketMgr::OnSocketOpen(boost::asio::ip::tcp::socket&& sock, uint32 th
BaseSocketMgr::OnSocketOpen(std::move(sock), threadIndex);
}
-NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const
+Trinity::Net::NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const
{
return new WorldSocketThread[GetNetworkThreadCount()];
}
diff --git a/src/server/game/Server/WorldSocketMgr.h b/src/server/game/Server/WorldSocketMgr.h
index 84b190575a4..8859da81074 100644
--- a/src/server/game/Server/WorldSocketMgr.h
+++ b/src/server/game/Server/WorldSocketMgr.h
@@ -15,21 +15,15 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-/** \addtogroup u2w User to World Communication
- * @{
- * \file WorldSocketMgr.h
- * \author Derex <derex101@gmail.com>
- */
-
-#ifndef __WORLDSOCKETMGR_H
-#define __WORLDSOCKETMGR_H
+#ifndef TRINITYCORE_WORLD_SOCKET_MGR_H
+#define TRINITYCORE_WORLD_SOCKET_MGR_H
#include "SocketMgr.h"
class WorldSocket;
/// Manages all sockets connected to peers and network threads
-class TC_GAME_API WorldSocketMgr : public SocketMgr<WorldSocket>
+class TC_GAME_API WorldSocketMgr : public Trinity::Net::SocketMgr<WorldSocket>
{
typedef SocketMgr<WorldSocket> BaseSocketMgr;
@@ -44,14 +38,14 @@ public:
/// Stops all network threads, It will wait for all running threads .
void StopNetwork() override;
- void OnSocketOpen(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex) override;
+ void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) override;
std::size_t GetApplicationSendBufferSize() const { return _socketApplicationSendBufferSize; }
protected:
WorldSocketMgr();
- NetworkThread<WorldSocket>* CreateThreads() const override;
+ Trinity::Net::NetworkThread<WorldSocket>* CreateThreads() const override;
private:
int32 _socketSystemSendBufferSize;
@@ -62,4 +56,3 @@ private:
#define sWorldSocketMgr WorldSocketMgr::Instance()
#endif
-/// @}
diff --git a/src/server/shared/Networking/AsyncAcceptor.h b/src/server/shared/Networking/AsyncAcceptor.h
index 95e25e02166..dd0857c2b38 100644
--- a/src/server/shared/Networking/AsyncAcceptor.h
+++ b/src/server/shared/Networking/AsyncAcceptor.h
@@ -15,12 +15,13 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef __ASYNCACCEPT_H_
-#define __ASYNCACCEPT_H_
+#ifndef TRINITYCORE_ASYNC_ACCEPTOR_H
+#define TRINITYCORE_ASYNC_ACCEPTOR_H
#include "IoContext.h"
#include "IpAddress.h"
#include "Log.h"
+#include "Socket.h"
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/ip/v6_only.hpp>
#include <atomic>
@@ -28,28 +29,28 @@
#define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections
+namespace Trinity::Net
+{
+template <typename Callable>
+concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&, uint32>;
+
class AsyncAcceptor
{
public:
- typedef void(*AcceptCallback)(boost::asio::ip::tcp::socket&& newSocket, uint32 threadIndex);
-
- AsyncAcceptor(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
- _acceptor(ioContext), _endpoint(Trinity::Net::make_address(bindIp), port),
+ AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
+ _acceptor(ioContext), _endpoint(make_address(bindIp), port),
_socket(ioContext), _closed(false), _socketFactory([this] { return DefeaultSocketFactory(); })
{
}
- template<class T>
- void AsyncAccept();
-
- template<AcceptCallback acceptCallback>
- void AsyncAcceptWithCallback()
+ template <AcceptCallback Callback>
+ void AsyncAccept(Callback&& acceptCallback)
{
auto [tmpSocket, tmpThreadIndex] = _socketFactory();
// TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures)
- boost::asio::ip::tcp::socket* socket = tmpSocket;
+ IoContextTcpSocket* socket = tmpSocket;
uint32 threadIndex = tmpThreadIndex;
- _acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error)
+ _acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error) mutable
{
if (!error)
{
@@ -66,7 +67,7 @@ public:
}
if (!_closed)
- this->AsyncAcceptWithCallback<acceptCallback>();
+ this->AsyncAccept(std::move(acceptCallback));
});
}
@@ -120,40 +121,17 @@ public:
_acceptor.close(err);
}
- void SetSocketFactory(std::function<std::pair<boost::asio::ip::tcp::socket*, uint32>()> func) { _socketFactory = std::move(func); }
+ void SetSocketFactory(std::function<std::pair<IoContextTcpSocket*, uint32>()> func) { _socketFactory = std::move(func); }
private:
- std::pair<boost::asio::ip::tcp::socket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
+ std::pair<IoContextTcpSocket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
- boost::asio::ip::tcp::acceptor _acceptor;
+ boost::asio::basic_socket_acceptor<boost::asio::ip::tcp, IoContextTcpSocket::executor_type> _acceptor;
boost::asio::ip::tcp::endpoint _endpoint;
- boost::asio::ip::tcp::socket _socket;
+ IoContextTcpSocket _socket;
std::atomic<bool> _closed;
- std::function<std::pair<boost::asio::ip::tcp::socket*, uint32>()> _socketFactory;
+ std::function<std::pair<IoContextTcpSocket*, uint32>()> _socketFactory;
};
-
-template<class T>
-void AsyncAcceptor::AsyncAccept()
-{
- _acceptor.async_accept(_socket, [this](boost::system::error_code error)
- {
- if (!error)
- {
- try
- {
- // this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class
- std::make_shared<T>(std::move(this->_socket))->Start();
- }
- catch (boost::system::system_error const& err)
- {
- TC_LOG_INFO("network", "Failed to retrieve client's remote address {}", err.what());
- }
- }
-
- // lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face
- if (!_closed)
- this->AsyncAccept<T>();
- });
}
-#endif /* __ASYNCACCEPT_H_ */
+#endif // TRINITYCORE_ASYNC_ACCEPTOR_H
diff --git a/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp
new file mode 100644
index 00000000000..5996c40faee
--- /dev/null
+++ b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp
@@ -0,0 +1,42 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "IpBanCheckConnectionInitializer.h"
+#include "DatabaseEnv.h"
+
+QueryCallback Trinity::Net::IpBanCheckHelpers::AsyncQuery(std::string_view ipAddress)
+{
+ LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
+ stmt->setString(0, ipAddress);
+ return LoginDatabase.AsyncQuery(stmt);
+}
+
+bool Trinity::Net::IpBanCheckHelpers::IsBanned(PreparedQueryResult const& result)
+{
+ if (result)
+ {
+ do
+ {
+ Field* fields = result->Fetch();
+ if (fields[0].GetUInt64() != 0)
+ return true;
+
+ } while (result->NextRow());
+ }
+
+ return false;
+}
diff --git a/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h
new file mode 100644
index 00000000000..ff8210a7f69
--- /dev/null
+++ b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h
@@ -0,0 +1,64 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H
+#define TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H
+
+#include "DatabaseEnvFwd.h"
+#include "Log.h"
+#include "QueryCallback.h"
+#include "SocketConnectionInitializer.h"
+
+namespace Trinity::Net
+{
+namespace IpBanCheckHelpers
+{
+TC_SHARED_API QueryCallback AsyncQuery(std::string_view ipAddress);
+TC_SHARED_API bool IsBanned(PreparedQueryResult const& result);
+}
+
+template <typename SocketImpl>
+struct IpBanCheckConnectionInitializer final : SocketConnectionInitializer
+{
+ explicit IpBanCheckConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
+
+ void Start() override
+ {
+ _socket->QueueQuery(IpBanCheckHelpers::AsyncQuery(_socket->GetRemoteIpAddress().to_string()).WithPreparedCallback([socketRef = _socket->weak_from_this(), self = this->shared_from_this()](PreparedQueryResult const& result)
+ {
+ std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock());
+ if (!socket)
+ return;
+
+ if (IpBanCheckHelpers::IsBanned(result))
+ {
+ TC_LOG_ERROR("network", "IpBanCheckConnectionInitializer: IP {} is banned.", socket->GetRemoteIpAddress().to_string());
+ socket->DelayedCloseSocket();
+ return;
+ }
+
+ if (self->next)
+ self->next->Start();
+ }));
+ }
+
+private:
+ SocketImpl* _socket;
+};
+}
+
+#endif // TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H
diff --git a/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h b/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h
new file mode 100644
index 00000000000..d3f0bb16dbf
--- /dev/null
+++ b/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h
@@ -0,0 +1,51 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
+#define TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
+
+#include <memory>
+#include <span>
+
+namespace Trinity::Net
+{
+struct SocketConnectionInitializer : public std::enable_shared_from_this<SocketConnectionInitializer>
+{
+ SocketConnectionInitializer() = default;
+
+ SocketConnectionInitializer(SocketConnectionInitializer const&) = delete;
+ SocketConnectionInitializer(SocketConnectionInitializer&&) noexcept = default;
+ SocketConnectionInitializer& operator=(SocketConnectionInitializer const&) = delete;
+ SocketConnectionInitializer& operator=(SocketConnectionInitializer&&) noexcept = default;
+
+ virtual ~SocketConnectionInitializer() = default;
+
+ virtual void Start() = 0;
+
+ std::shared_ptr<SocketConnectionInitializer> next;
+
+ static std::shared_ptr<SocketConnectionInitializer>& SetupChain(std::span<std::shared_ptr<SocketConnectionInitializer>> initializers)
+ {
+ for (std::size_t i = initializers.size(); i > 1; --i)
+ initializers[i - 2]->next.swap(initializers[i - 1]);
+
+ return initializers[0];
+ }
+};
+}
+
+#endif // TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
diff --git a/src/server/shared/Networking/Http/BaseHttpSocket.cpp b/src/server/shared/Networking/Http/BaseHttpSocket.cpp
index ca92c442aa2..33053399c14 100644
--- a/src/server/shared/Networking/Http/BaseHttpSocket.cpp
+++ b/src/server/shared/Networking/Http/BaseHttpSocket.cpp
@@ -16,6 +16,7 @@
*/
#include "BaseHttpSocket.h"
+#include <boost/asio/buffers_iterator.hpp>
#include <boost/beast/http/serializer.hpp>
namespace Trinity::Net::Http
@@ -112,4 +113,34 @@ MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response
return buffer;
}
+
+void AbstractSocket::LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const
+{
+ if (Logger const* logger = sLog->GetEnabledLogger("server.http", LOG_LEVEL_DEBUG))
+ {
+ std::string clientInfo = GetClientInfo();
+ sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_DEBUG, "{} Request {} {} done, status {}", clientInfo,
+ ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()), context.response.result_int());
+ if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE))
+ {
+ sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: {}", clientInfo,
+ CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>");
+ sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: {}", clientInfo,
+ CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>");
+ }
+ }
+}
+
+std::string AbstractSocket::GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state)
+{
+ std::string info = StringFormat("[{}:{}", address.to_string(), port);
+ if (state)
+ {
+ info.append(", Session Id: ");
+ info.append(boost::uuids::to_string(state->Id));
+ }
+
+ info += ']';
+ return info;
+}
}
diff --git a/src/server/shared/Networking/Http/BaseHttpSocket.h b/src/server/shared/Networking/Http/BaseHttpSocket.h
index c02b13e6463..4b7c3bd9dd1 100644
--- a/src/server/shared/Networking/Http/BaseHttpSocket.h
+++ b/src/server/shared/Networking/Http/BaseHttpSocket.h
@@ -25,13 +25,46 @@
#include "Optional.h"
#include "QueryCallback.h"
#include "Socket.h"
-#include <boost/asio/buffers_iterator.hpp>
+#include "SocketConnectionInitializer.h"
+#include <boost/beast/core/basic_stream.hpp>
#include <boost/beast/http/parser.hpp>
#include <boost/beast/http/string_body.hpp>
#include <boost/uuid/uuid_io.hpp>
namespace Trinity::Net::Http
{
+using IoContextHttpSocket = boost::beast::basic_stream<boost::asio::ip::tcp, boost::asio::io_context::executor_type, boost::beast::unlimited_rate_policy>;
+
+namespace Impl
+{
+class BoostBeastSocketWrapper : public IoContextHttpSocket
+{
+public:
+ using IoContextHttpSocket::basic_stream;
+
+ void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
+ {
+ socket().shutdown(what, shutdownError);
+ }
+
+ void close(boost::system::error_code& /*error*/)
+ {
+ IoContextHttpSocket::close();
+ }
+
+ template<typename WaitHandlerType>
+ void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
+ {
+ socket().async_wait(type, std::forward<WaitHandlerType>(handler));
+ }
+
+ IoContextTcpSocket::endpoint_type remote_endpoint() const
+ {
+ return socket().remote_endpoint();
+ }
+};
+}
+
using RequestParser = boost::beast::http::request_parser<RequestBody>;
class TC_SHARED_API AbstractSocket
@@ -51,22 +84,59 @@ public:
virtual void SendResponse(RequestContext& context) = 0;
+ void LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const;
+
virtual void QueueQuery(QueryCallback&& queryCallback) = 0;
virtual std::string GetClientInfo() const = 0;
- virtual Optional<boost::uuids::uuid> GetSessionId() const = 0;
+ static std::string GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state);
+
+ virtual SessionState* GetSessionState() const = 0;
+
+ Optional<boost::uuids::uuid> GetSessionId() const
+ {
+ if (SessionState* state = this->GetSessionState())
+ return state->Id;
+
+ return {};
+ }
+
+ virtual void Start() = 0;
+
+ virtual bool Update() = 0;
+
+ virtual boost::asio::ip::address const& GetRemoteIpAddress() const = 0;
+
+ virtual bool IsOpen() const = 0;
+
+ virtual void CloseSocket() = 0;
};
-template<typename Derived, typename Stream>
-class BaseSocket : public ::Socket<Derived, Stream>, public AbstractSocket
+template <typename SocketImpl>
+struct HttpConnectionInitializer final : SocketConnectionInitializer
{
- using Base = ::Socket<Derived, Stream>;
+ explicit HttpConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
+
+ void Start() override
+ {
+ _socket->ResetHttpParser();
+
+ if (this->next)
+ this->next->Start();
+ }
+
+private:
+ SocketImpl* _socket;
+};
+
+template<typename Stream>
+class BaseSocket : public Trinity::Net::Socket<Stream>, public AbstractSocket
+{
+ using Base = Trinity::Net::Socket<Stream>;
public:
- template<typename... Args>
- explicit BaseSocket(boost::asio::ip::tcp::socket&& socket, Args&&... args)
- : Base(std::move(socket), std::forward<Args>(args)...) { }
+ using Base::Base;
BaseSocket(BaseSocket const& other) = delete;
BaseSocket(BaseSocket&& other) = delete;
@@ -75,11 +145,8 @@ public:
~BaseSocket() = default;
- void ReadHandler() override
+ SocketReadCallbackResult ReadHandler() final
{
- if (!this->IsOpen())
- return;
-
MessageBuffer& packet = this->GetReadBuffer();
while (packet.GetActiveSize() > 0)
{
@@ -92,13 +159,13 @@ public:
if (!HandleMessage(_httpParser->get()))
{
this->CloseSocket();
- break;
+ return SocketReadCallbackResult::Stop;
}
this->ResetHttpParser();
}
- this->AsyncRead();
+ return SocketReadCallbackResult::KeepReading;
}
bool HandleMessage(Request& request)
@@ -118,19 +185,11 @@ public:
virtual RequestHandlerResult RequestHandler(RequestContext& context) = 0;
- void SendResponse(RequestContext& context) override
+ void SendResponse(RequestContext& context) final
{
MessageBuffer buffer = SerializeResponse(context.request, context.response);
- TC_LOG_DEBUG("server.http", "{} Request {} {} done, status {}", this->GetClientInfo(), ToStdStringView(context.request.method_string()),
- ToStdStringView(context.request.target()), context.response.result_int());
- if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE))
- {
- sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Request: {}", this->GetClientInfo(),
- CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>");
- sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Response: {}", this->GetClientInfo(),
- CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>");
- }
+ this->LogRequestAndResponse(context, buffer);
this->QueuePacket(std::move(buffer));
@@ -138,11 +197,13 @@ public:
this->DelayedCloseSocket();
}
- void QueueQuery(QueryCallback&& queryCallback) override
+ void QueueQuery(QueryCallback&& queryCallback) final
{
this->_queryProcessor.AddCallback(std::move(queryCallback));
}
+ void Start() override { return this->Base::Start(); }
+
bool Update() override
{
if (!this->Base::Update())
@@ -152,27 +213,19 @@ public:
return true;
}
+ boost::asio::ip::address const& GetRemoteIpAddress() const final { return this->Base::GetRemoteIpAddress(); }
+
+ bool IsOpen() const final { return this->Base::IsOpen(); }
+
+ void CloseSocket() final { return this->Base::CloseSocket(); }
+
std::string GetClientInfo() const override
{
- std::string info;
- info.reserve(500);
- auto itr = StringFormatTo(std::back_inserter(info), "[{}:{}", this->GetRemoteIpAddress().to_string(), this->GetRemotePort());
- if (_state)
- itr = StringFormatTo(itr, ", Session Id: {}", boost::uuids::to_string(_state->Id));
-
- StringFormatTo(itr, "]");
- return info;
+ return AbstractSocket::GetClientInfo(this->GetRemoteIpAddress(), this->GetRemotePort(), this->_state.get());
}
- Optional<boost::uuids::uuid> GetSessionId() const final
- {
- if (this->_state)
- return this->_state->Id;
-
- return {};
- }
+ SessionState* GetSessionState() const override { return _state.get(); }
-protected:
void ResetHttpParser()
{
this->_httpParser.reset();
@@ -180,6 +233,7 @@ protected:
this->_httpParser->eager(true);
}
+protected:
virtual std::shared_ptr<SessionState> ObtainSessionState(RequestContext& context) const = 0;
QueryCallbackProcessor _queryProcessor;
diff --git a/src/server/shared/Networking/Http/HttpService.h b/src/server/shared/Networking/Http/HttpService.h
index e68f8d2795f..2a377c734da 100644
--- a/src/server/shared/Networking/Http/HttpService.h
+++ b/src/server/shared/Networking/Http/HttpService.h
@@ -160,7 +160,7 @@ protected:
class Thread : public NetworkThread<SessionImpl>
{
protected:
- void SocketRemoved(std::shared_ptr<SessionImpl> session) override
+ void SocketRemoved(std::shared_ptr<SessionImpl> const& session) override
{
if (Optional<boost::uuids::uuid> id = session->GetSessionId())
_service->MarkSessionInactive(*id);
diff --git a/src/server/shared/Networking/Http/HttpSocket.h b/src/server/shared/Networking/Http/HttpSocket.h
index 9a333a8e779..2cfc3ba8ed8 100644
--- a/src/server/shared/Networking/Http/HttpSocket.h
+++ b/src/server/shared/Networking/Http/HttpSocket.h
@@ -19,43 +19,16 @@
#define TRINITYCORE_HTTP_SOCKET_H
#include "BaseHttpSocket.h"
-#include <boost/beast/core/tcp_stream.hpp>
+#include <array>
namespace Trinity::Net::Http
{
-namespace Impl
+class Socket : public BaseSocket<Impl::BoostBeastSocketWrapper>
{
-class BoostBeastSocketWrapper : public boost::beast::tcp_stream
-{
-public:
- using boost::beast::tcp_stream::tcp_stream;
-
- void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
- {
- socket().shutdown(what, shutdownError);
- }
-
- void close(boost::system::error_code& /*error*/)
- {
- boost::beast::tcp_stream::close();
- }
-
- boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const
- {
- return socket().remote_endpoint();
- }
-};
-}
-
-template <typename Derived>
-class Socket : public BaseSocket<Derived, Impl::BoostBeastSocketWrapper>
-{
- using SocketBase = BaseSocket<Derived, Impl::BoostBeastSocketWrapper>;
+ using SocketBase = BaseSocket<Impl::BoostBeastSocketWrapper>;
public:
- template<typename... Args>
- explicit Socket(boost::asio::ip::tcp::socket&& socket, Args&&...)
- : SocketBase(std::move(socket)) { }
+ using SocketBase::SocketBase;
Socket(Socket const& other) = delete;
Socket(Socket&& other) = delete;
@@ -66,9 +39,13 @@ public:
void Start() override
{
- this->ResetHttpParser();
+ std::array<std::shared_ptr<SocketConnectionInitializer>, 2> initializers =
+ { {
+ std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
+ std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
+ } };
- this->AsyncRead();
+ SocketConnectionInitializer::SetupChain(initializers)->Start();
}
};
}
diff --git a/src/server/shared/Networking/Http/HttpSslSocket.h b/src/server/shared/Networking/Http/HttpSslSocket.h
index cdb70645e05..c789cbfefaf 100644
--- a/src/server/shared/Networking/Http/HttpSslSocket.h
+++ b/src/server/shared/Networking/Http/HttpSslSocket.h
@@ -19,47 +19,21 @@
#define TRINITYCORE_HTTP_SSL_SOCKET_H
#include "BaseHttpSocket.h"
-#include "SslSocket.h"
-#include <boost/beast/core/stream_traits.hpp>
-#include <boost/beast/core/tcp_stream.hpp>
-#include <boost/beast/ssl/ssl_stream.hpp>
+#include "SslStream.h"
namespace Trinity::Net::Http
{
-namespace Impl
+class SslSocket : public BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>
{
-class BoostBeastSslSocketWrapper : public ::SslSocket<boost::beast::ssl_stream<boost::beast::tcp_stream>>
-{
-public:
- using SslSocket::SslSocket;
-
- void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
- {
- _sslSocket.shutdown(shutdownError);
- boost::beast::get_lowest_layer(_sslSocket).socket().shutdown(what, shutdownError);
- }
-
- void close(boost::system::error_code& /*error*/)
- {
- boost::beast::get_lowest_layer(_sslSocket).close();
- }
-
- boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const
- {
- return boost::beast::get_lowest_layer(_sslSocket).socket().remote_endpoint();
- }
-};
-}
-
-template <typename Derived>
-class SslSocket : public BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper>
-{
- using SocketBase = BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper>;
+ using SocketBase = BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>;
public:
- explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext)
+ explicit SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext)
: SocketBase(std::move(socket), sslContext) { }
+ explicit SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext)
+ : SocketBase(context, sslContext) { }
+
SslSocket(SslSocket const& other) = delete;
SslSocket(SslSocket&& other) = delete;
SslSocket& operator=(SslSocket const& other) = delete;
@@ -69,27 +43,14 @@ public:
void Start() override
{
- this->AsyncHandshake();
- }
-
- void AsyncHandshake()
- {
- this->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
- [self = this->shared_from_this()](boost::system::error_code const& error) { self->HandshakeHandler(error); });
- }
-
- void HandshakeHandler(boost::system::error_code const& error)
- {
- if (error)
- {
- TC_LOG_ERROR("server.http.session.ssl", "{} SSL Handshake failed {}", this->GetClientInfo(), error.message());
- this->CloseSocket();
- return;
- }
-
- this->ResetHttpParser();
-
- this->AsyncRead();
+ std::array<std::shared_ptr<SocketConnectionInitializer>, 3> initializers =
+ { {
+ std::make_shared<SslHandshakeConnectionInitializer<SocketBase>>(this),
+ std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
+ std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
+ } };
+
+ SocketConnectionInitializer::SetupChain(initializers)->Start();
}
};
}
diff --git a/src/server/shared/Networking/NetworkThread.h b/src/server/shared/Networking/NetworkThread.h
index fc7d9647fc1..d16da442149 100644
--- a/src/server/shared/Networking/NetworkThread.h
+++ b/src/server/shared/Networking/NetworkThread.h
@@ -15,20 +15,24 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef NetworkThread_h__
-#define NetworkThread_h__
+#ifndef TRINITYCORE_NETWORK_THREAD_H
+#define TRINITYCORE_NETWORK_THREAD_H
-#include "Define.h"
+#include "Containers.h"
#include "DeadlineTimer.h"
+#include "Define.h"
#include "Errors.h"
#include "IoContext.h"
#include "Log.h"
+#include "Socket.h"
#include <boost/asio/ip/tcp.hpp>
#include <atomic>
#include <memory>
#include <mutex>
#include <thread>
+namespace Trinity::Net
+{
template<class SocketType>
class NetworkThread
{
@@ -38,14 +42,16 @@ public:
{
}
+ NetworkThread(NetworkThread const&) = delete;
+ NetworkThread(NetworkThread&&) = delete;
+ NetworkThread& operator=(NetworkThread const&) = delete;
+ NetworkThread& operator=(NetworkThread&&) = delete;
+
virtual ~NetworkThread()
{
Stop();
if (_thread)
- {
Wait();
- delete _thread;
- }
}
void Stop()
@@ -59,7 +65,7 @@ public:
if (_thread)
return false;
- _thread = new std::thread(&NetworkThread::Run, this);
+ _thread = std::make_unique<std::thread>(&NetworkThread::Run, this);
return true;
}
@@ -68,7 +74,6 @@ public:
ASSERT(_thread);
_thread->join();
- delete _thread;
_thread = nullptr;
}
@@ -82,15 +87,14 @@ public:
std::lock_guard<std::mutex> lock(_newSocketsLock);
++_connections;
- _newSockets.push_back(sock);
- SocketAdded(sock);
+ SocketAdded(_newSockets.emplace_back(std::move(sock)));
}
- boost::asio::ip::tcp::socket* GetSocketForAccept() { return &_acceptSocket; }
+ Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; }
protected:
- virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { }
- virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { }
+ virtual void SocketAdded(std::shared_ptr<SocketType> const& /*sock*/) { }
+ virtual void SocketRemoved(std::shared_ptr<SocketType> const& /*sock*/) { }
void AddNewSockets()
{
@@ -99,7 +103,7 @@ protected:
if (_newSockets.empty())
return;
- for (std::shared_ptr<SocketType> sock : _newSockets)
+ for (std::shared_ptr<SocketType>& sock : _newSockets)
{
if (!sock->IsOpen())
{
@@ -107,7 +111,7 @@ protected:
--_connections;
}
else
- _sockets.push_back(sock);
+ _sockets.emplace_back(std::move(sock));
}
_newSockets.clear();
@@ -136,7 +140,7 @@ protected:
AddNewSockets();
- _sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr<SocketType> sock)
+ Trinity::Containers::EraseIf(_sockets, [this](std::shared_ptr<SocketType> const& sock)
{
if (!sock->Update())
{
@@ -150,7 +154,7 @@ protected:
}
return false;
- }), _sockets.end());
+ });
}
private:
@@ -159,7 +163,7 @@ private:
std::atomic<int32> _connections;
std::atomic<bool> _stopped;
- std::thread* _thread;
+ std::unique_ptr<std::thread> _thread;
SocketContainer _sockets;
@@ -167,8 +171,9 @@ private:
SocketContainer _newSockets;
Trinity::Asio::IoContext _ioContext;
- boost::asio::ip::tcp::socket _acceptSocket;
+ Trinity::Net::IoContextTcpSocket _acceptSocket;
Trinity::Asio::DeadlineTimer _updateTimer;
};
+}
-#endif // NetworkThread_h__
+#endif // TRINITYCORE_NETWORK_THREAD_H
diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h
index 40f5820da92..565cc175318 100644
--- a/src/server/shared/Networking/Socket.h
+++ b/src/server/shared/Networking/Socket.h
@@ -15,11 +15,14 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef __SOCKET_H__
-#define __SOCKET_H__
+#ifndef TRINITYCORE_SOCKET_H
+#define TRINITYCORE_SOCKET_H
+#include "Concepts.h"
#include "Log.h"
#include "MessageBuffer.h"
+#include "SocketConnectionInitializer.h"
+#include <boost/asio/io_context.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <atomic>
#include <memory>
@@ -31,12 +34,51 @@
#define TC_SOCKET_USE_IOCP
#endif
+namespace Trinity::Net
+{
+using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>;
+
+enum class SocketReadCallbackResult
+{
+ KeepReading,
+ Stop
+};
+
+template <typename Callable>
+concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>;
+
+template <typename SocketType>
+struct InvokeReadHandlerCallback
+{
+ SocketReadCallbackResult operator()() const
+ {
+ return this->Socket->ReadHandler();
+ }
+
+ SocketType* Socket;
+};
+
+template <typename SocketType>
+struct ReadConnectionInitializer final : SocketConnectionInitializer
+{
+ explicit ReadConnectionInitializer(SocketType* socket) : ReadCallback({ .Socket = socket }) { }
+
+ void Start() override
+ {
+ ReadCallback.Socket->AsyncRead(std::move(ReadCallback));
+
+ if (this->next)
+ this->next->Start();
+ }
+
+ InvokeReadHandlerCallback<SocketType> ReadCallback;
+};
+
/**
@class Socket
Base async socket implementation
- @tparam T derived class type (CRTP)
@tparam Stream stream type used for operations on socket
Stream must implement the following methods:
@@ -53,21 +95,27 @@
template<typename ConstBufferSequence>
std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error);
+ template<typename WaitHandlerType>
+ void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler);
+
template<typename SettableSocketOption>
void set_option(SettableSocketOption const& option, boost::system::error_code& error);
tcp::socket::endpoint_type remote_endpoint() const;
*/
-template<class T, class Stream = boost::asio::ip::tcp::socket>
-class Socket : public std::enable_shared_from_this<T>
+template<class Stream = IoContextTcpSocket>
+class Socket : public std::enable_shared_from_this<Socket<Stream>>
{
public:
template<typename... Args>
- explicit Socket(boost::asio::ip::tcp::socket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...),
- _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()),
- _closed(false), _closing(false), _isWritingAsync(false)
+ explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...),
+ _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open)
+ {
+ }
+
+ template<typename... Args>
+ explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward<Args>(args)...), _openState(OpenState_Closed)
{
- _readBuffer.Resize(READ_BLOCK_SIZE);
}
Socket(Socket const& other) = delete;
@@ -77,20 +125,20 @@ public:
virtual ~Socket()
{
- _closed = true;
+ _openState = OpenState_Closed;
boost::system::error_code error;
_socket.close(error);
}
- virtual void Start() = 0;
+ virtual void Start() { }
virtual bool Update()
{
- if (_closed)
+ if (_openState == OpenState_Closed)
return false;
#ifndef TC_SOCKET_USE_IOCP
- if (_isWritingAsync || (_writeQueue.empty() && !_closing))
+ if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open))
return true;
for (; HandleQueue();)
@@ -100,7 +148,7 @@ public:
return true;
}
- boost::asio::ip::address GetRemoteIpAddress() const
+ boost::asio::ip::address const& GetRemoteIpAddress() const
{
return _remoteAddress;
}
@@ -110,7 +158,8 @@ public:
return _remotePort;
}
- void AsyncRead()
+ template <SocketReadCallback Callback>
+ void AsyncRead(Callback&& callback)
{
if (!IsOpen())
return;
@@ -118,23 +167,11 @@ public:
_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
- [self = this->shared_from_this()](boost::system::error_code const& error, size_t transferredBytes)
+ [self = this->shared_from_this(), callback = std::forward<Callback>(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable
{
- self->ReadHandlerInternal(error, transferredBytes);
- });
- }
-
- void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code const&, std::size_t))
- {
- if (!IsOpen())
- return;
-
- _readBuffer.Normalize();
- _readBuffer.EnsureFreeSpace();
- _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
- [self = this->shared_from_this(), callback](boost::system::error_code const& error, size_t transferredBytes)
- {
- (self.get()->*callback)(error, transferredBytes);
+ if (self->ReadHandlerInternal(error, transferredBytes))
+ if (callback() == SocketReadCallbackResult::KeepReading)
+ self->AsyncRead(std::forward<Callback>(callback));
});
}
@@ -147,11 +184,11 @@ public:
#endif
}
- bool IsOpen() const { return !_closed && !_closing; }
+ bool IsOpen() const { return _openState == OpenState_Open; }
void CloseSocket()
{
- if (_closed.exchange(true))
+ if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0)
return;
boost::system::error_code shutdownError;
@@ -160,13 +197,13 @@ public:
TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress().to_string(),
shutdownError.value(), shutdownError.message());
- OnClose();
+ this->OnClose();
}
/// Marks the socket for closing after write buffer becomes empty
void DelayedCloseSocket()
{
- if (_closing.exchange(true))
+ if (_openState.fetch_or(OpenState_Closing) != 0)
return;
if (_writeQueue.empty())
@@ -175,10 +212,15 @@ public:
MessageBuffer& GetReadBuffer() { return _readBuffer; }
+ Stream& underlying_stream()
+ {
+ return _socket;
+ }
+
protected:
virtual void OnClose() { }
- virtual void ReadHandler() = 0;
+ virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; }
bool AsyncProcessQueue()
{
@@ -195,10 +237,10 @@ protected:
self->WriteHandler(error, transferedBytes);
});
#else
- _socket.async_write_some(boost::asio::null_buffers(),
- [self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes)
+ _socket.async_wait(boost::asio::socket_base::wait_type::wait_write,
+ [self = this->shared_from_this()](boost::system::error_code const& error)
{
- self->WriteHandlerWrapper(error, transferedBytes);
+ self->WriteHandlerWrapper(error);
});
#endif
@@ -214,22 +256,17 @@ protected:
GetRemoteIpAddress().to_string(), err.value(), err.message());
}
- Stream& underlying_stream()
- {
- return _socket;
- }
-
private:
- void ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes)
+ bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes)
{
if (error)
{
CloseSocket();
- return;
+ return false;
}
_readBuffer.WriteCompleted(transferredBytes);
- ReadHandler();
+ return IsOpen();
}
#ifdef TC_SOCKET_USE_IOCP
@@ -245,7 +282,7 @@ private:
if (!_writeQueue.empty())
AsyncProcessQueue();
- else if (_closing)
+ else if (_openState == OpenState_Closing)
CloseSocket();
}
else
@@ -254,7 +291,7 @@ private:
#else
- void WriteHandlerWrapper(boost::system::error_code const& /*error*/, std::size_t /*transferedBytes*/)
+ void WriteHandlerWrapper(boost::system::error_code const& /*error*/)
{
_isWritingAsync = false;
HandleQueue();
@@ -278,14 +315,14 @@ private:
return AsyncProcessQueue();
_writeQueue.pop();
- if (_closing && _writeQueue.empty())
+ if (_openState == OpenState_Closing && _writeQueue.empty())
CloseSocket();
return false;
}
else if (bytesSent == 0)
{
_writeQueue.pop();
- if (_closing && _writeQueue.empty())
+ if (_openState == OpenState_Closing && _writeQueue.empty())
CloseSocket();
return false;
}
@@ -296,7 +333,7 @@ private:
}
_writeQueue.pop();
- if (_closing && _writeQueue.empty())
+ if (_openState == OpenState_Closing && _writeQueue.empty())
CloseSocket();
return !_writeQueue.empty();
}
@@ -306,15 +343,20 @@ private:
Stream _socket;
boost::asio::ip::address _remoteAddress;
- uint16 _remotePort;
+ uint16 _remotePort = 0;
- MessageBuffer _readBuffer;
+ MessageBuffer _readBuffer = MessageBuffer(READ_BLOCK_SIZE);
std::queue<MessageBuffer> _writeQueue;
- std::atomic<bool> _closed;
- std::atomic<bool> _closing;
+ // Socket open state "enum" (not enum to enable integral std::atomic api)
+ static constexpr uint8 OpenState_Open = 0x0;
+ static constexpr uint8 OpenState_Closing = 0x1; ///< Transition to Closed state after sending all queued data
+ static constexpr uint8 OpenState_Closed = 0x2;
+
+ std::atomic<uint8> _openState;
- bool _isWritingAsync;
+ bool _isWritingAsync = false;
};
+}
-#endif // __SOCKET_H__
+#endif // TRINITYCORE_SOCKET_H
diff --git a/src/server/shared/Networking/SocketMgr.h b/src/server/shared/Networking/SocketMgr.h
index 0b2d03e0944..07252355308 100644
--- a/src/server/shared/Networking/SocketMgr.h
+++ b/src/server/shared/Networking/SocketMgr.h
@@ -15,32 +15,40 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef SocketMgr_h__
-#define SocketMgr_h__
+#ifndef TRINITYCORE_SOCKET_MGR_H
+#define TRINITYCORE_SOCKET_MGR_H
#include "AsyncAcceptor.h"
#include "Errors.h"
#include "NetworkThread.h"
+#include "Socket.h"
#include <boost/asio/ip/tcp.hpp>
#include <memory>
+namespace Trinity::Net
+{
template<class SocketType>
class SocketMgr
{
public:
+ SocketMgr(SocketMgr const&) = delete;
+ SocketMgr(SocketMgr&&) = delete;
+ SocketMgr& operator=(SocketMgr const&) = delete;
+ SocketMgr& operator=(SocketMgr&&) = delete;
+
virtual ~SocketMgr()
{
ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction");
}
- virtual bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
+ virtual bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
{
ASSERT(threadCount > 0);
- AsyncAcceptor* acceptor = nullptr;
+ std::unique_ptr<AsyncAcceptor> acceptor = nullptr;
try
{
- acceptor = new AsyncAcceptor(ioContext, bindIp, port);
+ acceptor = std::make_unique<AsyncAcceptor>(ioContext, bindIp, port);
}
catch (boost::system::system_error const& err)
{
@@ -51,13 +59,12 @@ public:
if (!acceptor->Bind())
{
TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor");
- delete acceptor;
return false;
}
- _acceptor = acceptor;
+ _acceptor = std::move(acceptor);
_threadCount = threadCount;
- _threads = CreateThreads();
+ _threads.reset(CreateThreads());
ASSERT(_threads);
@@ -73,27 +80,23 @@ public:
{
_acceptor->Close();
- if (_threadCount != 0)
- for (int32 i = 0; i < _threadCount; ++i)
- _threads[i].Stop();
+ for (int32 i = 0; i < _threadCount; ++i)
+ _threads[i].Stop();
Wait();
- delete _acceptor;
_acceptor = nullptr;
- delete[] _threads;
_threads = nullptr;
_threadCount = 0;
}
void Wait()
{
- if (_threadCount != 0)
- for (int32 i = 0; i < _threadCount; ++i)
- _threads[i].Wait();
+ for (int32 i = 0; i < _threadCount; ++i)
+ _threads[i].Wait();
}
- virtual void OnSocketOpen(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex)
+ virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex)
{
try
{
@@ -121,22 +124,23 @@ public:
return min;
}
- std::pair<boost::asio::ip::tcp::socket*, uint32> GetSocketForAccept()
+ std::pair<IoContextTcpSocket*, uint32> GetSocketForAccept()
{
uint32 threadIndex = SelectThreadWithMinConnections();
return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
}
protected:
- SocketMgr() : _acceptor(nullptr), _threads(nullptr), _threadCount(0)
+ SocketMgr() : _threadCount(0)
{
}
virtual NetworkThread<SocketType>* CreateThreads() const = 0;
- AsyncAcceptor* _acceptor;
- NetworkThread<SocketType>* _threads;
+ std::unique_ptr<AsyncAcceptor> _acceptor;
+ std::unique_ptr<NetworkThread<SocketType>[]> _threads;
int32 _threadCount;
};
+}
-#endif // SocketMgr_h__
+#endif // TRINITYCORE_SOCKET_MGR_H
diff --git a/src/server/shared/Networking/SslSocket.h b/src/server/shared/Networking/SslSocket.h
deleted file mode 100644
index c19c8612edf..00000000000
--- a/src/server/shared/Networking/SslSocket.h
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
- *
- * This program is free software; you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation; either version 2 of the License, or (at your
- * option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but WITHOUT
- * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
- * more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef SslSocket_h__
-#define SslSocket_h__
-
-#include <boost/asio/ip/tcp.hpp>
-#include <boost/asio/ssl/stream.hpp>
-#include <boost/system/error_code.hpp>
-
-namespace boostssl = boost::asio::ssl;
-
-template<class Stream = boostssl::stream<boost::asio::ip::tcp::socket>>
-class SslSocket
-{
-public:
- explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext)
- {
- _sslSocket.set_verify_mode(boostssl::verify_none);
- }
-
- // adapting tcp::socket api
- void close(boost::system::error_code& error)
- {
- _sslSocket.lowest_layer().close(error);
- }
-
- void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
- {
- _sslSocket.shutdown(shutdownError);
- _sslSocket.lowest_layer().shutdown(what, shutdownError);
- }
-
- template<typename MutableBufferSequence, typename ReadHandlerType>
- void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler)
- {
- _sslSocket.async_read_some(buffers, std::move(handler));
- }
-
- template<typename ConstBufferSequence, typename WriteHandlerType>
- void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler)
- {
- _sslSocket.async_write_some(buffers, std::move(handler));
- }
-
- template<typename ConstBufferSequence>
- std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error)
- {
- return _sslSocket.write_some(buffers, error);
- }
-
- template<typename SettableSocketOption>
- void set_option(SettableSocketOption const& option, boost::system::error_code& error)
- {
- _sslSocket.lowest_layer().set_option(option, error);
- }
-
- boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const
- {
- return _sslSocket.lowest_layer().remote_endpoint();
- }
-
- // ssl api
- template<typename HandshakeHandlerType>
- void async_handshake(boostssl::stream_base::handshake_type type, HandshakeHandlerType&& handler)
- {
- _sslSocket.async_handshake(type, std::move(handler));
- }
-
-protected:
- Stream _sslSocket;
-};
-
-#endif // SslSocket_h__
diff --git a/src/server/shared/Networking/SslStream.h b/src/server/shared/Networking/SslStream.h
new file mode 100644
index 00000000000..2cced44e5ff
--- /dev/null
+++ b/src/server/shared/Networking/SslStream.h
@@ -0,0 +1,131 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_SSL_STREAM_H
+#define TRINITYCORE_SSL_STREAM_H
+
+#include "SocketConnectionInitializer.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <boost/asio/ssl/stream.hpp>
+#include <boost/system/error_code.hpp>
+
+namespace Trinity::Net
+{
+template <typename SocketImpl>
+struct SslHandshakeConnectionInitializer final : SocketConnectionInitializer
+{
+ explicit SslHandshakeConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
+
+ void Start() override
+ {
+ _socket->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
+ [socketRef = _socket->weak_from_this(), self = this->shared_from_this()](boost::system::error_code const& error)
+ {
+ std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock());
+ if (!socket)
+ return;
+
+ if (error)
+ {
+ TC_LOG_ERROR("session", "{} SSL Handshake failed {}", socket->GetClientInfo(), error.message());
+ socket->CloseSocket();
+ return;
+ }
+
+ if (self->next)
+ self->next->Start();
+ });
+ }
+
+private:
+ SocketImpl* _socket;
+};
+
+template<class WrappedStream = IoContextTcpSocket>
+class SslStream
+{
+public:
+ explicit SslStream(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext)
+ {
+ _sslSocket.set_verify_mode(boost::asio::ssl::verify_none);
+ }
+
+ explicit SslStream(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) : _sslSocket(context, sslContext)
+ {
+ _sslSocket.set_verify_mode(boost::asio::ssl::verify_none);
+ }
+
+ // adapting tcp::socket api
+ void close(boost::system::error_code& error)
+ {
+ _sslSocket.next_layer().close(error);
+ }
+
+ void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
+ {
+ _sslSocket.shutdown(shutdownError);
+ _sslSocket.next_layer().shutdown(what, shutdownError);
+ }
+
+ template<typename MutableBufferSequence, typename ReadHandlerType>
+ void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler)
+ {
+ _sslSocket.async_read_some(buffers, std::forward<ReadHandlerType>(handler));
+ }
+
+ template<typename ConstBufferSequence, typename WriteHandlerType>
+ void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler)
+ {
+ _sslSocket.async_write_some(buffers, std::forward<WriteHandlerType>(handler));
+ }
+
+ template<typename ConstBufferSequence>
+ std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error)
+ {
+ return _sslSocket.write_some(buffers, error);
+ }
+
+ template<typename WaitHandlerType>
+ void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
+ {
+ _sslSocket.next_layer().async_wait(type, std::forward<WaitHandlerType>(handler));
+ }
+
+ template<typename SettableSocketOption>
+ void set_option(SettableSocketOption const& option, boost::system::error_code& error)
+ {
+ _sslSocket.next_layer().set_option(option, error);
+ }
+
+ IoContextTcpSocket::endpoint_type remote_endpoint() const
+ {
+ return _sslSocket.next_layer().remote_endpoint();
+ }
+
+ // ssl api
+ template<typename HandshakeHandlerType>
+ void async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler)
+ {
+ _sslSocket.async_handshake(type, std::forward<HandshakeHandlerType>(handler));
+ }
+
+protected:
+ boost::asio::ssl::stream<WrappedStream> _sslSocket;
+};
+}
+
+#endif // TRINITYCORE_SSL_STREAM_H
diff --git a/src/server/worldserver/Main.cpp b/src/server/worldserver/Main.cpp
index a482d109890..ba5c8f9c758 100644
--- a/src/server/worldserver/Main.cpp
+++ b/src/server/worldserver/Main.cpp
@@ -15,10 +15,6 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-/// \addtogroup Trinityd Trinity Daemon
-/// @{
-/// \file
-
#include "Common.h"
#include "AppenderDB.h"
#include "AsyncAcceptor.h"
@@ -122,7 +118,7 @@ private:
};
void SignalHandler(boost::system::error_code const& error, int signalNumber);
-AsyncAcceptor* StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext);
+std::unique_ptr<Trinity::Net::AsyncAcceptor> StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext);
bool StartDB();
void StopDB();
void WorldUpdateLoop();
@@ -372,9 +368,9 @@ int main(int argc, char** argv)
auto battlegroundMgrHandle = Trinity::make_unique_ptr_with_deleter<&BattlegroundMgr::DeleteAllBattlegrounds>(sBattlegroundMgr);
// Start the Remote Access port (acceptor) if enabled
- std::unique_ptr<AsyncAcceptor> raAcceptor;
+ std::unique_ptr<Trinity::Net::AsyncAcceptor> raAcceptor;
if (sConfigMgr->GetBoolDefault("Ra.Enable", false))
- raAcceptor.reset(StartRaSocketAcceptor(*ioContext));
+ raAcceptor = StartRaSocketAcceptor(*ioContext);
// Start soap serving thread if enabled
std::unique_ptr<std::thread, ShutdownTCSoapThread> soapThread;
@@ -632,20 +628,23 @@ void FreezeDetector::Handler(std::weak_ptr<FreezeDetector> freezeDetectorRef, bo
}
}
-AsyncAcceptor* StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext)
+std::unique_ptr<Trinity::Net::AsyncAcceptor> StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext)
{
uint16 raPort = uint16(sConfigMgr->GetIntDefault("Ra.Port", 3443));
std::string raListener = sConfigMgr->GetStringDefault("Ra.IP", "0.0.0.0");
- AsyncAcceptor* acceptor = new AsyncAcceptor(ioContext, raListener, raPort);
+ std::unique_ptr<Trinity::Net::AsyncAcceptor> acceptor = std::make_unique<Trinity::Net::AsyncAcceptor>(ioContext, raListener, raPort);
if (!acceptor->Bind())
{
TC_LOG_ERROR("server.worldserver", "Failed to bind RA socket acceptor");
- delete acceptor;
return nullptr;
}
- acceptor->AsyncAccept<RASession>();
+ acceptor->AsyncAccept([](Trinity::Net::IoContextTcpSocket&& sock, uint32 /*threadIndex*/)
+ {
+ std::make_shared<RASession>(std::move(sock))->Start();
+
+ });
return acceptor;
}
@@ -696,7 +695,6 @@ void ClearOnlineAccounts(uint32 realmId)
// Battleground instance ids reset at server restart
CharacterDatabase.DirectExecute("UPDATE character_battleground_data SET instanceId = 0");
}
-/// @}
variables_map GetConsoleArguments(int argc, char** argv, fs::path& configFile, fs::path& configDir, [[maybe_unused]] std::string& winServiceAction)
{
diff --git a/src/server/worldserver/RemoteAccess/RASession.cpp b/src/server/worldserver/RemoteAccess/RASession.cpp
index b4e9e6317be..910cfdf5e4f 100644
--- a/src/server/worldserver/RemoteAccess/RASession.cpp
+++ b/src/server/worldserver/RemoteAccess/RASession.cpp
@@ -29,9 +29,11 @@
void RASession::Start()
{
+ _socket.non_blocking(false);
+
// wait 1 second for active connections to send negotiation request
for (int counter = 0; counter < 10 && _socket.available() == 0; counter++)
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ std::this_thread::sleep_for(100ms);
// Check if there are bytes available, if they are, then the client is requesting the negotiation
if (_socket.available() > 0)
diff --git a/src/server/worldserver/RemoteAccess/RASession.h b/src/server/worldserver/RemoteAccess/RASession.h
index e0f4b373f74..23fd4e70c55 100644
--- a/src/server/worldserver/RemoteAccess/RASession.h
+++ b/src/server/worldserver/RemoteAccess/RASession.h
@@ -15,10 +15,11 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef __RASESSION_H__
-#define __RASESSION_H__
+#ifndef TRINITYCORE_RA_SESSION_H
+#define TRINITYCORE_RA_SESSION_H
#include "Define.h"
+#include "Socket.h"
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/streambuf.hpp>
#include <future>
@@ -29,7 +30,7 @@ const size_t bufferSize = 4096;
class RASession : public std::enable_shared_from_this <RASession>
{
public:
- RASession(boost::asio::ip::tcp::socket&& socket) : _socket(std::move(socket)), _commandExecuting(nullptr)
+ RASession(Trinity::Net::IoContextTcpSocket&& socket) : _socket(std::move(socket)), _commandExecuting(nullptr)
{
}
@@ -47,7 +48,7 @@ private:
static void CommandPrint(void* callbackArg, std::string_view text);
static void CommandFinished(void* callbackArg, bool);
- boost::asio::ip::tcp::socket _socket;
+ Trinity::Net::IoContextTcpSocket _socket;
boost::asio::streambuf _readBuffer;
boost::asio::streambuf _writeBuffer;
std::promise<void>* _commandExecuting;