diff --git a/src/common/Utilities/Containers.h b/src/common/Utilities/Containers.h index a55602a3163..64c1e4e7958 100644 --- a/src/common/Utilities/Containers.h +++ b/src/common/Utilities/Containers.h @@ -207,7 +207,7 @@ namespace Trinity if (!p(*rpos)) { if (rpos != wpos) - std::swap(*rpos, *wpos); + std::ranges::swap(*rpos, *wpos); ++wpos; } } diff --git a/src/server/authserver/Server/AuthSession.cpp b/src/server/authserver/Server/AuthSession.cpp index 54b0cbc15d9..24044f64601 100644 --- a/src/server/authserver/Server/AuthSession.cpp +++ b/src/server/authserver/Server/AuthSession.cpp @@ -27,6 +27,7 @@ #include "DatabaseEnv.h" #include "IPLocation.h" #include "IoContext.h" +#include "IpBanCheckConnectionInitializer.h" #include "Log.h" #include "RealmList.h" #include "SecretMgr.h" @@ -199,21 +200,23 @@ void AccountInfo::LoadResult(Field* fields) Utf8ToUpperOnlyLatin(Login); } -AuthSession::AuthSession(tcp::socket&& socket) : Socket(std::move(socket)), - _timeout(*underlying_stream().get_executor().target()), +AuthSession::AuthSession(Trinity::Net::IoContextTcpSocket&& socket) : Socket(std::move(socket)), + _timeout(underlying_stream().get_executor()), _status(STATUS_CHALLENGE), _locale(LOCALE_enUS), _os(0), _build(0), _expversion(0), _timezoneOffset(0min) { } void AuthSession::Start() { - std::string ip_address = GetRemoteIpAddress().to_string(); - TC_LOG_TRACE("session", "Accepted connection from {}", ip_address); + // build initializer chain + std::array, 3> initializers = + { { + std::make_shared>(this), + std::make_shared>(this), + } }; - LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); - stmt->setString(0, ip_address); - - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback(std::bind(&AuthSession::CheckIpCallback, this, std::placeholders::_1))); + Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start(); + SetTimeout(); } bool AuthSession::Update() @@ -226,36 +229,7 @@ bool AuthSession::Update() return true; } -void AuthSession::CheckIpCallback(PreparedQueryResult result) -{ - if (result) - { - bool banned = false; - do - { - Field* fields = result->Fetch(); - if (fields[0].GetUInt64() != 0) - banned = true; - - } while (result->NextRow()); - - if (banned) - { - ByteBuffer pkt; - pkt << uint8(AUTH_LOGON_CHALLENGE); - pkt << uint8(0x00); - pkt << uint8(WOW_FAIL_BANNED); - SendPacket(pkt); - TC_LOG_DEBUG("session", "[AuthSession::CheckIpCallback] Banned ip '{}:{}' tries to login!", GetRemoteIpAddress().to_string(), GetRemotePort()); - return; - } - } - - AsyncRead(); - SetTimeout(); -} - -void AuthSession::ReadHandler() +Trinity::Net::SocketReadCallbackResult AuthSession::ReadHandler() { MessageBuffer& packet = GetReadBuffer(); while (packet.GetActiveSize()) @@ -265,7 +239,7 @@ void AuthSession::ReadHandler() if (!itr || _status != itr->status) { CloseSocket(); - return; + return Trinity::Net::SocketReadCallbackResult::Stop; } std::size_t size = itr->packetSize; @@ -279,7 +253,7 @@ void AuthSession::ReadHandler() if (size > MAX_ACCEPTED_CHALLENGE_SIZE) { CloseSocket(); - return; + return Trinity::Net::SocketReadCallbackResult::Stop; } } @@ -289,14 +263,19 @@ void AuthSession::ReadHandler() if (!itr->handler(this)) { CloseSocket(); - return; + return Trinity::Net::SocketReadCallbackResult::Stop; } packet.ReadCompleted(size); SetTimeout(); } - AsyncRead(); + return Trinity::Net::SocketReadCallbackResult::KeepReading; +} + +void AuthSession::QueueQuery(QueryCallback&& queryCallback) +{ + _queryProcessor.AddCallback(std::move(queryCallback)); } void AuthSession::SendPacket(ByteBuffer& packet) @@ -334,7 +313,7 @@ bool AuthSession::HandleLogonChallenge() LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_LOGONCHALLENGE); stmt->setStringView(0, login); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt) + QueueQuery(LoginDatabase.AsyncQuery(stmt) .WithPreparedCallback([this](PreparedQueryResult result) { LogonChallengeCallback(std::move(result)); })); return true; } @@ -546,7 +525,7 @@ bool AuthSession::HandleLogonProof() stmt->setStringView(3, ClientBuild::ToCharArray(_os).data()); stmt->setInt16(4, _timezoneOffset.count()); stmt->setString(5, _accountInfo.Login); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt) + QueueQuery(LoginDatabase.AsyncQuery(stmt) .WithPreparedCallback([this, M2 = Trinity::Crypto::SRP6::GetSessionVerifier(logonProof->A, logonProof->clientM, _sessionKey)](PreparedQueryResult const&) { // Finish SRP6 and send the final result to the client @@ -665,7 +644,7 @@ bool AuthSession::HandleReconnectChallenge() LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_RECONNECTCHALLENGE); stmt->setStringView(0, login); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt) + QueueQuery(LoginDatabase.AsyncQuery(stmt) .WithPreparedCallback([this](PreparedQueryResult result) { ReconnectChallengeCallback(std::move(result)); })); return true; } @@ -748,7 +727,7 @@ bool AuthSession::HandleRealmList() LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_REALM_CHARACTER_COUNTS); stmt->setUInt32(0, _accountInfo.Id); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback(std::bind(&AuthSession::RealmListCallback, this, std::placeholders::_1))); + QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback(std::bind(&AuthSession::RealmListCallback, this, std::placeholders::_1))); _status = STATUS_WAITING_FOR_REALM_LIST; return true; } @@ -922,7 +901,7 @@ void AuthSession::SetTimeout() _timeout.async_wait([selfRef = weak_from_this()](boost::system::error_code const& error) { - std::shared_ptr self = selfRef.lock(); + std::shared_ptr self = static_pointer_cast(selfRef.lock()); if (!self) return; diff --git a/src/server/authserver/Server/AuthSession.h b/src/server/authserver/Server/AuthSession.h index 400d08c5a16..97f7d4cd6b8 100644 --- a/src/server/authserver/Server/AuthSession.h +++ b/src/server/authserver/Server/AuthSession.h @@ -61,20 +61,21 @@ struct AccountInfo AccountTypes SecurityLevel = SEC_PLAYER; }; -class AuthSession : public Socket +class AuthSession final : public Trinity::Net::Socket<> { - typedef Socket AuthSocket; + using AuthSocket = Socket; public: - AuthSession(tcp::socket&& socket); + AuthSession(Trinity::Net::IoContextTcpSocket&& socket); void Start() override; bool Update() override; void SendPacket(ByteBuffer& packet); -protected: - void ReadHandler() override; + Trinity::Net::SocketReadCallbackResult ReadHandler() override; + + void QueueQuery(QueryCallback&& queryCallback); private: friend AuthHandlerTable; @@ -87,7 +88,6 @@ private: bool HandleXferResume(); bool HandleXferCancel(); - void CheckIpCallback(PreparedQueryResult result); void LogonChallengeCallback(PreparedQueryResult result); void ReconnectChallengeCallback(PreparedQueryResult result); void RealmListCallback(PreparedQueryResult result); diff --git a/src/server/authserver/Server/AuthSocketMgr.h b/src/server/authserver/Server/AuthSocketMgr.h index ac2cf85e4e6..dea1b41cc81 100644 --- a/src/server/authserver/Server/AuthSocketMgr.h +++ b/src/server/authserver/Server/AuthSocketMgr.h @@ -21,7 +21,7 @@ #include "SocketMgr.h" #include "AuthSession.h" -class AuthSocketMgr : public SocketMgr +class AuthSocketMgr : public Trinity::Net::SocketMgr { typedef SocketMgr BaseSocketMgr; @@ -37,19 +37,17 @@ public: if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount)) return false; - _acceptor->AsyncAcceptWithCallback<&AuthSocketMgr::OnSocketAccept>(); + _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) + { + OnSocketOpen(std::move(sock), threadIndex); + }); return true; } protected: - NetworkThread* CreateThreads() const override + Trinity::Net::NetworkThread* CreateThreads() const override { - return new NetworkThread[1]; - } - - static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex) - { - Instance().OnSocketOpen(std::forward(sock), threadIndex); + return new Trinity::Net::NetworkThread[1]; } }; diff --git a/src/server/game/Scripting/ScriptMgr.cpp b/src/server/game/Scripting/ScriptMgr.cpp index 7761e1bbef3..e1f23a2494f 100644 --- a/src/server/game/Scripting/ScriptMgr.cpp +++ b/src/server/game/Scripting/ScriptMgr.cpp @@ -1274,14 +1274,14 @@ void ScriptMgr::OnNetworkStop() FOREACH_SCRIPT(ServerScript)->OnNetworkStop(); } -void ScriptMgr::OnSocketOpen(std::shared_ptr socket) +void ScriptMgr::OnSocketOpen(std::shared_ptr const& socket) { ASSERT(socket); FOREACH_SCRIPT(ServerScript)->OnSocketOpen(socket); } -void ScriptMgr::OnSocketClose(std::shared_ptr socket) +void ScriptMgr::OnSocketClose(std::shared_ptr const& socket) { ASSERT(socket); diff --git a/src/server/game/Scripting/ScriptMgr.h b/src/server/game/Scripting/ScriptMgr.h index 42ffefbf7f5..c50e63e79fc 100644 --- a/src/server/game/Scripting/ScriptMgr.h +++ b/src/server/game/Scripting/ScriptMgr.h @@ -884,8 +884,8 @@ class TC_GAME_API ScriptMgr void OnNetworkStart(); void OnNetworkStop(); - void OnSocketOpen(std::shared_ptr socket); - void OnSocketClose(std::shared_ptr socket); + void OnSocketOpen(std::shared_ptr const& socket); + void OnSocketClose(std::shared_ptr const& socket); void OnPacketReceive(WorldSession* session, WorldPacket const& packet); void OnPacketSend(WorldSession* session, WorldPacket const& packet); diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp index 64307b823f3..4fab67d923a 100644 --- a/src/server/game/Server/WorldSocket.cpp +++ b/src/server/game/Server/WorldSocket.cpp @@ -23,7 +23,7 @@ #include "CryptoHash.h" #include "CryptoRandom.h" #include "IPLocation.h" -#include "Opcodes.h" +#include "IpBanCheckConnectionInitializer.h" #include "PacketLog.h" #include "Random.h" #include "RBAC.h" @@ -33,10 +33,7 @@ #include "WorldSession.h" #include -using boost::asio::ip::tcp; - -WorldSocket::WorldSocket(tcp::socket&& socket) - : Socket(std::move(socket)), _OverSpeedPings(0), _worldSession(nullptr), _authed(false), _sendBufferSize(4096) +WorldSocket::WorldSocket(Trinity::Net::IoContextTcpSocket&& socket) : BaseSocket(std::move(socket)), _OverSpeedPings(0), _worldSession(nullptr), _authed(false), _sendBufferSize(4096) { Trinity::Crypto::GetRandomBytes(_authSeed); _headerBuffer.Resize(sizeof(ClientPktHeader)); @@ -44,39 +41,33 @@ WorldSocket::WorldSocket(tcp::socket&& socket) WorldSocket::~WorldSocket() = default; -void WorldSocket::Start() +struct WorldSocketProtocolInitializer final : Trinity::Net::SocketConnectionInitializer { - std::string ip_address = GetRemoteIpAddress().to_string(); - LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); - stmt->setString(0, ip_address); + explicit WorldSocketProtocolInitializer(WorldSocket* socket) : _socket(socket) { } - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback(std::bind(&WorldSocket::CheckIpCallback, this, std::placeholders::_1))); -} - -void WorldSocket::CheckIpCallback(PreparedQueryResult result) -{ - if (result) + void Start() override { - bool banned = false; - do - { - Field* fields = result->Fetch(); - if (fields[0].GetUInt64() != 0) - banned = true; + _socket->SendAuthSession(); - } while (result->NextRow()); - - if (banned) - { - SendAuthResponseError(AUTH_REJECT); - TC_LOG_ERROR("network", "WorldSocket::CheckIpCallback: Sent Auth Response (IP {} banned).", GetRemoteIpAddress().to_string()); - DelayedCloseSocket(); - return; - } + if (this->next) + this->next->Start(); } - AsyncRead(); - HandleSendAuthSession(); +private: + WorldSocket* _socket; +}; + +void WorldSocket::Start() +{ + // build initializer chain + std::array, 3> initializers = + { { + std::make_shared>(this), + std::make_shared(this), + std::make_shared>(this), + } }; + + Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start(); } bool WorldSocket::Update() @@ -129,7 +120,7 @@ bool WorldSocket::Update() return true; } -void WorldSocket::HandleSendAuthSession() +void WorldSocket::SendAuthSession() { WorldPacket packet(SMSG_AUTH_CHALLENGE, 40); packet << uint32(1); // 1...31 @@ -148,11 +139,8 @@ void WorldSocket::OnClose() } } -void WorldSocket::ReadHandler() +Trinity::Net::SocketReadCallbackResult WorldSocket::ReadHandler() { - if (!IsOpen()) - return; - MessageBuffer& packet = GetReadBuffer(); while (packet.GetActiveSize() > 0) { @@ -174,7 +162,7 @@ void WorldSocket::ReadHandler() if (!ReadHeaderHandler()) { CloseSocket(); - return; + return Trinity::Net::SocketReadCallbackResult::Stop; } } @@ -202,11 +190,16 @@ void WorldSocket::ReadHandler() if (result != ReadDataHandlerResult::WaitingForQuery) CloseSocket(); - return; + return Trinity::Net::SocketReadCallbackResult::Stop; } } - AsyncRead(); + return Trinity::Net::SocketReadCallbackResult::KeepReading; +} + +void WorldSocket::QueueQuery(QueryCallback&& queryCallback) +{ + _queryProcessor.AddCallback(std::move(queryCallback)); } bool WorldSocket::ReadHeaderHandler() @@ -398,14 +391,13 @@ WorldSocket::ReadDataHandlerResult WorldSocket::ReadDataHandler() void WorldSocket::LogOpcodeText(OpcodeClient opcode, std::unique_lock const& guard) const { - if (!guard) + if (!guard || !_worldSession) { TC_LOG_TRACE("network.opcode", "C->S: {} {}", GetRemoteIpAddress().to_string(), GetOpcodeNameForLogging(opcode)); } else { - TC_LOG_TRACE("network.opcode", "C->S: {} {}", (_worldSession ? _worldSession->GetPlayerInfo() : GetRemoteIpAddress().to_string()), - GetOpcodeNameForLogging(opcode)); + TC_LOG_TRACE("network.opcode", "C->S: {} {}", _worldSession->GetPlayerInfo(), GetOpcodeNameForLogging(opcode)); } } @@ -449,7 +441,7 @@ void WorldSocket::HandleAuthSession(WorldPacket& recvPacket) stmt->setInt32(0, int32(realm.Id.Realm)); stmt->setString(1, authSession->Account); - _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession)](PreparedQueryResult result) mutable + QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback([this, authSession = std::move(authSession)](PreparedQueryResult result) mutable { HandleAuthSessionCallback(std::move(authSession), std::move(result)); })); @@ -613,16 +605,18 @@ void WorldSocket::HandleAuthSessionCallback(std::shared_ptr authSes sScriptMgr->OnAccountLogin(account.Id); _authed = true; - _worldSession = new WorldSession(account.Id, std::move(authSession->Account), shared_from_this(), account.Security, - account.Expansion, mutetime, account.TimezoneOffset, account.Locale, account.Recruiter, account.IsRectuiter); + _worldSession = new WorldSession(account.Id, std::move(authSession->Account), + static_pointer_cast(shared_from_this()), account.Security, account.Expansion, mutetime, + account.TimezoneOffset, account.Locale, + account.Recruiter, account.IsRectuiter); _worldSession->ReadAddonsInfo(authSession->AddonInfo); // Initialize Warden system only if it is enabled by config if (wardenActive) _worldSession->InitWarden(account.SessionKey, account.OS); - _queryProcessor.AddCallback(_worldSession->LoadPermissionsAsync().WithPreparedCallback(std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1))); - AsyncRead(); + QueueQuery(_worldSession->LoadPermissionsAsync().WithPreparedCallback(std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1))); + AsyncRead(Trinity::Net::InvokeReadHandlerCallback{ .Socket = this }); } void WorldSocket::LoadSessionPermissionsCallback(PreparedQueryResult result) diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h index 1cec1e5cd1f..2e7928f754d 100644 --- a/src/server/game/Server/WorldSocket.h +++ b/src/server/game/Server/WorldSocket.h @@ -15,8 +15,8 @@ * with this program. If not, see . */ -#ifndef __WORLDSOCKET_H__ -#define __WORLDSOCKET_H__ +#ifndef TRINITYCORE_WORLD_SOCKET_H +#define TRINITYCORE_WORLD_SOCKET_H #include "Common.h" #include "ServerPktHeader.h" @@ -64,16 +64,18 @@ struct ClientPktHeader struct AuthSession; -class TC_GAME_API WorldSocket : public Socket +class TC_GAME_API WorldSocket final : public Trinity::Net::Socket<> { - typedef Socket BaseSocket; + using BaseSocket = Socket; public: - WorldSocket(tcp::socket&& socket); + WorldSocket(Trinity::Net::IoContextTcpSocket&& socket); ~WorldSocket(); WorldSocket(WorldSocket const& right) = delete; + WorldSocket(WorldSocket&& right) = delete; WorldSocket& operator=(WorldSocket const& right) = delete; + WorldSocket& operator=(WorldSocket&& right) = delete; void Start() override; bool Update() override; @@ -82,9 +84,14 @@ public: void SetSendBufferSize(std::size_t sendBufferSize) { _sendBufferSize = sendBufferSize; } -protected: void OnClose() override; - void ReadHandler() override; + Trinity::Net::SocketReadCallbackResult ReadHandler() override; + + void QueueQuery(QueryCallback&& queryCallback); + + void SendAuthSession(); + +protected: bool ReadHeaderHandler(); enum class ReadDataHandlerResult @@ -97,14 +104,11 @@ protected: ReadDataHandlerResult ReadDataHandler(); private: - void CheckIpCallback(PreparedQueryResult result); - /// writes network.opcode log /// accessing WorldSession is not threadsafe, only do it when holding _worldSessionLock void LogOpcodeText(OpcodeClient opcode, std::unique_lock const& guard) const; /// sends and logs network.opcode without accessing WorldSession void SendPacketAndLogOpcode(WorldPacket const& packet); - void HandleSendAuthSession(); void HandleAuthSession(WorldPacket& recvPacket); void HandleAuthSessionCallback(std::shared_ptr authSession, PreparedQueryResult result); void LoadSessionPermissionsCallback(PreparedQueryResult result); diff --git a/src/server/game/Server/WorldSocketMgr.cpp b/src/server/game/Server/WorldSocketMgr.cpp index 61bbfb3a818..8a51a443131 100644 --- a/src/server/game/Server/WorldSocketMgr.cpp +++ b/src/server/game/Server/WorldSocketMgr.cpp @@ -23,21 +23,16 @@ #include -static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex) -{ - sWorldSocketMgr.OnSocketOpen(std::forward(sock), threadIndex); -} - -class WorldSocketThread : public NetworkThread +class WorldSocketThread : public Trinity::Net::NetworkThread { public: - void SocketAdded(std::shared_ptr sock) override + void SocketAdded(std::shared_ptr const& sock) override { sock->SetSendBufferSize(sWorldSocketMgr.GetApplicationSendBufferSize()); sScriptMgr->OnSocketOpen(sock); } - void SocketRemoved(std::shared_ptr sock) override + void SocketRemoved(std::shared_ptrconst& sock) override { sScriptMgr->OnSocketClose(sock); } @@ -74,7 +69,10 @@ bool WorldSocketMgr::StartWorldNetwork(Trinity::Asio::IoContext& ioContext, std: if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount)) return false; - _acceptor->AsyncAcceptWithCallback<&OnSocketAccept>(); + _acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) + { + OnSocketOpen(std::move(sock), threadIndex); + }); sScriptMgr->OnNetworkStart(); return true; @@ -87,7 +85,7 @@ void WorldSocketMgr::StopNetwork() sScriptMgr->OnNetworkStop(); } -void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) +void WorldSocketMgr::OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) { // set some options here if (_socketSystemSendBufferSize >= 0) @@ -115,10 +113,10 @@ void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) //sock->m_OutBufferSize = static_cast (m_SockOutUBuff); - BaseSocketMgr::OnSocketOpen(std::forward(sock), threadIndex); + BaseSocketMgr::OnSocketOpen(std::move(sock), threadIndex); } -NetworkThread* WorldSocketMgr::CreateThreads() const +Trinity::Net::NetworkThread* WorldSocketMgr::CreateThreads() const { return new WorldSocketThread[GetNetworkThreadCount()]; } diff --git a/src/server/game/Server/WorldSocketMgr.h b/src/server/game/Server/WorldSocketMgr.h index da3f62af0d0..986a94e6a22 100644 --- a/src/server/game/Server/WorldSocketMgr.h +++ b/src/server/game/Server/WorldSocketMgr.h @@ -15,21 +15,15 @@ * with this program. If not, see . */ -/** \addtogroup u2w User to World Communication - * @{ - * \file WorldSocketMgr.h - * \author Derex - */ - -#ifndef __WORLDSOCKETMGR_H -#define __WORLDSOCKETMGR_H +#ifndef TRINITYCORE_WORLD_SOCKET_MGR_H +#define TRINITYCORE_WORLD_SOCKET_MGR_H #include "SocketMgr.h" class WorldSocket; /// Manages all sockets connected to peers and network threads -class TC_GAME_API WorldSocketMgr : public SocketMgr +class TC_GAME_API WorldSocketMgr : public Trinity::Net::SocketMgr { typedef SocketMgr BaseSocketMgr; @@ -42,14 +36,14 @@ public: /// Stops all network threads, It will wait for all running threads . void StopNetwork() override; - void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) override; + void OnSocketOpen(Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex) override; std::size_t GetApplicationSendBufferSize() const { return _socketApplicationSendBufferSize; } protected: WorldSocketMgr(); - NetworkThread* CreateThreads() const override; + Trinity::Net::NetworkThread* CreateThreads() const override; private: int32 _socketSystemSendBufferSize; @@ -60,4 +54,3 @@ private: #define sWorldSocketMgr WorldSocketMgr::Instance() #endif -/// @} diff --git a/src/server/shared/Networking/AsyncAcceptor.h b/src/server/shared/Networking/AsyncAcceptor.h index e88330ebdba..efbb3c9a40d 100644 --- a/src/server/shared/Networking/AsyncAcceptor.h +++ b/src/server/shared/Networking/AsyncAcceptor.h @@ -15,12 +15,13 @@ * with this program. If not, see . */ -#ifndef __ASYNCACCEPT_H_ -#define __ASYNCACCEPT_H_ +#ifndef TRINITYCORE_ASYNC_ACCEPTOR_H +#define TRINITYCORE_ASYNC_ACCEPTOR_H #include "IoContext.h" #include "IpAddress.h" #include "Log.h" +#include "Socket.h" #include #include #include @@ -29,27 +30,28 @@ using boost::asio::ip::tcp; #define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections +namespace Trinity::Net +{ +template +concept AcceptCallback = std::invocable; + class AsyncAcceptor { public: - typedef void(*AcceptCallback)(tcp::socket&& newSocket, uint32 threadIndex); - - AsyncAcceptor(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) : - _acceptor(ioContext), _endpoint(Trinity::Net::make_address(bindIp), port), + AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) : + _acceptor(ioContext), _endpoint(make_address(bindIp), port), _socket(ioContext), _closed(false), _socketFactory(std::bind(&AsyncAcceptor::DefeaultSocketFactory, this)) { } - template - void AsyncAccept(); - - template - void AsyncAcceptWithCallback() + template + void AsyncAccept(Callback&& acceptCallback) { - tcp::socket* socket; - uint32 threadIndex; - std::tie(socket, threadIndex) = _socketFactory(); - _acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error) + auto [tmpSocket, tmpThreadIndex] = _socketFactory(); + // TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures) + IoContextTcpSocket* socket = tmpSocket; + uint32 threadIndex = tmpThreadIndex; + _acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward(acceptCallback)](boost::system::error_code const& error) mutable { if (!error) { @@ -66,7 +68,7 @@ public: } if (!_closed) - this->AsyncAcceptWithCallback(); + this->AsyncAccept(std::move(acceptCallback)); }); } @@ -115,40 +117,17 @@ public: _acceptor.close(err); } - void SetSocketFactory(std::function()> func) { _socketFactory = func; } + void SetSocketFactory(std::function()> func) { _socketFactory = std::move(func); } private: - std::pair DefeaultSocketFactory() { return std::make_pair(&_socket, 0); } + std::pair DefeaultSocketFactory() { return std::make_pair(&_socket, 0); } - tcp::acceptor _acceptor; - tcp::endpoint _endpoint; - tcp::socket _socket; + boost::asio::basic_socket_acceptor _acceptor; + boost::asio::ip::tcp::endpoint _endpoint; + IoContextTcpSocket _socket; std::atomic _closed; - std::function()> _socketFactory; + std::function()> _socketFactory; }; - -template -void AsyncAcceptor::AsyncAccept() -{ - _acceptor.async_accept(_socket, [this](boost::system::error_code error) - { - if (!error) - { - try - { - // this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class - std::make_shared(std::move(this->_socket))->Start(); - } - catch (boost::system::system_error const& err) - { - TC_LOG_INFO("network", "Failed to retrieve client's remote address {}", err.what()); - } - } - - // lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face - if (!_closed) - this->AsyncAccept(); - }); } -#endif /* __ASYNCACCEPT_H_ */ +#endif // TRINITYCORE_ASYNC_ACCEPTOR_H diff --git a/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp new file mode 100644 index 00000000000..696945e2ded --- /dev/null +++ b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp @@ -0,0 +1,42 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see . + */ + +#include "IpBanCheckConnectionInitializer.h" +#include "DatabaseEnv.h" + +QueryCallback Trinity::Net::IpBanCheckHelpers::AsyncQuery(std::string_view ipAddress) +{ + LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); + stmt->setStringView(0, ipAddress); + return LoginDatabase.AsyncQuery(stmt); +} + +bool Trinity::Net::IpBanCheckHelpers::IsBanned(PreparedQueryResult const& result) +{ + if (result) + { + do + { + Field* fields = result->Fetch(); + if (fields[0].GetUInt64() != 0) + return true; + + } while (result->NextRow()); + } + + return false; +} diff --git a/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h new file mode 100644 index 00000000000..ff8210a7f69 --- /dev/null +++ b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h @@ -0,0 +1,64 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see . + */ + +#ifndef TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H +#define TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H + +#include "DatabaseEnvFwd.h" +#include "Log.h" +#include "QueryCallback.h" +#include "SocketConnectionInitializer.h" + +namespace Trinity::Net +{ +namespace IpBanCheckHelpers +{ +TC_SHARED_API QueryCallback AsyncQuery(std::string_view ipAddress); +TC_SHARED_API bool IsBanned(PreparedQueryResult const& result); +} + +template +struct IpBanCheckConnectionInitializer final : SocketConnectionInitializer +{ + explicit IpBanCheckConnectionInitializer(SocketImpl* socket) : _socket(socket) { } + + void Start() override + { + _socket->QueueQuery(IpBanCheckHelpers::AsyncQuery(_socket->GetRemoteIpAddress().to_string()).WithPreparedCallback([socketRef = _socket->weak_from_this(), self = this->shared_from_this()](PreparedQueryResult const& result) + { + std::shared_ptr socket = static_pointer_cast(socketRef.lock()); + if (!socket) + return; + + if (IpBanCheckHelpers::IsBanned(result)) + { + TC_LOG_ERROR("network", "IpBanCheckConnectionInitializer: IP {} is banned.", socket->GetRemoteIpAddress().to_string()); + socket->DelayedCloseSocket(); + return; + } + + if (self->next) + self->next->Start(); + })); + } + +private: + SocketImpl* _socket; +}; +} + +#endif // TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H diff --git a/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h b/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h new file mode 100644 index 00000000000..d3f0bb16dbf --- /dev/null +++ b/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h @@ -0,0 +1,51 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see . + */ + +#ifndef TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H +#define TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H + +#include +#include + +namespace Trinity::Net +{ +struct SocketConnectionInitializer : public std::enable_shared_from_this +{ + SocketConnectionInitializer() = default; + + SocketConnectionInitializer(SocketConnectionInitializer const&) = delete; + SocketConnectionInitializer(SocketConnectionInitializer&&) noexcept = default; + SocketConnectionInitializer& operator=(SocketConnectionInitializer const&) = delete; + SocketConnectionInitializer& operator=(SocketConnectionInitializer&&) noexcept = default; + + virtual ~SocketConnectionInitializer() = default; + + virtual void Start() = 0; + + std::shared_ptr next; + + static std::shared_ptr& SetupChain(std::span> initializers) + { + for (std::size_t i = initializers.size(); i > 1; --i) + initializers[i - 2]->next.swap(initializers[i - 1]); + + return initializers[0]; + } +}; +} + +#endif // TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H diff --git a/src/server/shared/Networking/NetworkThread.h b/src/server/shared/Networking/NetworkThread.h index e149867c962..a4783a67525 100644 --- a/src/server/shared/Networking/NetworkThread.h +++ b/src/server/shared/Networking/NetworkThread.h @@ -15,14 +15,16 @@ * with this program. If not, see . */ -#ifndef NetworkThread_h__ -#define NetworkThread_h__ +#ifndef TRINITYCORE_NETWORK_THREAD_H +#define TRINITYCORE_NETWORK_THREAD_H -#include "Define.h" +#include "Containers.h" #include "DeadlineTimer.h" +#include "Define.h" #include "Errors.h" #include "IoContext.h" #include "Log.h" +#include "Socket.h" #include "Timer.h" #include #include @@ -32,8 +34,8 @@ #include #include -using boost::asio::ip::tcp; - +namespace Trinity::Net +{ template class NetworkThread { @@ -43,14 +45,16 @@ public: { } + NetworkThread(NetworkThread const&) = delete; + NetworkThread(NetworkThread&&) = delete; + NetworkThread& operator=(NetworkThread const&) = delete; + NetworkThread& operator=(NetworkThread&&) = delete; + virtual ~NetworkThread() { Stop(); if (_thread) - { Wait(); - delete _thread; - } } void Stop() @@ -64,7 +68,7 @@ public: if (_thread) return false; - _thread = new std::thread(&NetworkThread::Run, this); + _thread = std::make_unique(&NetworkThread::Run, this); return true; } @@ -73,7 +77,6 @@ public: ASSERT(_thread); _thread->join(); - delete _thread; _thread = nullptr; } @@ -87,15 +90,14 @@ public: std::lock_guard lock(_newSocketsLock); ++_connections; - _newSockets.push_back(sock); - SocketAdded(sock); + SocketAdded(_newSockets.emplace_back(std::move(sock))); } - tcp::socket* GetSocketForAccept() { return &_acceptSocket; } + Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; } protected: - virtual void SocketAdded(std::shared_ptr /*sock*/) { } - virtual void SocketRemoved(std::shared_ptr /*sock*/) { } + virtual void SocketAdded(std::shared_ptr const& /*sock*/) { } + virtual void SocketRemoved(std::shared_ptr const& /*sock*/) { } void AddNewSockets() { @@ -104,7 +106,7 @@ protected: if (_newSockets.empty()) return; - for (std::shared_ptr sock : _newSockets) + for (std::shared_ptr& sock : _newSockets) { if (!sock->IsOpen()) { @@ -112,7 +114,7 @@ protected: --_connections; } else - _sockets.push_back(sock); + _sockets.emplace_back(std::move(sock)); } _newSockets.clear(); @@ -141,7 +143,7 @@ protected: AddNewSockets(); - _sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr sock) + Trinity::Containers::EraseIf(_sockets, [this](std::shared_ptr const& sock) { if (!sock->Update()) { @@ -155,7 +157,7 @@ protected: } return false; - }), _sockets.end()); + }); } private: @@ -164,7 +166,7 @@ private: std::atomic _connections; std::atomic _stopped; - std::thread* _thread; + std::unique_ptr _thread; SocketContainer _sockets; @@ -172,8 +174,9 @@ private: SocketContainer _newSockets; Trinity::Asio::IoContext _ioContext; - tcp::socket _acceptSocket; + Trinity::Net::IoContextTcpSocket _acceptSocket; Trinity::Asio::DeadlineTimer _updateTimer; }; +} -#endif // NetworkThread_h__ +#endif // TRINITYCORE_NETWORK_THREAD_H diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h index 62edfa4d8cd..76e5947d7e8 100644 --- a/src/server/shared/Networking/Socket.h +++ b/src/server/shared/Networking/Socket.h @@ -15,17 +15,19 @@ * with this program. If not, see . */ -#ifndef __SOCKET_H__ -#define __SOCKET_H__ +#ifndef TRINITYCORE_SOCKET_H +#define TRINITYCORE_SOCKET_H -#include "MessageBuffer.h" +#include "Concepts.h" #include "Log.h" -#include -#include -#include -#include -#include +#include "MessageBuffer.h" +#include "SocketConnectionInitializer.h" +#include #include +#include +#include +#include +#include using boost::asio::ip::tcp; @@ -34,32 +36,111 @@ using boost::asio::ip::tcp; #define TC_SOCKET_USE_IOCP #endif -template -class Socket : public std::enable_shared_from_this +namespace Trinity::Net +{ +using IoContextTcpSocket = boost::asio::basic_stream_socket; + +enum class SocketReadCallbackResult +{ + KeepReading, + Stop +}; + +template +concept SocketReadCallback = Trinity::invocable_r; + +template +struct InvokeReadHandlerCallback +{ + SocketReadCallbackResult operator()() const + { + return this->Socket->ReadHandler(); + } + + SocketType* Socket; +}; + +template +struct ReadConnectionInitializer final : SocketConnectionInitializer +{ + explicit ReadConnectionInitializer(SocketType* socket) : ReadCallback({ .Socket = socket }) { } + + void Start() override + { + ReadCallback.Socket->AsyncRead(std::move(ReadCallback)); + + if (this->next) + this->next->Start(); + } + + InvokeReadHandlerCallback ReadCallback; +}; + +/** + @class Socket + + Base async socket implementation + + @tparam Stream stream type used for operations on socket + Stream must implement the following methods: + + void close(boost::system::error_code& error); + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError); + + template + void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler); + + template + void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler); + + template + std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error); + + template + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler); + + template + void set_option(SettableSocketOption const& option, boost::system::error_code& error); + + tcp::socket::endpoint_type remote_endpoint() const; +*/ +template +class Socket : public std::enable_shared_from_this> { public: - explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()), - _remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false) + template + explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward(args)...), + _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open) { - _readBuffer.Resize(READ_BLOCK_SIZE); } + template + explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward(args)...), _openState(OpenState_Closed) + { + } + + Socket(Socket const& other) = delete; + Socket(Socket&& other) = delete; + Socket& operator=(Socket const& other) = delete; + Socket& operator=(Socket&& other) = delete; + virtual ~Socket() { - _closed = true; + _openState = OpenState_Closed; boost::system::error_code error; _socket.close(error); } - virtual void Start() = 0; + virtual void Start() { } virtual bool Update() { - if (_closed) + if (_openState == OpenState_Closed) return false; #ifndef TC_SOCKET_USE_IOCP - if (_isWritingAsync || (_writeQueue.empty() && !_closing)) + if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open)) return true; for (; HandleQueue();) @@ -69,7 +150,7 @@ public: return true; } - boost::asio::ip::address GetRemoteIpAddress() const + boost::asio::ip::address const& GetRemoteIpAddress() const { return _remoteAddress; } @@ -79,7 +160,8 @@ public: return _remotePort; } - void AsyncRead() + template + void AsyncRead(Callback&& callback) { if (!IsOpen()) return; @@ -87,18 +169,12 @@ public: _readBuffer.Normalize(); _readBuffer.EnsureFreeSpace(); _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), - std::bind(&Socket::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); - } - - void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t)) - { - if (!IsOpen()) - return; - - _readBuffer.Normalize(); - _readBuffer.EnsureFreeSpace(); - _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), - std::bind(callback, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); + [self = this->shared_from_this(), callback = std::forward(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable + { + if (self->ReadHandlerInternal(error, transferredBytes)) + if (callback() == SocketReadCallbackResult::KeepReading) + self->AsyncRead(std::forward(callback)); + }); } void QueuePacket(MessageBuffer&& buffer) @@ -110,11 +186,11 @@ public: #endif } - bool IsOpen() const { return !_closed && !_closing; } + bool IsOpen() const { return _openState == OpenState_Open; } void CloseSocket() { - if (_closed.exchange(true)) + if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0) return; boost::system::error_code shutdownError; @@ -123,13 +199,13 @@ public: TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress().to_string(), shutdownError.value(), shutdownError.message()); - OnClose(); + this->OnClose(); } /// Marks the socket for closing after write buffer becomes empty void DelayedCloseSocket() { - if (_closing.exchange(true)) + if (_openState.fetch_or(OpenState_Closing) != 0) return; if (_writeQueue.empty()) @@ -138,7 +214,7 @@ public: MessageBuffer& GetReadBuffer() { return _readBuffer; } - tcp::socket& underlying_stream() + Stream& underlying_stream() { return _socket; } @@ -146,7 +222,7 @@ public: protected: virtual void OnClose() { } - virtual void ReadHandler() = 0; + virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; } bool AsyncProcessQueue() { @@ -157,11 +233,17 @@ protected: #ifdef TC_SOCKET_USE_IOCP MessageBuffer& buffer = _writeQueue.front(); - _socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), std::bind(&Socket::WriteHandler, - this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); + _socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), + [self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes) + { + self->WriteHandler(error, transferedBytes); + }); #else - _socket.async_write_some(boost::asio::null_buffers(), std::bind(&Socket::WriteHandlerWrapper, - this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); + _socket.async_wait(boost::asio::socket_base::wait_type::wait_write, + [self = this->shared_from_this()](boost::system::error_code const& error) + { + self->WriteHandlerWrapper(error); + }); #endif return false; @@ -177,16 +259,16 @@ protected: } private: - void ReadHandlerInternal(boost::system::error_code error, size_t transferredBytes) + bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes) { if (error) { CloseSocket(); - return; + return false; } _readBuffer.WriteCompleted(transferredBytes); - ReadHandler(); + return IsOpen(); } #ifdef TC_SOCKET_USE_IOCP @@ -202,7 +284,7 @@ private: if (!_writeQueue.empty()) AsyncProcessQueue(); - else if (_closing) + else if (_openState == OpenState_Closing) CloseSocket(); } else @@ -211,7 +293,7 @@ private: #else - void WriteHandlerWrapper(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/) + void WriteHandlerWrapper(boost::system::error_code const& /*error*/) { _isWritingAsync = false; HandleQueue(); @@ -235,14 +317,14 @@ private: return AsyncProcessQueue(); _writeQueue.pop(); - if (_closing && _writeQueue.empty()) + if (_openState == OpenState_Closing && _writeQueue.empty()) CloseSocket(); return false; } else if (bytesSent == 0) { _writeQueue.pop(); - if (_closing && _writeQueue.empty()) + if (_openState == OpenState_Closing && _writeQueue.empty()) CloseSocket(); return false; } @@ -253,25 +335,30 @@ private: } _writeQueue.pop(); - if (_closing && _writeQueue.empty()) + if (_openState == OpenState_Closing && _writeQueue.empty()) CloseSocket(); return !_writeQueue.empty(); } #endif - tcp::socket _socket; + Stream _socket; boost::asio::ip::address _remoteAddress; - uint16 _remotePort; + uint16 _remotePort = 0; - MessageBuffer _readBuffer; + MessageBuffer _readBuffer = MessageBuffer(READ_BLOCK_SIZE); std::queue _writeQueue; - std::atomic _closed; - std::atomic _closing; + // Socket open state "enum" (not enum to enable integral std::atomic api) + static constexpr uint8 OpenState_Open = 0x0; + static constexpr uint8 OpenState_Closing = 0x1; ///< Transition to Closed state after sending all queued data + static constexpr uint8 OpenState_Closed = 0x2; - bool _isWritingAsync; + std::atomic _openState; + + bool _isWritingAsync = false; }; +} -#endif // __SOCKET_H__ +#endif // TRINITYCORE_SOCKET_H diff --git a/src/server/shared/Networking/SocketMgr.h b/src/server/shared/Networking/SocketMgr.h index 31ec4e6390b..07252355308 100644 --- a/src/server/shared/Networking/SocketMgr.h +++ b/src/server/shared/Networking/SocketMgr.h @@ -15,34 +15,40 @@ * with this program. If not, see . */ -#ifndef SocketMgr_h__ -#define SocketMgr_h__ +#ifndef TRINITYCORE_SOCKET_MGR_H +#define TRINITYCORE_SOCKET_MGR_H #include "AsyncAcceptor.h" #include "Errors.h" #include "NetworkThread.h" +#include "Socket.h" #include #include -using boost::asio::ip::tcp; - +namespace Trinity::Net +{ template class SocketMgr { public: + SocketMgr(SocketMgr const&) = delete; + SocketMgr(SocketMgr&&) = delete; + SocketMgr& operator=(SocketMgr const&) = delete; + SocketMgr& operator=(SocketMgr&&) = delete; + virtual ~SocketMgr() { ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction"); } - virtual bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) + virtual bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) { ASSERT(threadCount > 0); - AsyncAcceptor* acceptor = nullptr; + std::unique_ptr acceptor = nullptr; try { - acceptor = new AsyncAcceptor(ioContext, bindIp, port); + acceptor = std::make_unique(ioContext, bindIp, port); } catch (boost::system::system_error const& err) { @@ -53,13 +59,12 @@ public: if (!acceptor->Bind()) { TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor"); - delete acceptor; return false; } - _acceptor = acceptor; + _acceptor = std::move(acceptor); _threadCount = threadCount; - _threads = CreateThreads(); + _threads.reset(CreateThreads()); ASSERT(_threads); @@ -75,27 +80,23 @@ public: { _acceptor->Close(); - if (_threadCount != 0) - for (int32 i = 0; i < _threadCount; ++i) - _threads[i].Stop(); + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Stop(); Wait(); - delete _acceptor; _acceptor = nullptr; - delete[] _threads; _threads = nullptr; _threadCount = 0; } void Wait() { - if (_threadCount != 0) - for (int32 i = 0; i < _threadCount; ++i) - _threads[i].Wait(); + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Wait(); } - virtual void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) + virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex) { try { @@ -123,22 +124,23 @@ public: return min; } - std::pair GetSocketForAccept() + std::pair GetSocketForAccept() { uint32 threadIndex = SelectThreadWithMinConnections(); return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex); } protected: - SocketMgr() : _acceptor(nullptr), _threads(nullptr), _threadCount(0) + SocketMgr() : _threadCount(0) { } virtual NetworkThread* CreateThreads() const = 0; - AsyncAcceptor* _acceptor; - NetworkThread* _threads; + std::unique_ptr _acceptor; + std::unique_ptr[]> _threads; int32 _threadCount; }; +} -#endif // SocketMgr_h__ +#endif // TRINITYCORE_SOCKET_MGR_H diff --git a/src/server/shared/Networking/SslStream.h b/src/server/shared/Networking/SslStream.h new file mode 100644 index 00000000000..2cced44e5ff --- /dev/null +++ b/src/server/shared/Networking/SslStream.h @@ -0,0 +1,131 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see . + */ + +#ifndef TRINITYCORE_SSL_STREAM_H +#define TRINITYCORE_SSL_STREAM_H + +#include "SocketConnectionInitializer.h" +#include +#include +#include + +namespace Trinity::Net +{ +template +struct SslHandshakeConnectionInitializer final : SocketConnectionInitializer +{ + explicit SslHandshakeConnectionInitializer(SocketImpl* socket) : _socket(socket) { } + + void Start() override + { + _socket->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server, + [socketRef = _socket->weak_from_this(), self = this->shared_from_this()](boost::system::error_code const& error) + { + std::shared_ptr socket = static_pointer_cast(socketRef.lock()); + if (!socket) + return; + + if (error) + { + TC_LOG_ERROR("session", "{} SSL Handshake failed {}", socket->GetClientInfo(), error.message()); + socket->CloseSocket(); + return; + } + + if (self->next) + self->next->Start(); + }); + } + +private: + SocketImpl* _socket; +}; + +template +class SslStream +{ +public: + explicit SslStream(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext) + { + _sslSocket.set_verify_mode(boost::asio::ssl::verify_none); + } + + explicit SslStream(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) : _sslSocket(context, sslContext) + { + _sslSocket.set_verify_mode(boost::asio::ssl::verify_none); + } + + // adapting tcp::socket api + void close(boost::system::error_code& error) + { + _sslSocket.next_layer().close(error); + } + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) + { + _sslSocket.shutdown(shutdownError); + _sslSocket.next_layer().shutdown(what, shutdownError); + } + + template + void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler) + { + _sslSocket.async_read_some(buffers, std::forward(handler)); + } + + template + void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler) + { + _sslSocket.async_write_some(buffers, std::forward(handler)); + } + + template + std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error) + { + return _sslSocket.write_some(buffers, error); + } + + template + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler) + { + _sslSocket.next_layer().async_wait(type, std::forward(handler)); + } + + template + void set_option(SettableSocketOption const& option, boost::system::error_code& error) + { + _sslSocket.next_layer().set_option(option, error); + } + + IoContextTcpSocket::endpoint_type remote_endpoint() const + { + return _sslSocket.next_layer().remote_endpoint(); + } + + // ssl api + template + void async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler) + { + _sslSocket.async_handshake(type, std::forward(handler)); + } + +protected: + boost::asio::ssl::stream _sslSocket; +}; +} + +#endif // TRINITYCORE_SSL_STREAM_H diff --git a/src/server/worldserver/Main.cpp b/src/server/worldserver/Main.cpp index b78a45fd3ae..2dee41f8aaa 100644 --- a/src/server/worldserver/Main.cpp +++ b/src/server/worldserver/Main.cpp @@ -15,10 +15,6 @@ * with this program. If not, see . */ -/// \addtogroup Trinityd Trinity Daemon -/// @{ -/// \file - #include "Common.h" #include "AppenderDB.h" #include "AsyncAcceptor.h" @@ -117,7 +113,7 @@ private: }; void SignalHandler(boost::system::error_code const& error, int signalNumber); -AsyncAcceptor* StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext); +std::unique_ptr StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext); bool StartDB(); void StopDB(); void WorldUpdateLoop(); @@ -336,9 +332,9 @@ int main(int argc, char** argv) }); // Start the Remote Access port (acceptor) if enabled - std::unique_ptr raAcceptor; + std::unique_ptr raAcceptor; if (sConfigMgr->GetBoolDefault("Ra.Enable", false)) - raAcceptor.reset(StartRaSocketAcceptor(*ioContext)); + raAcceptor = StartRaSocketAcceptor(*ioContext); // Start soap serving thread if enabled std::shared_ptr soapThread; @@ -585,20 +581,23 @@ void FreezeDetector::Handler(std::weak_ptr freezeDetectorRef, bo } } -AsyncAcceptor* StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext) +std::unique_ptr StartRaSocketAcceptor(Trinity::Asio::IoContext& ioContext) { uint16 raPort = uint16(sConfigMgr->GetIntDefault("Ra.Port", 3443)); std::string raListener = sConfigMgr->GetStringDefault("Ra.IP", "0.0.0.0"); - AsyncAcceptor* acceptor = new AsyncAcceptor(ioContext, raListener, raPort); + std::unique_ptr acceptor = std::make_unique(ioContext, raListener, raPort); if (!acceptor->Bind()) { TC_LOG_ERROR("server.worldserver", "Failed to bind RA socket acceptor"); - delete acceptor; return nullptr; } - acceptor->AsyncAccept(); + acceptor->AsyncAccept([](Trinity::Net::IoContextTcpSocket&& sock, uint32 /*threadIndex*/) + { + std::make_shared(std::move(sock))->Start(); + + }); return acceptor; } @@ -708,8 +707,6 @@ void ClearOnlineAccounts() CharacterDatabase.DirectExecute("UPDATE character_battleground_data SET instanceId = 0"); } -/// @} - variables_map GetConsoleArguments(int argc, char** argv, fs::path& configFile, fs::path& configDir, [[maybe_unused]] std::string& winServiceAction) { options_description all("Allowed options"); diff --git a/src/server/worldserver/RemoteAccess/RASession.cpp b/src/server/worldserver/RemoteAccess/RASession.cpp index f2b1833704e..cc58e044e7d 100644 --- a/src/server/worldserver/RemoteAccess/RASession.cpp +++ b/src/server/worldserver/RemoteAccess/RASession.cpp @@ -33,9 +33,11 @@ using boost::asio::ip::tcp; void RASession::Start() { + _socket.non_blocking(false); + // wait 1 second for active connections to send negotiation request for (int counter = 0; counter < 10 && _socket.available() == 0; counter++) - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::this_thread::sleep_for(100ms); // Check if there are bytes available, if they are, then the client is requesting the negotiation if (_socket.available() > 0) diff --git a/src/server/worldserver/RemoteAccess/RASession.h b/src/server/worldserver/RemoteAccess/RASession.h index d499727ae9d..5e30be4a254 100644 --- a/src/server/worldserver/RemoteAccess/RASession.h +++ b/src/server/worldserver/RemoteAccess/RASession.h @@ -15,10 +15,11 @@ * with this program. If not, see . */ -#ifndef __RASESSION_H__ -#define __RASESSION_H__ +#ifndef TRINITYCORE_RA_SESSION_H +#define TRINITYCORE_RA_SESSION_H #include +#include "Socket.h" #include #include #include "Common.h" @@ -32,7 +33,7 @@ const size_t bufferSize = 4096; class RASession : public std::enable_shared_from_this { public: - RASession(tcp::socket&& socket) : _socket(std::move(socket)), _commandExecuting(nullptr) + RASession(Trinity::Net::IoContextTcpSocket&& socket) : _socket(std::move(socket)), _commandExecuting(nullptr) { } @@ -51,7 +52,7 @@ private: static void CommandPrint(void* callbackArg, std::string_view text); static void CommandFinished(void* callbackArg, bool); - tcp::socket _socket; + Trinity::Net::IoContextTcpSocket _socket; boost::asio::streambuf _readBuffer; boost::asio::streambuf _writeBuffer; std::promise* _commandExecuting;