mirror of
https://github.com/TrinityCore/TrinityCore.git
synced 2026-01-16 07:30:42 +01:00
Core/Network: Socket refactors
* Devirtualize calls to Read and Update by marking concrete implementations as final
* Removed derived class template argument
* Specialize boost::asio::basic_stream_socket for boost::asio::io_context instead of type-erased any_io_executor
* Make socket initialization easier composable (before entering Read loop)
* Remove use of deprecated boost::asio::null_buffers and boost::beast::ssl_stream
(cherry picked from commit e8b2be3527)
This commit is contained in:
@@ -207,7 +207,7 @@ namespace Trinity
|
||||
if (!p(*rpos))
|
||||
{
|
||||
if (rpos != wpos)
|
||||
std::swap(*rpos, *wpos);
|
||||
std::ranges::swap(*rpos, *wpos);
|
||||
++wpos;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "DatabaseEnv.h"
|
||||
#include "IPLocation.h"
|
||||
#include "IoContext.h"
|
||||
#include "IpBanCheckConnectionInitializer.h"
|
||||
#include "Log.h"
|
||||
#include "RealmList.h"
|
||||
#include "SecretMgr.h"
|
||||
@@ -199,21 +200,23 @@ void AccountInfo::LoadResult(Field* fields)
|
||||
Utf8ToUpperOnlyLatin(Login);
|
||||
}
|
||||
|
||||
AuthSession::AuthSession(tcp::socket&& socket) : Socket(std::move(socket)),
|
||||
_timeout(*underlying_stream().get_executor().target<boost::asio::io_context::executor_type>()),
|
||||
AuthSession::AuthSession(Trinity::Net::IoContextTcpSocket&& socket) : Socket(std::move(socket)),
|
||||
_timeout(underlying_stream().get_executor()),
|
||||
_status(STATUS_CHALLENGE), _locale(LOCALE_enUS), _os(0), _build(0), _expversion(0), _timezoneOffset(0min)
|
||||
{
|
||||
}
|
||||
|
||||
void AuthSession::Start()
|
||||
{
|
||||
std::string ip_address = GetRemoteIpAddress().to_string();
|
||||
TC_LOG_TRACE("session", "Accepted connection from {}", ip_address);
|
||||
// build initializer chain
|
||||
std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers =
|
||||
{ {
|
||||
std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<AuthSession>>(this),
|
||||
std::make_shared<Trinity::Net::ReadConnectionInitializer<AuthSession>>(this),
|
||||
} };
|
||||
|
||||
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
|
||||
stmt->setString(0, ip_address);
|
||||
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback(std::bind(&AuthSession::CheckIpCallback, this, std::placeholders::_1)));
|
||||
Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start();
|
||||
SetTimeout();
|
||||
}
|
||||
|
||||
bool AuthSession::Update()
|
||||
@@ -226,36 +229,7 @@ bool AuthSession::Update()
|
||||
return true;
|
||||
}
|
||||
|
||||
void AuthSession::CheckIpCallback(PreparedQueryResult result)
|
||||
{
|
||||
if (result)
|
||||
{
|
||||
bool banned = false;
|
||||
do
|
||||
{
|
||||
Field* fields = result->Fetch();
|
||||
if (fields[0].GetUInt64() != 0)
|
||||
banned = true;
|
||||
|
||||
} while (result->NextRow());
|
||||
|
||||
if (banned)
|
||||
{
|
||||
ByteBuffer pkt;
|
||||
pkt << uint8(AUTH_LOGON_CHALLENGE);
|
||||
pkt << uint8(0x00);
|
||||
pkt << uint8(WOW_FAIL_BANNED);
|
||||
SendPacket(pkt);
|
||||
TC_LOG_DEBUG("session", "[AuthSession::CheckIpCallback] Banned ip '{}:{}' tries to login!", GetRemoteIpAddress().to_string(), GetRemotePort());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
AsyncRead();
|
||||
SetTimeout();
|
||||
}
|
||||
|
||||
void AuthSession::ReadHandler()
|
||||
Trinity::Net::SocketReadCallbackResult AuthSession::ReadHandler()
|
||||
{
|
||||
MessageBuffer& packet = GetReadBuffer();
|
||||
while (packet.GetActiveSize())
|
||||
@@ -265,7 +239,7 @@ void AuthSession::ReadHandler()
|
||||
if (!itr || _status != itr->status)
|
||||
{
|
||||
CloseSocket();
|
||||
return;
|
||||
return Trinity::Net::SocketReadCallbackResult::Stop;
|
||||
}
|
||||
|
||||
std::size_t size = itr->packetSize;
|
||||
@@ -279,7 +253,7 @@ void AuthSession::ReadHandler()
|
||||
if (size > MAX_ACCEPTED_CHALLENGE_SIZE)
|
||||
{
|
||||
CloseSocket();
|
||||
return;
|
||||
return Trinity::Net::SocketReadCallbackResult::Stop;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,14 +263,19 @@ void AuthSession::ReadHandler()
|
||||
if (!itr->handler(this))
|
||||
{
|
||||
CloseSocket();
|
||||
return;
|
||||
return Trinity::Net::SocketReadCallbackResult::Stop;
|
||||
}
|
||||
|
||||
packet.ReadCompleted(size);
|
||||
SetTimeout();
|
||||
}
|
||||
|
||||
AsyncRead();
|
||||
return Trinity::Net::SocketReadCallbackResult::KeepReading;
|
||||
}
|
||||
|
||||
void AuthSession::QueueQuery(QueryCallback&& queryCallback)
|
||||
{
|
||||
_queryProcessor.AddCallback(std::move(queryCallback));
|
||||
}
|
||||
|
||||
void AuthSession::SendPacket(ByteBuffer& packet)
|
||||
@@ -334,7 +313,7 @@ bool AuthSession::HandleLogonChallenge()
|
||||
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_LOGONCHALLENGE);
|
||||
stmt->setStringView(0, login);
|
||||
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
|
||||
QueueQuery(LoginDatabase.AsyncQuery(stmt)
|
||||
.WithPreparedCallback([this](PreparedQueryResult result) { LogonChallengeCallback(std::move(result)); }));
|
||||
return true;
|
||||
}
|
||||
@@ -546,7 +525,7 @@ bool AuthSession::HandleLogonProof()
|
||||
stmt->setStringView(3, ClientBuild::ToCharArray(_os).data());
|
||||
stmt->setInt16(4, _timezoneOffset.count());
|
||||
stmt->setString(5, _accountInfo.Login);
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
|
||||
QueueQuery(LoginDatabase.AsyncQuery(stmt)
|
||||
.WithPreparedCallback([this, M2 = Trinity::Crypto::SRP6::GetSessionVerifier(logonProof->A, logonProof->clientM, _sessionKey)](PreparedQueryResult const&)
|
||||
{
|
||||
// Finish SRP6 and send the final result to the client
|
||||
@@ -665,7 +644,7 @@ bool AuthSession::HandleReconnectChallenge()
|
||||
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_RECONNECTCHALLENGE);
|
||||
stmt->setStringView(0, login);
|
||||
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
|
||||
QueueQuery(LoginDatabase.AsyncQuery(stmt)
|
||||
.WithPreparedCallback([this](PreparedQueryResult result) { ReconnectChallengeCallback(std::move(result)); }));
|
||||
return true;
|
||||
}
|
||||
@@ -748,7 +727,7 @@ bool AuthSession::HandleRealmList()
|
||||
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_REALM_CHARACTER_COUNTS);
|
||||
stmt->setUInt32(0, _accountInfo.Id);
|
||||
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback(std::bind(&AuthSession::RealmListCallback, this, std::placeholders::_1)));
|
||||
QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback(std::bind(&AuthSession::RealmListCallback, this, std::placeholders::_1)));
|
||||
_status = STATUS_WAITING_FOR_REALM_LIST;
|
||||
return true;
|
||||
}
|
||||
@@ -922,7 +901,7 @@ void AuthSession::SetTimeout()
|
||||
|
||||
_timeout.async_wait([selfRef = weak_from_this()](boost::system::error_code const& error)
|
||||
{
|
||||
std::shared_ptr<AuthSession> self = selfRef.lock();
|
||||
std::shared_ptr<AuthSession> self = static_pointer_cast<AuthSession>(selfRef.lock());
|
||||
if (!self)
|
||||
return;
|
||||
|
||||
|
||||
@@ -61,20 +61,21 @@ struct AccountInfo
|
||||
AccountTypes SecurityLevel = SEC_PLAYER;
|
||||
};
|
||||
|
||||
class AuthSession : public Socket<AuthSession>
|
||||
class AuthSession final : public Trinity::Net::Socket<>
|
||||
{
|
||||
typedef Socket<AuthSession> AuthSocket;
|
||||
using AuthSocket = Socket;
|
||||
|
||||
public:
|
||||
AuthSession(tcp::socket&& socket);
|
||||
AuthSession(Trinity::Net::IoContextTcpSocket&& socket);
|
||||
|
||||
void Start() override;
|
||||
bool Update() override;
|
||||
|
||||
void SendPacket(ByteBuffer& packet);
|
||||
|
||||
protected:
|
||||
void ReadHandler() override;
|
||||
Trinity::Net::SocketReadCallbackResult ReadHandler() override;
|
||||
|
||||
void QueueQuery(QueryCallback&& queryCallback);
|
||||
|
||||
private:
|
||||
friend AuthHandlerTable;
|
||||
@@ -87,7 +88,6 @@ private:
|
||||
bool HandleXferResume();
|
||||
bool HandleXferCancel();
|
||||
|
||||
void CheckIpCallback(PreparedQueryResult result);
|
||||
void LogonChallengeCallback(PreparedQueryResult result);
|
||||
void ReconnectChallengeCallback(PreparedQueryResult result);
|
||||
void RealmListCallback(PreparedQueryResult result);
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
#include "SocketMgr.h"
|
||||
#include "AuthSession.h"
|
||||
|
||||
class AuthSocketMgr : public SocketMgr<AuthSession>
|
||||
class AuthSocketMgr : public Trinity::Net::SocketMgr<AuthSession>
|
||||
{
|
||||
typedef SocketMgr<AuthSession> BaseSocketMgr;
|
||||
|
||||
@@ -37,19 +37,17 @@ public:
|
||||
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
|
||||
return false;
|
||||
|
||||
_acceptor->AsyncAcceptWithCallback<&AuthSocketMgr::OnSocketAccept>();
|
||||
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
|
||||
{
|
||||
OnSocketOpen(std::move(sock), threadIndex);
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
NetworkThread<AuthSession>* CreateThreads() const override
|
||||
Trinity::Net::NetworkThread<AuthSession>* CreateThreads() const override
|
||||
{
|
||||
return new NetworkThread<AuthSession>[1];
|
||||
}
|
||||
|
||||
static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex)
|
||||
{
|
||||
Instance().OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
|
||||
return new Trinity::Net::NetworkThread<AuthSession>[1];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1274,14 +1274,14 @@ void ScriptMgr::OnNetworkStop()
|
||||
FOREACH_SCRIPT(ServerScript)->OnNetworkStop();
|
||||
}
|
||||
|
||||
void ScriptMgr::OnSocketOpen(std::shared_ptr<WorldSocket> socket)
|
||||
void ScriptMgr::OnSocketOpen(std::shared_ptr<WorldSocket> const& socket)
|
||||
{
|
||||
ASSERT(socket);
|
||||
|
||||
FOREACH_SCRIPT(ServerScript)->OnSocketOpen(socket);
|
||||
}
|
||||
|
||||
void ScriptMgr::OnSocketClose(std::shared_ptr<WorldSocket> socket)
|
||||
void ScriptMgr::OnSocketClose(std::shared_ptr<WorldSocket> const& socket)
|
||||
{
|
||||
ASSERT(socket);
|
||||
|
||||
|
||||
@@ -884,8 +884,8 @@ class TC_GAME_API ScriptMgr
|
||||
|
||||
void OnNetworkStart();
|
||||
void OnNetworkStop();
|
||||
void OnSocketOpen(std::shared_ptr<WorldSocket> socket);
|
||||
void OnSocketClose(std::shared_ptr<WorldSocket> socket);
|
||||
void OnSocketOpen(std::shared_ptr<WorldSocket> const& socket);
|
||||
void OnSocketClose(std::shared_ptr<WorldSocket> const& socket);
|
||||
void OnPacketReceive(WorldSession* session, WorldPacket const& packet);
|
||||
void OnPacketSend(WorldSession* session, WorldPacket const& packet);
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
#include "CryptoHash.h"
|
||||
#include "CryptoRandom.h"
|
||||
#include "IPLocation.h"
|
||||
#include "Opcodes.h"
|
||||
#include "IpBanCheckConnectionInitializer.h"
|
||||
#include "PacketLog.h"
|
||||
#include "Random.h"
|
||||
#include "RBAC.h"
|
||||
@@ -33,10 +33,7 @@
|
||||
#include "WorldSession.h"
|
||||
#include <memory>
|
||||
|
||||
using boost::asio::ip::tcp;
|
||||
|
||||
WorldSocket::WorldSocket(tcp::socket&& socket)
|
||||
: Socket(std::move(socket)), _OverSpeedPings(0), _worldSession(nullptr), _authed(false), _sendBufferSize(4096)
|
||||
WorldSocket::WorldSocket(Trinity::Net::IoContextTcpSocket&& socket) : BaseSocket(std::move(socket)), _OverSpeedPings(0), _worldSession(nullptr), _authed(false), _sendBufferSize(4096)
|
||||
{
|
||||
Trinity::Crypto::GetRandomBytes(_authSeed);
|
||||
_headerBuffer.Resize(sizeof(ClientPktHeader));
|
||||
@@ -44,39 +41,33 @@ WorldSocket::WorldSocket(tcp::socket&& socket)
|
||||
|
||||
WorldSocket::~WorldSocket() = default;
|
||||
|
||||
void WorldSocket::Start()
|
||||
struct WorldSocketProtocolInitializer final : Trinity::Net::SocketConnectionInitializer
|
||||
{
|
||||
std::string ip_address = GetRemoteIpAddress().to_string();
|
||||
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
|
||||
stmt->setString(0, ip_address);
|
||||
explicit WorldSocketProtocolInitializer(WorldSocket* socket) : _socket(socket) { }
|
||||
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback(std::bind(&WorldSocket::CheckIpCallback, this, std::placeholders::_1)));
|
||||
}
|
||||
|
||||
void WorldSocket::CheckIpCallback(PreparedQueryResult result)
|
||||
{
|
||||
if (result)
|
||||
void Start() override
|
||||
{
|
||||
bool banned = false;
|
||||
do
|
||||
{
|
||||
Field* fields = result->Fetch();
|
||||
if (fields[0].GetUInt64() != 0)
|
||||
banned = true;
|
||||
_socket->SendAuthSession();
|
||||
|
||||
} while (result->NextRow());
|
||||
|
||||
if (banned)
|
||||
{
|
||||
SendAuthResponseError(AUTH_REJECT);
|
||||
TC_LOG_ERROR("network", "WorldSocket::CheckIpCallback: Sent Auth Response (IP {} banned).", GetRemoteIpAddress().to_string());
|
||||
DelayedCloseSocket();
|
||||
return;
|
||||
}
|
||||
if (this->next)
|
||||
this->next->Start();
|
||||
}
|
||||
|
||||
AsyncRead();
|
||||
HandleSendAuthSession();
|
||||
private:
|
||||
WorldSocket* _socket;
|
||||
};
|
||||
|
||||
void WorldSocket::Start()
|
||||
{
|
||||
// build initializer chain
|
||||
std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers =
|
||||
{ {
|
||||
std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<WorldSocket>>(this),
|
||||
std::make_shared<WorldSocketProtocolInitializer>(this),
|
||||
std::make_shared<Trinity::Net::ReadConnectionInitializer<WorldSocket>>(this),
|
||||
} };
|
||||
|
||||
Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start();
|
||||
}
|
||||
|
||||
bool WorldSocket::Update()
|
||||
@@ -129,7 +120,7 @@ bool WorldSocket::Update()
|
||||
return true;
|
||||
}
|
||||
|
||||
void WorldSocket::HandleSendAuthSession()
|
||||
void WorldSocket::SendAuthSession()
|
||||
{
|
||||
WorldPacket packet(SMSG_AUTH_CHALLENGE, 40);
|
||||
packet << uint32(1); // 1...31
|
||||
@@ -148,11 +139,8 @@ void WorldSocket::OnClose()
|
||||
}
|
||||
}
|
||||
|
||||
void WorldSocket::ReadHandler()
|
||||
Trinity::Net::SocketReadCallbackResult WorldSocket::ReadHandler()
|
||||
{
|
||||
if (!IsOpen())
|
||||
return;
|
||||
|
||||
MessageBuffer& packet = GetReadBuffer();
|
||||
while (packet.GetActiveSize() > 0)
|
||||
{
|
||||
@@ -174,7 +162,7 @@ void WorldSocket::ReadHandler()
|
||||
if (!ReadHeaderHandler())
|
||||
{
|
||||
CloseSocket();
|
||||
return;
|
||||
return Trinity::Net::SocketReadCallbackResult::Stop;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -202,11 +190,16 @@ void WorldSocket::ReadHandler()
|
||||
if (result != ReadDataHandlerResult::WaitingForQuery)
|
||||
CloseSocket();
|
||||
|
||||
return;
|
||||
return Trinity::Net::SocketReadCallbackResult::Stop;
|
||||
}
|
||||
}
|
||||
|
||||
AsyncRead();
|
||||
return Trinity::Net::SocketReadCallbackResult::KeepReading;
|
||||
}
|
||||
|
||||
void WorldSocket::QueueQuery(QueryCallback&& queryCallback)
|
||||
{
|
||||
_queryProcessor.AddCallback(std::move(queryCallback));
|
||||
}
|
||||
|
||||
bool WorldSocket::ReadHeaderHandler()
|
||||
@@ -398,14 +391,13 @@ WorldSocket::ReadDataHandlerResult WorldSocket::ReadDataHandler()
|
||||
|
||||
void WorldSocket::LogOpcodeText(OpcodeClient opcode, std::unique_lock<std::mutex> const& guard) const
|
||||
{
|
||||
if (!guard)
|
||||
if (!guard || !_worldSession)
|
||||
{
|
||||
TC_LOG_TRACE("network.opcode", "C->S: {} {}", GetRemoteIpAddress().to_string(), GetOpcodeNameForLogging(opcode));
|
||||
}
|
||||
else
|
||||
{
|
||||
TC_LOG_TRACE("network.opcode", "C->S: {} {}", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress().to_string()),
|
||||
GetOpcodeNameForLogging(opcode));
|
||||
TC_LOG_TRACE("network.opcode", "C->S: {} {}", _worldSession->GetPlayerInfo(), GetOpcodeNameForLogging(opcode));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -449,7 +441,7 @@ void WorldSocket::HandleAuthSession(WorldPacket& recvPacket)
|
||||
stmt->setInt32(0, int32(realm.Id.Realm));
|
||||
stmt->setString(1, authSession->Account);
|
||||
|
||||
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession)](PreparedQueryResult result) mutable
|
||||
QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession)](PreparedQueryResult result) mutable
|
||||
{
|
||||
HandleAuthSessionCallback(std::move(authSession), std::move(result));
|
||||
}));
|
||||
@@ -613,16 +605,18 @@ void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<AuthSession> authSes
|
||||
sScriptMgr->OnAccountLogin(account.Id);
|
||||
|
||||
_authed = true;
|
||||
_worldSession = new WorldSession(account.Id, std::move(authSession->Account), shared_from_this(), account.Security,
|
||||
account.Expansion, mutetime, account.TimezoneOffset, account.Locale, account.Recruiter, account.IsRectuiter);
|
||||
_worldSession = new WorldSession(account.Id, std::move(authSession->Account),
|
||||
static_pointer_cast<WorldSocket>(shared_from_this()), account.Security, account.Expansion, mutetime,
|
||||
account.TimezoneOffset, account.Locale,
|
||||
account.Recruiter, account.IsRectuiter);
|
||||
_worldSession->ReadAddonsInfo(authSession->AddonInfo);
|
||||
|
||||
// Initialize Warden system only if it is enabled by config
|
||||
if (wardenActive)
|
||||
_worldSession->InitWarden(account.SessionKey, account.OS);
|
||||
|
||||
_queryProcessor.AddCallback(_worldSession->LoadPermissionsAsync().WithPreparedCallback(std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1)));
|
||||
AsyncRead();
|
||||
QueueQuery(_worldSession->LoadPermissionsAsync().WithPreparedCallback(std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1)));
|
||||
AsyncRead(Trinity::Net::InvokeReadHandlerCallback<WorldSocket>{ .Socket = this });
|
||||
}
|
||||
|
||||
void WorldSocket::LoadSessionPermissionsCallback(PreparedQueryResult result)
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef __WORLDSOCKET_H__
|
||||
#define __WORLDSOCKET_H__
|
||||
#ifndef TRINITYCORE_WORLD_SOCKET_H
|
||||
#define TRINITYCORE_WORLD_SOCKET_H
|
||||
|
||||
#include "Common.h"
|
||||
#include "ServerPktHeader.h"
|
||||
@@ -64,16 +64,18 @@ struct ClientPktHeader
|
||||
|
||||
struct AuthSession;
|
||||
|
||||
class TC_GAME_API WorldSocket : public Socket<WorldSocket>
|
||||
class TC_GAME_API WorldSocket final : public Trinity::Net::Socket<>
|
||||
{
|
||||
typedef Socket<WorldSocket> BaseSocket;
|
||||
using BaseSocket = Socket;
|
||||
|
||||
public:
|
||||
WorldSocket(tcp::socket&& socket);
|
||||
WorldSocket(Trinity::Net::IoContextTcpSocket&& socket);
|
||||
~WorldSocket();
|
||||
|
||||
WorldSocket(WorldSocket const& right) = delete;
|
||||
WorldSocket(WorldSocket&& right) = delete;
|
||||
WorldSocket& operator=(WorldSocket const& right) = delete;
|
||||
WorldSocket& operator=(WorldSocket&& right) = delete;
|
||||
|
||||
void Start() override;
|
||||
bool Update() override;
|
||||
@@ -82,9 +84,14 @@ public:
|
||||
|
||||
void SetSendBufferSize(std::size_t sendBufferSize) { _sendBufferSize = sendBufferSize; }
|
||||
|
||||
protected:
|
||||
void OnClose() override;
|
||||
void ReadHandler() override;
|
||||
Trinity::Net::SocketReadCallbackResult ReadHandler() override;
|
||||
|
||||
void QueueQuery(QueryCallback&& queryCallback);
|
||||
|
||||
void SendAuthSession();
|
||||
|
||||
protected:
|
||||
bool ReadHeaderHandler();
|
||||
|
||||
enum class ReadDataHandlerResult
|
||||
@@ -97,14 +104,11 @@ protected:
|
||||
ReadDataHandlerResult ReadDataHandler();
|
||||
|
||||
private:
|
||||
void CheckIpCallback(PreparedQueryResult result);
|
||||
|
||||
/// writes network.opcode log
|
||||
/// accessing WorldSession is not threadsafe, only do it when holding _worldSessionLock
|
||||
void LogOpcodeText(OpcodeClient opcode, std::unique_lock<std::mutex> const& guard) const;
|
||||
/// sends and logs network.opcode without accessing WorldSession
|
||||
void SendPacketAndLogOpcode(WorldPacket const& packet);
|
||||
void HandleSendAuthSession();
|
||||
void HandleAuthSession(WorldPacket& recvPacket);
|
||||
void HandleAuthSessionCallback(std::shared_ptr<AuthSession> authSession, PreparedQueryResult result);
|
||||
void LoadSessionPermissionsCallback(PreparedQueryResult result);
|
||||
|
||||
@@ -23,21 +23,16 @@
|
||||
|
||||
#include <boost/system/error_code.hpp>
|
||||
|
||||
static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex)
|
||||
{
|
||||
sWorldSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
|
||||
}
|
||||
|
||||
class WorldSocketThread : public NetworkThread<WorldSocket>
|
||||
class WorldSocketThread : public Trinity::Net::NetworkThread<WorldSocket>
|
||||
{
|
||||
public:
|
||||
void SocketAdded(std::shared_ptr<WorldSocket> sock) override
|
||||
void SocketAdded(std::shared_ptr<WorldSocket> const& sock) override
|
||||
{
|
||||
sock->SetSendBufferSize(sWorldSocketMgr.GetApplicationSendBufferSize());
|
||||
sScriptMgr->OnSocketOpen(sock);
|
||||
}
|
||||
|
||||
void SocketRemoved(std::shared_ptr<WorldSocket> sock) override
|
||||
void SocketRemoved(std::shared_ptr<WorldSocket>const& sock) override
|
||||
{
|
||||
sScriptMgr->OnSocketClose(sock);
|
||||
}
|
||||
@@ -74,7 +69,10 @@ bool WorldSocketMgr::StartWorldNetwork(Trinity::Asio::IoContext& ioContext, std:
|
||||
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
|
||||
return false;
|
||||
|
||||
_acceptor->AsyncAcceptWithCallback<&OnSocketAccept>();
|
||||
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
|
||||
{
|
||||
OnSocketOpen(std::move(sock), threadIndex);
|
||||
});
|
||||
|
||||
sScriptMgr->OnNetworkStart();
|
||||
return true;
|
||||
@@ -87,7 +85,7 @@ void WorldSocketMgr::StopNetwork()
|
||||
sScriptMgr->OnNetworkStop();
|
||||
}
|
||||
|
||||
void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
|
||||
void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
|
||||
{
|
||||
// set some options here
|
||||
if (_socketSystemSendBufferSize >= 0)
|
||||
@@ -115,10 +113,10 @@ void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
|
||||
|
||||
//sock->m_OutBufferSize = static_cast<size_t> (m_SockOutUBuff);
|
||||
|
||||
BaseSocketMgr::OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
|
||||
BaseSocketMgr::OnSocketOpen(std::move(sock), threadIndex);
|
||||
}
|
||||
|
||||
NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const
|
||||
Trinity::Net::NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const
|
||||
{
|
||||
return new WorldSocketThread[GetNetworkThreadCount()];
|
||||
}
|
||||
|
||||
@@ -15,21 +15,15 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
/** \addtogroup u2w User to World Communication
|
||||
* @{
|
||||
* \file WorldSocketMgr.h
|
||||
* \author Derex <derex101@gmail.com>
|
||||
*/
|
||||
|
||||
#ifndef __WORLDSOCKETMGR_H
|
||||
#define __WORLDSOCKETMGR_H
|
||||
#ifndef TRINITYCORE_WORLD_SOCKET_MGR_H
|
||||
#define TRINITYCORE_WORLD_SOCKET_MGR_H
|
||||
|
||||
#include "SocketMgr.h"
|
||||
|
||||
class WorldSocket;
|
||||
|
||||
/// Manages all sockets connected to peers and network threads
|
||||
class TC_GAME_API WorldSocketMgr : public SocketMgr<WorldSocket>
|
||||
class TC_GAME_API WorldSocketMgr : public Trinity::Net::SocketMgr<WorldSocket>
|
||||
{
|
||||
typedef SocketMgr<WorldSocket> BaseSocketMgr;
|
||||
|
||||
@@ -42,14 +36,14 @@ public:
|
||||
/// Stops all network threads, It will wait for all running threads .
|
||||
void StopNetwork() override;
|
||||
|
||||
void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) override;
|
||||
void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) override;
|
||||
|
||||
std::size_t GetApplicationSendBufferSize() const { return _socketApplicationSendBufferSize; }
|
||||
|
||||
protected:
|
||||
WorldSocketMgr();
|
||||
|
||||
NetworkThread<WorldSocket>* CreateThreads() const override;
|
||||
Trinity::Net::NetworkThread<WorldSocket>* CreateThreads() const override;
|
||||
|
||||
private:
|
||||
int32 _socketSystemSendBufferSize;
|
||||
@@ -60,4 +54,3 @@ private:
|
||||
#define sWorldSocketMgr WorldSocketMgr::Instance()
|
||||
|
||||
#endif
|
||||
/// @}
|
||||
|
||||
@@ -15,12 +15,13 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef __ASYNCACCEPT_H_
|
||||
#define __ASYNCACCEPT_H_
|
||||
#ifndef TRINITYCORE_ASYNC_ACCEPTOR_H
|
||||
#define TRINITYCORE_ASYNC_ACCEPTOR_H
|
||||
|
||||
#include "IoContext.h"
|
||||
#include "IpAddress.h"
|
||||
#include "Log.h"
|
||||
#include "Socket.h"
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
@@ -29,27 +30,28 @@ using boost::asio::ip::tcp;
|
||||
|
||||
#define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections
|
||||
|
||||
namespace Trinity::Net
|
||||
{
|
||||
template <typename Callable>
|
||||
concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&, uint32>;
|
||||
|
||||
class AsyncAcceptor
|
||||
{
|
||||
public:
|
||||
typedef void(*AcceptCallback)(tcp::socket&& newSocket, uint32 threadIndex);
|
||||
|
||||
AsyncAcceptor(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
|
||||
_acceptor(ioContext), _endpoint(Trinity::Net::make_address(bindIp), port),
|
||||
AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
|
||||
_acceptor(ioContext), _endpoint(make_address(bindIp), port),
|
||||
_socket(ioContext), _closed(false), _socketFactory(std::bind(&AsyncAcceptor::DefeaultSocketFactory, this))
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void AsyncAccept();
|
||||
|
||||
template<AcceptCallback acceptCallback>
|
||||
void AsyncAcceptWithCallback()
|
||||
template <AcceptCallback Callback>
|
||||
void AsyncAccept(Callback&& acceptCallback)
|
||||
{
|
||||
tcp::socket* socket;
|
||||
uint32 threadIndex;
|
||||
std::tie(socket, threadIndex) = _socketFactory();
|
||||
_acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error)
|
||||
auto [tmpSocket, tmpThreadIndex] = _socketFactory();
|
||||
// TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures)
|
||||
IoContextTcpSocket* socket = tmpSocket;
|
||||
uint32 threadIndex = tmpThreadIndex;
|
||||
_acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error) mutable
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
@@ -66,7 +68,7 @@ public:
|
||||
}
|
||||
|
||||
if (!_closed)
|
||||
this->AsyncAcceptWithCallback<acceptCallback>();
|
||||
this->AsyncAccept(std::move(acceptCallback));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -115,40 +117,17 @@ public:
|
||||
_acceptor.close(err);
|
||||
}
|
||||
|
||||
void SetSocketFactory(std::function<std::pair<tcp::socket*, uint32>()> func) { _socketFactory = func; }
|
||||
void SetSocketFactory(std::function<std::pair<IoContextTcpSocket*, uint32>()> func) { _socketFactory = std::move(func); }
|
||||
|
||||
private:
|
||||
std::pair<tcp::socket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
|
||||
std::pair<IoContextTcpSocket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
|
||||
|
||||
tcp::acceptor _acceptor;
|
||||
tcp::endpoint _endpoint;
|
||||
tcp::socket _socket;
|
||||
boost::asio::basic_socket_acceptor<boost::asio::ip::tcp, IoContextTcpSocket::executor_type> _acceptor;
|
||||
boost::asio::ip::tcp::endpoint _endpoint;
|
||||
IoContextTcpSocket _socket;
|
||||
std::atomic<bool> _closed;
|
||||
std::function<std::pair<tcp::socket*, uint32>()> _socketFactory;
|
||||
std::function<std::pair<IoContextTcpSocket*, uint32>()> _socketFactory;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
void AsyncAcceptor::AsyncAccept()
|
||||
{
|
||||
_acceptor.async_accept(_socket, [this](boost::system::error_code error)
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
try
|
||||
{
|
||||
// this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class
|
||||
std::make_shared<T>(std::move(this->_socket))->Start();
|
||||
}
|
||||
catch (boost::system::system_error const& err)
|
||||
{
|
||||
TC_LOG_INFO("network", "Failed to retrieve client's remote address {}", err.what());
|
||||
}
|
||||
}
|
||||
|
||||
// lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face
|
||||
if (!_closed)
|
||||
this->AsyncAccept<T>();
|
||||
});
|
||||
}
|
||||
|
||||
#endif /* __ASYNCACCEPT_H_ */
|
||||
#endif // TRINITYCORE_ASYNC_ACCEPTOR_H
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License as published by the
|
||||
* Free Software Foundation; either version 2 of the License, or (at your
|
||||
* option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
|
||||
* more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License along
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#include "IpBanCheckConnectionInitializer.h"
|
||||
#include "DatabaseEnv.h"
|
||||
|
||||
QueryCallback Trinity::Net::IpBanCheckHelpers::AsyncQuery(std::string_view ipAddress)
|
||||
{
|
||||
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
|
||||
stmt->setStringView(0, ipAddress);
|
||||
return LoginDatabase.AsyncQuery(stmt);
|
||||
}
|
||||
|
||||
bool Trinity::Net::IpBanCheckHelpers::IsBanned(PreparedQueryResult const& result)
|
||||
{
|
||||
if (result)
|
||||
{
|
||||
do
|
||||
{
|
||||
Field* fields = result->Fetch();
|
||||
if (fields[0].GetUInt64() != 0)
|
||||
return true;
|
||||
|
||||
} while (result->NextRow());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
/*
|
||||
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License as published by the
|
||||
* Free Software Foundation; either version 2 of the License, or (at your
|
||||
* option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
|
||||
* more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License along
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H
|
||||
#define TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H
|
||||
|
||||
#include "DatabaseEnvFwd.h"
|
||||
#include "Log.h"
|
||||
#include "QueryCallback.h"
|
||||
#include "SocketConnectionInitializer.h"
|
||||
|
||||
namespace Trinity::Net
|
||||
{
|
||||
namespace IpBanCheckHelpers
|
||||
{
|
||||
TC_SHARED_API QueryCallback AsyncQuery(std::string_view ipAddress);
|
||||
TC_SHARED_API bool IsBanned(PreparedQueryResult const& result);
|
||||
}
|
||||
|
||||
template <typename SocketImpl>
|
||||
struct IpBanCheckConnectionInitializer final : SocketConnectionInitializer
|
||||
{
|
||||
explicit IpBanCheckConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
|
||||
|
||||
void Start() override
|
||||
{
|
||||
_socket->QueueQuery(IpBanCheckHelpers::AsyncQuery(_socket->GetRemoteIpAddress().to_string()).WithPreparedCallback([socketRef = _socket->weak_from_this(), self = this->shared_from_this()](PreparedQueryResult const& result)
|
||||
{
|
||||
std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock());
|
||||
if (!socket)
|
||||
return;
|
||||
|
||||
if (IpBanCheckHelpers::IsBanned(result))
|
||||
{
|
||||
TC_LOG_ERROR("network", "IpBanCheckConnectionInitializer: IP {} is banned.", socket->GetRemoteIpAddress().to_string());
|
||||
socket->DelayedCloseSocket();
|
||||
return;
|
||||
}
|
||||
|
||||
if (self->next)
|
||||
self->next->Start();
|
||||
}));
|
||||
}
|
||||
|
||||
private:
|
||||
SocketImpl* _socket;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H
|
||||
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License as published by the
|
||||
* Free Software Foundation; either version 2 of the License, or (at your
|
||||
* option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
|
||||
* more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License along
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
|
||||
#define TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
|
||||
|
||||
#include <memory>
|
||||
#include <span>
|
||||
|
||||
namespace Trinity::Net
|
||||
{
|
||||
struct SocketConnectionInitializer : public std::enable_shared_from_this<SocketConnectionInitializer>
|
||||
{
|
||||
SocketConnectionInitializer() = default;
|
||||
|
||||
SocketConnectionInitializer(SocketConnectionInitializer const&) = delete;
|
||||
SocketConnectionInitializer(SocketConnectionInitializer&&) noexcept = default;
|
||||
SocketConnectionInitializer& operator=(SocketConnectionInitializer const&) = delete;
|
||||
SocketConnectionInitializer& operator=(SocketConnectionInitializer&&) noexcept = default;
|
||||
|
||||
virtual ~SocketConnectionInitializer() = default;
|
||||
|
||||
virtual void Start() = 0;
|
||||
|
||||
std::shared_ptr<SocketConnectionInitializer> next;
|
||||
|
||||
static std::shared_ptr<SocketConnectionInitializer>& SetupChain(std::span<std::shared_ptr<SocketConnectionInitializer>> initializers)
|
||||
{
|
||||
for (std::size_t i = initializers.size(); i > 1; --i)
|
||||
initializers[i - 2]->next.swap(initializers[i - 1]);
|
||||
|
||||
return initializers[0];
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#endif // TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
|
||||
@@ -15,14 +15,16 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef NetworkThread_h__
|
||||
#define NetworkThread_h__
|
||||
#ifndef TRINITYCORE_NETWORK_THREAD_H
|
||||
#define TRINITYCORE_NETWORK_THREAD_H
|
||||
|
||||
#include "Define.h"
|
||||
#include "Containers.h"
|
||||
#include "DeadlineTimer.h"
|
||||
#include "Define.h"
|
||||
#include "Errors.h"
|
||||
#include "IoContext.h"
|
||||
#include "Log.h"
|
||||
#include "Socket.h"
|
||||
#include "Timer.h"
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <atomic>
|
||||
@@ -32,8 +34,8 @@
|
||||
#include <set>
|
||||
#include <thread>
|
||||
|
||||
using boost::asio::ip::tcp;
|
||||
|
||||
namespace Trinity::Net
|
||||
{
|
||||
template<class SocketType>
|
||||
class NetworkThread
|
||||
{
|
||||
@@ -43,14 +45,16 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
NetworkThread(NetworkThread const&) = delete;
|
||||
NetworkThread(NetworkThread&&) = delete;
|
||||
NetworkThread& operator=(NetworkThread const&) = delete;
|
||||
NetworkThread& operator=(NetworkThread&&) = delete;
|
||||
|
||||
virtual ~NetworkThread()
|
||||
{
|
||||
Stop();
|
||||
if (_thread)
|
||||
{
|
||||
Wait();
|
||||
delete _thread;
|
||||
}
|
||||
}
|
||||
|
||||
void Stop()
|
||||
@@ -64,7 +68,7 @@ public:
|
||||
if (_thread)
|
||||
return false;
|
||||
|
||||
_thread = new std::thread(&NetworkThread::Run, this);
|
||||
_thread = std::make_unique<std::thread>(&NetworkThread::Run, this);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -73,7 +77,6 @@ public:
|
||||
ASSERT(_thread);
|
||||
|
||||
_thread->join();
|
||||
delete _thread;
|
||||
_thread = nullptr;
|
||||
}
|
||||
|
||||
@@ -87,15 +90,14 @@ public:
|
||||
std::lock_guard<std::mutex> lock(_newSocketsLock);
|
||||
|
||||
++_connections;
|
||||
_newSockets.push_back(sock);
|
||||
SocketAdded(sock);
|
||||
SocketAdded(_newSockets.emplace_back(std::move(sock)));
|
||||
}
|
||||
|
||||
tcp::socket* GetSocketForAccept() { return &_acceptSocket; }
|
||||
Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; }
|
||||
|
||||
protected:
|
||||
virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { }
|
||||
virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { }
|
||||
virtual void SocketAdded(std::shared_ptr<SocketType> const& /*sock*/) { }
|
||||
virtual void SocketRemoved(std::shared_ptr<SocketType> const& /*sock*/) { }
|
||||
|
||||
void AddNewSockets()
|
||||
{
|
||||
@@ -104,7 +106,7 @@ protected:
|
||||
if (_newSockets.empty())
|
||||
return;
|
||||
|
||||
for (std::shared_ptr<SocketType> sock : _newSockets)
|
||||
for (std::shared_ptr<SocketType>& sock : _newSockets)
|
||||
{
|
||||
if (!sock->IsOpen())
|
||||
{
|
||||
@@ -112,7 +114,7 @@ protected:
|
||||
--_connections;
|
||||
}
|
||||
else
|
||||
_sockets.push_back(sock);
|
||||
_sockets.emplace_back(std::move(sock));
|
||||
}
|
||||
|
||||
_newSockets.clear();
|
||||
@@ -141,7 +143,7 @@ protected:
|
||||
|
||||
AddNewSockets();
|
||||
|
||||
_sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr<SocketType> sock)
|
||||
Trinity::Containers::EraseIf(_sockets, [this](std::shared_ptr<SocketType> const& sock)
|
||||
{
|
||||
if (!sock->Update())
|
||||
{
|
||||
@@ -155,7 +157,7 @@ protected:
|
||||
}
|
||||
|
||||
return false;
|
||||
}), _sockets.end());
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -164,7 +166,7 @@ private:
|
||||
std::atomic<int32> _connections;
|
||||
std::atomic<bool> _stopped;
|
||||
|
||||
std::thread* _thread;
|
||||
std::unique_ptr<std::thread> _thread;
|
||||
|
||||
SocketContainer _sockets;
|
||||
|
||||
@@ -172,8 +174,9 @@ private:
|
||||
SocketContainer _newSockets;
|
||||
|
||||
Trinity::Asio::IoContext _ioContext;
|
||||
tcp::socket _acceptSocket;
|
||||
Trinity::Net::IoContextTcpSocket _acceptSocket;
|
||||
Trinity::Asio::DeadlineTimer _updateTimer;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // NetworkThread_h__
|
||||
#endif // TRINITYCORE_NETWORK_THREAD_H
|
||||
|
||||
@@ -15,17 +15,19 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef __SOCKET_H__
|
||||
#define __SOCKET_H__
|
||||
#ifndef TRINITYCORE_SOCKET_H
|
||||
#define TRINITYCORE_SOCKET_H
|
||||
|
||||
#include "MessageBuffer.h"
|
||||
#include "Concepts.h"
|
||||
#include "Log.h"
|
||||
#include <atomic>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
#include "MessageBuffer.h"
|
||||
#include "SocketConnectionInitializer.h"
|
||||
#include <boost/asio/io_context.hpp>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <type_traits>
|
||||
|
||||
using boost::asio::ip::tcp;
|
||||
|
||||
@@ -34,32 +36,111 @@ using boost::asio::ip::tcp;
|
||||
#define TC_SOCKET_USE_IOCP
|
||||
#endif
|
||||
|
||||
template<class T>
|
||||
class Socket : public std::enable_shared_from_this<T>
|
||||
namespace Trinity::Net
|
||||
{
|
||||
using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>;
|
||||
|
||||
enum class SocketReadCallbackResult
|
||||
{
|
||||
KeepReading,
|
||||
Stop
|
||||
};
|
||||
|
||||
template <typename Callable>
|
||||
concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>;
|
||||
|
||||
template <typename SocketType>
|
||||
struct InvokeReadHandlerCallback
|
||||
{
|
||||
SocketReadCallbackResult operator()() const
|
||||
{
|
||||
return this->Socket->ReadHandler();
|
||||
}
|
||||
|
||||
SocketType* Socket;
|
||||
};
|
||||
|
||||
template <typename SocketType>
|
||||
struct ReadConnectionInitializer final : SocketConnectionInitializer
|
||||
{
|
||||
explicit ReadConnectionInitializer(SocketType* socket) : ReadCallback({ .Socket = socket }) { }
|
||||
|
||||
void Start() override
|
||||
{
|
||||
ReadCallback.Socket->AsyncRead(std::move(ReadCallback));
|
||||
|
||||
if (this->next)
|
||||
this->next->Start();
|
||||
}
|
||||
|
||||
InvokeReadHandlerCallback<SocketType> ReadCallback;
|
||||
};
|
||||
|
||||
/**
|
||||
@class Socket
|
||||
|
||||
Base async socket implementation
|
||||
|
||||
@tparam Stream stream type used for operations on socket
|
||||
Stream must implement the following methods:
|
||||
|
||||
void close(boost::system::error_code& error);
|
||||
|
||||
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError);
|
||||
|
||||
template<typename MutableBufferSequence, typename ReadHandlerType>
|
||||
void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler);
|
||||
|
||||
template<typename ConstBufferSequence, typename WriteHandlerType>
|
||||
void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler);
|
||||
|
||||
template<typename ConstBufferSequence>
|
||||
std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error);
|
||||
|
||||
template<typename WaitHandlerType>
|
||||
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler);
|
||||
|
||||
template<typename SettableSocketOption>
|
||||
void set_option(SettableSocketOption const& option, boost::system::error_code& error);
|
||||
|
||||
tcp::socket::endpoint_type remote_endpoint() const;
|
||||
*/
|
||||
template<class Stream = IoContextTcpSocket>
|
||||
class Socket : public std::enable_shared_from_this<Socket<Stream>>
|
||||
{
|
||||
public:
|
||||
explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
|
||||
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false)
|
||||
template<typename... Args>
|
||||
explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...),
|
||||
_remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open)
|
||||
{
|
||||
_readBuffer.Resize(READ_BLOCK_SIZE);
|
||||
}
|
||||
|
||||
template<typename... Args>
|
||||
explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward<Args>(args)...), _openState(OpenState_Closed)
|
||||
{
|
||||
}
|
||||
|
||||
Socket(Socket const& other) = delete;
|
||||
Socket(Socket&& other) = delete;
|
||||
Socket& operator=(Socket const& other) = delete;
|
||||
Socket& operator=(Socket&& other) = delete;
|
||||
|
||||
virtual ~Socket()
|
||||
{
|
||||
_closed = true;
|
||||
_openState = OpenState_Closed;
|
||||
boost::system::error_code error;
|
||||
_socket.close(error);
|
||||
}
|
||||
|
||||
virtual void Start() = 0;
|
||||
virtual void Start() { }
|
||||
|
||||
virtual bool Update()
|
||||
{
|
||||
if (_closed)
|
||||
if (_openState == OpenState_Closed)
|
||||
return false;
|
||||
|
||||
#ifndef TC_SOCKET_USE_IOCP
|
||||
if (_isWritingAsync || (_writeQueue.empty() && !_closing))
|
||||
if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open))
|
||||
return true;
|
||||
|
||||
for (; HandleQueue();)
|
||||
@@ -69,7 +150,7 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
boost::asio::ip::address GetRemoteIpAddress() const
|
||||
boost::asio::ip::address const& GetRemoteIpAddress() const
|
||||
{
|
||||
return _remoteAddress;
|
||||
}
|
||||
@@ -79,7 +160,8 @@ public:
|
||||
return _remotePort;
|
||||
}
|
||||
|
||||
void AsyncRead()
|
||||
template <SocketReadCallback Callback>
|
||||
void AsyncRead(Callback&& callback)
|
||||
{
|
||||
if (!IsOpen())
|
||||
return;
|
||||
@@ -87,18 +169,12 @@ public:
|
||||
_readBuffer.Normalize();
|
||||
_readBuffer.EnsureFreeSpace();
|
||||
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
|
||||
std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
|
||||
}
|
||||
|
||||
void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t))
|
||||
{
|
||||
if (!IsOpen())
|
||||
return;
|
||||
|
||||
_readBuffer.Normalize();
|
||||
_readBuffer.EnsureFreeSpace();
|
||||
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
|
||||
std::bind(callback, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
|
||||
[self = this->shared_from_this(), callback = std::forward<Callback>(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable
|
||||
{
|
||||
if (self->ReadHandlerInternal(error, transferredBytes))
|
||||
if (callback() == SocketReadCallbackResult::KeepReading)
|
||||
self->AsyncRead(std::forward<Callback>(callback));
|
||||
});
|
||||
}
|
||||
|
||||
void QueuePacket(MessageBuffer&& buffer)
|
||||
@@ -110,11 +186,11 @@ public:
|
||||
#endif
|
||||
}
|
||||
|
||||
bool IsOpen() const { return !_closed && !_closing; }
|
||||
bool IsOpen() const { return _openState == OpenState_Open; }
|
||||
|
||||
void CloseSocket()
|
||||
{
|
||||
if (_closed.exchange(true))
|
||||
if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0)
|
||||
return;
|
||||
|
||||
boost::system::error_code shutdownError;
|
||||
@@ -123,13 +199,13 @@ public:
|
||||
TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress().to_string(),
|
||||
shutdownError.value(), shutdownError.message());
|
||||
|
||||
OnClose();
|
||||
this->OnClose();
|
||||
}
|
||||
|
||||
/// Marks the socket for closing after write buffer becomes empty
|
||||
void DelayedCloseSocket()
|
||||
{
|
||||
if (_closing.exchange(true))
|
||||
if (_openState.fetch_or(OpenState_Closing) != 0)
|
||||
return;
|
||||
|
||||
if (_writeQueue.empty())
|
||||
@@ -138,7 +214,7 @@ public:
|
||||
|
||||
MessageBuffer& GetReadBuffer() { return _readBuffer; }
|
||||
|
||||
tcp::socket& underlying_stream()
|
||||
Stream& underlying_stream()
|
||||
{
|
||||
return _socket;
|
||||
}
|
||||
@@ -146,7 +222,7 @@ public:
|
||||
protected:
|
||||
virtual void OnClose() { }
|
||||
|
||||
virtual void ReadHandler() = 0;
|
||||
virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; }
|
||||
|
||||
bool AsyncProcessQueue()
|
||||
{
|
||||
@@ -157,11 +233,17 @@ protected:
|
||||
|
||||
#ifdef TC_SOCKET_USE_IOCP
|
||||
MessageBuffer& buffer = _writeQueue.front();
|
||||
_socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), std::bind(&Socket<T>::WriteHandler,
|
||||
this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
|
||||
_socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()),
|
||||
[self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes)
|
||||
{
|
||||
self->WriteHandler(error, transferedBytes);
|
||||
});
|
||||
#else
|
||||
_socket.async_write_some(boost::asio::null_buffers(), std::bind(&Socket<T>::WriteHandlerWrapper,
|
||||
this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
|
||||
_socket.async_wait(boost::asio::socket_base::wait_type::wait_write,
|
||||
[self = this->shared_from_this()](boost::system::error_code const& error)
|
||||
{
|
||||
self->WriteHandlerWrapper(error);
|
||||
});
|
||||
#endif
|
||||
|
||||
return false;
|
||||
@@ -177,16 +259,16 @@ protected:
|
||||
}
|
||||
|
||||
private:
|
||||
void ReadHandlerInternal(boost::system::error_code error, size_t transferredBytes)
|
||||
bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes)
|
||||
{
|
||||
if (error)
|
||||
{
|
||||
CloseSocket();
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
_readBuffer.WriteCompleted(transferredBytes);
|
||||
ReadHandler();
|
||||
return IsOpen();
|
||||
}
|
||||
|
||||
#ifdef TC_SOCKET_USE_IOCP
|
||||
@@ -202,7 +284,7 @@ private:
|
||||
|
||||
if (!_writeQueue.empty())
|
||||
AsyncProcessQueue();
|
||||
else if (_closing)
|
||||
else if (_openState == OpenState_Closing)
|
||||
CloseSocket();
|
||||
}
|
||||
else
|
||||
@@ -211,7 +293,7 @@ private:
|
||||
|
||||
#else
|
||||
|
||||
void WriteHandlerWrapper(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/)
|
||||
void WriteHandlerWrapper(boost::system::error_code const& /*error*/)
|
||||
{
|
||||
_isWritingAsync = false;
|
||||
HandleQueue();
|
||||
@@ -235,14 +317,14 @@ private:
|
||||
return AsyncProcessQueue();
|
||||
|
||||
_writeQueue.pop();
|
||||
if (_closing && _writeQueue.empty())
|
||||
if (_openState == OpenState_Closing && _writeQueue.empty())
|
||||
CloseSocket();
|
||||
return false;
|
||||
}
|
||||
else if (bytesSent == 0)
|
||||
{
|
||||
_writeQueue.pop();
|
||||
if (_closing && _writeQueue.empty())
|
||||
if (_openState == OpenState_Closing && _writeQueue.empty())
|
||||
CloseSocket();
|
||||
return false;
|
||||
}
|
||||
@@ -253,25 +335,30 @@ private:
|
||||
}
|
||||
|
||||
_writeQueue.pop();
|
||||
if (_closing && _writeQueue.empty())
|
||||
if (_openState == OpenState_Closing && _writeQueue.empty())
|
||||
CloseSocket();
|
||||
return !_writeQueue.empty();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
tcp::socket _socket;
|
||||
Stream _socket;
|
||||
|
||||
boost::asio::ip::address _remoteAddress;
|
||||
uint16 _remotePort;
|
||||
uint16 _remotePort = 0;
|
||||
|
||||
MessageBuffer _readBuffer;
|
||||
MessageBuffer _readBuffer = MessageBuffer(READ_BLOCK_SIZE);
|
||||
std::queue<MessageBuffer> _writeQueue;
|
||||
|
||||
std::atomic<bool> _closed;
|
||||
std::atomic<bool> _closing;
|
||||
// Socket open state "enum" (not enum to enable integral std::atomic api)
|
||||
static constexpr uint8 OpenState_Open = 0x0;
|
||||
static constexpr uint8 OpenState_Closing = 0x1; ///< Transition to Closed state after sending all queued data
|
||||
static constexpr uint8 OpenState_Closed = 0x2;
|
||||
|
||||
bool _isWritingAsync;
|
||||
std::atomic<uint8> _openState;
|
||||
|
||||
bool _isWritingAsync = false;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // __SOCKET_H__
|
||||
#endif // TRINITYCORE_SOCKET_H
|
||||
|
||||
@@ -15,34 +15,40 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef SocketMgr_h__
|
||||
#define SocketMgr_h__
|
||||
#ifndef TRINITYCORE_SOCKET_MGR_H
|
||||
#define TRINITYCORE_SOCKET_MGR_H
|
||||
|
||||
#include "AsyncAcceptor.h"
|
||||
#include "Errors.h"
|
||||
#include "NetworkThread.h"
|
||||
#include "Socket.h"
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <memory>
|
||||
|
||||
using boost::asio::ip::tcp;
|
||||
|
||||
namespace Trinity::Net
|
||||
{
|
||||
template<class SocketType>
|
||||
class SocketMgr
|
||||
{
|
||||
public:
|
||||
SocketMgr(SocketMgr const&) = delete;
|
||||
SocketMgr(SocketMgr&&) = delete;
|
||||
SocketMgr& operator=(SocketMgr const&) = delete;
|
||||
SocketMgr& operator=(SocketMgr&&) = delete;
|
||||
|
||||
virtual ~SocketMgr()
|
||||
{
|
||||
ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction");
|
||||
}
|
||||
|
||||
virtual bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
|
||||
virtual bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
|
||||
{
|
||||
ASSERT(threadCount > 0);
|
||||
|
||||
AsyncAcceptor* acceptor = nullptr;
|
||||
std::unique_ptr<AsyncAcceptor> acceptor = nullptr;
|
||||
try
|
||||
{
|
||||
acceptor = new AsyncAcceptor(ioContext, bindIp, port);
|
||||
acceptor = std::make_unique<AsyncAcceptor>(ioContext, bindIp, port);
|
||||
}
|
||||
catch (boost::system::system_error const& err)
|
||||
{
|
||||
@@ -53,13 +59,12 @@ public:
|
||||
if (!acceptor->Bind())
|
||||
{
|
||||
TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor");
|
||||
delete acceptor;
|
||||
return false;
|
||||
}
|
||||
|
||||
_acceptor = acceptor;
|
||||
_acceptor = std::move(acceptor);
|
||||
_threadCount = threadCount;
|
||||
_threads = CreateThreads();
|
||||
_threads.reset(CreateThreads());
|
||||
|
||||
ASSERT(_threads);
|
||||
|
||||
@@ -75,27 +80,23 @@ public:
|
||||
{
|
||||
_acceptor->Close();
|
||||
|
||||
if (_threadCount != 0)
|
||||
for (int32 i = 0; i < _threadCount; ++i)
|
||||
_threads[i].Stop();
|
||||
for (int32 i = 0; i < _threadCount; ++i)
|
||||
_threads[i].Stop();
|
||||
|
||||
Wait();
|
||||
|
||||
delete _acceptor;
|
||||
_acceptor = nullptr;
|
||||
delete[] _threads;
|
||||
_threads = nullptr;
|
||||
_threadCount = 0;
|
||||
}
|
||||
|
||||
void Wait()
|
||||
{
|
||||
if (_threadCount != 0)
|
||||
for (int32 i = 0; i < _threadCount; ++i)
|
||||
_threads[i].Wait();
|
||||
for (int32 i = 0; i < _threadCount; ++i)
|
||||
_threads[i].Wait();
|
||||
}
|
||||
|
||||
virtual void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
|
||||
virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex)
|
||||
{
|
||||
try
|
||||
{
|
||||
@@ -123,22 +124,23 @@ public:
|
||||
return min;
|
||||
}
|
||||
|
||||
std::pair<tcp::socket*, uint32> GetSocketForAccept()
|
||||
std::pair<IoContextTcpSocket*, uint32> GetSocketForAccept()
|
||||
{
|
||||
uint32 threadIndex = SelectThreadWithMinConnections();
|
||||
return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
|
||||
}
|
||||
|
||||
protected:
|
||||
SocketMgr() : _acceptor(nullptr), _threads(nullptr), _threadCount(0)
|
||||
SocketMgr() : _threadCount(0)
|
||||
{
|
||||
}
|
||||
|
||||
virtual NetworkThread<SocketType>* CreateThreads() const = 0;
|
||||
|
||||
AsyncAcceptor* _acceptor;
|
||||
NetworkThread<SocketType>* _threads;
|
||||
std::unique_ptr<AsyncAcceptor> _acceptor;
|
||||
std::unique_ptr<NetworkThread<SocketType>[]> _threads;
|
||||
int32 _threadCount;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // SocketMgr_h__
|
||||
#endif // TRINITYCORE_SOCKET_MGR_H
|
||||
|
||||
131
src/server/shared/Networking/SslStream.h
Normal file
131
src/server/shared/Networking/SslStream.h
Normal file
@@ -0,0 +1,131 @@
|
||||
/*
|
||||
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License as published by the
|
||||
* Free Software Foundation; either version 2 of the License, or (at your
|
||||
* option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
|
||||
* more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License along
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef TRINITYCORE_SSL_STREAM_H
|
||||
#define TRINITYCORE_SSL_STREAM_H
|
||||
|
||||
#include "SocketConnectionInitializer.h"
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/ssl/stream.hpp>
|
||||
#include <boost/system/error_code.hpp>
|
||||
|
||||
namespace Trinity::Net
|
||||
{
|
||||
template <typename SocketImpl>
|
||||
struct SslHandshakeConnectionInitializer final : SocketConnectionInitializer
|
||||
{
|
||||
explicit SslHandshakeConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
|
||||
|
||||
void Start() override
|
||||
{
|
||||
_socket->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
|
||||
[socketRef = _socket->weak_from_this(), self = this->shared_from_this()](boost::system::error_code const& error)
|
||||
{
|
||||
std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock());
|
||||
if (!socket)
|
||||
return;
|
||||
|
||||
if (error)
|
||||
{
|
||||
TC_LOG_ERROR("session", "{} SSL Handshake failed {}", socket->GetClientInfo(), error.message());
|
||||
socket->CloseSocket();
|
||||
return;
|
||||
}
|
||||
|
||||
if (self->next)
|
||||
self->next->Start();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
SocketImpl* _socket;
|
||||
};
|
||||
|
||||
template<class WrappedStream = IoContextTcpSocket>
|
||||
class SslStream
|
||||
{
|
||||
public:
|
||||
explicit SslStream(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext)
|
||||
{
|
||||
_sslSocket.set_verify_mode(boost::asio::ssl::verify_none);
|
||||
}
|
||||
|
||||
explicit SslStream(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) : _sslSocket(context, sslContext)
|
||||
{
|
||||
_sslSocket.set_verify_mode(boost::asio::ssl::verify_none);
|
||||
}
|
||||
|
||||
// adapting tcp::socket api
|
||||
void close(boost::system::error_code& error)
|
||||
{
|
||||
_sslSocket.next_layer().close(error);
|
||||
}
|
||||
|
||||
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
|
||||
{
|
||||
_sslSocket.shutdown(shutdownError);
|
||||
_sslSocket.next_layer().shutdown(what, shutdownError);
|
||||
}
|
||||
|
||||
template<typename MutableBufferSequence, typename ReadHandlerType>
|
||||
void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler)
|
||||
{
|
||||
_sslSocket.async_read_some(buffers, std::forward<ReadHandlerType>(handler));
|
||||
}
|
||||
|
||||
template<typename ConstBufferSequence, typename WriteHandlerType>
|
||||
void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler)
|
||||
{
|
||||
_sslSocket.async_write_some(buffers, std::forward<WriteHandlerType>(handler));
|
||||
}
|
||||
|
||||
template<typename ConstBufferSequence>
|
||||
std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error)
|
||||
{
|
||||
return _sslSocket.write_some(buffers, error);
|
||||
}
|
||||
|
||||
template<typename WaitHandlerType>
|
||||
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
|
||||
{
|
||||
_sslSocket.next_layer().async_wait(type, std::forward<WaitHandlerType>(handler));
|
||||
}
|
||||
|
||||
template<typename SettableSocketOption>
|
||||
void set_option(SettableSocketOption const& option, boost::system::error_code& error)
|
||||
{
|
||||
_sslSocket.next_layer().set_option(option, error);
|
||||
}
|
||||
|
||||
IoContextTcpSocket::endpoint_type remote_endpoint() const
|
||||
{
|
||||
return _sslSocket.next_layer().remote_endpoint();
|
||||
}
|
||||
|
||||
// ssl api
|
||||
template<typename HandshakeHandlerType>
|
||||
void async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler)
|
||||
{
|
||||
_sslSocket.async_handshake(type, std::forward<HandshakeHandlerType>(handler));
|
||||
}
|
||||
|
||||
protected:
|
||||
boost::asio::ssl::stream<WrappedStream> _sslSocket;
|
||||
};
|
||||
}
|
||||
|
||||
#endif // TRINITYCORE_SSL_STREAM_H
|
||||
@@ -15,10 +15,6 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
/// \addtogroup Trinityd Trinity Daemon
|
||||
/// @{
|
||||
/// \file
|
||||
|
||||
#include "Common.h"
|
||||
#include "AppenderDB.h"
|
||||
#include "AsyncAcceptor.h"
|
||||
@@ -117,7 +113,7 @@ private:
|
||||
};
|
||||
|
||||
void SignalHandler(boost::system::error_code const& error, int signalNumber);
|
||||
AsyncAcceptor* StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext);
|
||||
std::unique_ptr<Trinity::Net::AsyncAcceptor> StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext);
|
||||
bool StartDB();
|
||||
void StopDB();
|
||||
void WorldUpdateLoop();
|
||||
@@ -336,9 +332,9 @@ int main(int argc, char** argv)
|
||||
});
|
||||
|
||||
// Start the Remote Access port (acceptor) if enabled
|
||||
std::unique_ptr<AsyncAcceptor> raAcceptor;
|
||||
std::unique_ptr<Trinity::Net::AsyncAcceptor> raAcceptor;
|
||||
if (sConfigMgr->GetBoolDefault("Ra.Enable", false))
|
||||
raAcceptor.reset(StartRaSocketAcceptor(*ioContext));
|
||||
raAcceptor = StartRaSocketAcceptor(*ioContext);
|
||||
|
||||
// Start soap serving thread if enabled
|
||||
std::shared_ptr<std::thread> soapThread;
|
||||
@@ -585,20 +581,23 @@ void FreezeDetector::Handler(std::weak_ptr<FreezeDetector> freezeDetectorRef, bo
|
||||
}
|
||||
}
|
||||
|
||||
AsyncAcceptor* StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext)
|
||||
std::unique_ptr<Trinity::Net::AsyncAcceptor> StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext)
|
||||
{
|
||||
uint16 raPort = uint16(sConfigMgr->GetIntDefault("Ra.Port", 3443));
|
||||
std::string raListener = sConfigMgr->GetStringDefault("Ra.IP", "0.0.0.0");
|
||||
|
||||
AsyncAcceptor* acceptor = new AsyncAcceptor(ioContext, raListener, raPort);
|
||||
std::unique_ptr<Trinity::Net::AsyncAcceptor> acceptor = std::make_unique<Trinity::Net::AsyncAcceptor>(ioContext, raListener, raPort);
|
||||
if (!acceptor->Bind())
|
||||
{
|
||||
TC_LOG_ERROR("server.worldserver", "Failed to bind RA socket acceptor");
|
||||
delete acceptor;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
acceptor->AsyncAccept<RASession>();
|
||||
acceptor->AsyncAccept([](Trinity::Net::IoContextTcpSocket&& sock, uint32 /*threadIndex*/)
|
||||
{
|
||||
std::make_shared<RASession>(std::move(sock))->Start();
|
||||
|
||||
});
|
||||
return acceptor;
|
||||
}
|
||||
|
||||
@@ -708,8 +707,6 @@ void ClearOnlineAccounts()
|
||||
CharacterDatabase.DirectExecute("UPDATE character_battleground_data SET instanceId = 0");
|
||||
}
|
||||
|
||||
/// @}
|
||||
|
||||
variables_map GetConsoleArguments(int argc, char** argv, fs::path& configFile, fs::path& configDir, [[maybe_unused]] std::string& winServiceAction)
|
||||
{
|
||||
options_description all("Allowed options");
|
||||
|
||||
@@ -33,9 +33,11 @@ using boost::asio::ip::tcp;
|
||||
|
||||
void RASession::Start()
|
||||
{
|
||||
_socket.non_blocking(false);
|
||||
|
||||
// wait 1 second for active connections to send negotiation request
|
||||
for (int counter = 0; counter < 10 && _socket.available() == 0; counter++)
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
std::this_thread::sleep_for(100ms);
|
||||
|
||||
// Check if there are bytes available, if they are, then the client is requesting the negotiation
|
||||
if (_socket.available() > 0)
|
||||
|
||||
@@ -15,10 +15,11 @@
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef __RASESSION_H__
|
||||
#define __RASESSION_H__
|
||||
#ifndef TRINITYCORE_RA_SESSION_H
|
||||
#define TRINITYCORE_RA_SESSION_H
|
||||
|
||||
#include <memory>
|
||||
#include "Socket.h"
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/streambuf.hpp>
|
||||
#include "Common.h"
|
||||
@@ -32,7 +33,7 @@ const size_t bufferSize = 4096;
|
||||
class RASession : public std::enable_shared_from_this <RASession>
|
||||
{
|
||||
public:
|
||||
RASession(tcp::socket&& socket) : _socket(std::move(socket)), _commandExecuting(nullptr)
|
||||
RASession(Trinity::Net::IoContextTcpSocket&& socket) : _socket(std::move(socket)), _commandExecuting(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -51,7 +52,7 @@ private:
|
||||
static void CommandPrint(void* callbackArg, std::string_view text);
|
||||
static void CommandFinished(void* callbackArg, bool);
|
||||
|
||||
tcp::socket _socket;
|
||||
Trinity::Net::IoContextTcpSocket _socket;
|
||||
boost::asio::streambuf _readBuffer;
|
||||
boost::asio::streambuf _writeBuffer;
|
||||
std::promise<void>* _commandExecuting;
|
||||
|
||||
Reference in New Issue
Block a user