From bb1cc8a48f8f3472ebdfdbaa6a97ec2e1c5190c8 Mon Sep 17 00:00:00 2001 From: Shauren Date: Mon, 22 Sep 2025 13:17:50 +0200 Subject: Core/Network: Implement connect operations in Socket class --- src/common/network/Socket.h | 122 +++++++++++++++++++++++++++++++++++++++-- src/common/network/SslStream.h | 6 ++ 2 files changed, 124 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/common/network/Socket.h b/src/common/network/Socket.h index aa42c7299e8..c2b2c63a291 100644 --- a/src/common/network/Socket.h +++ b/src/common/network/Socket.h @@ -23,6 +23,7 @@ #include "Log.h" #include "MessageBuffer.h" #include "SocketConnectionInitializer.h" +#include #include #include #include @@ -39,6 +40,12 @@ namespace Trinity::Net { using IoContextTcpSocket = boost::asio::basic_stream_socket; +namespace Impl::Operations +{ +template +struct Connect; +} + enum class SocketReadCallbackResult { KeepReading, @@ -52,9 +59,6 @@ inline boost::asio::mutable_buffer PrepareReadBuffer(MessageBuffer& readBuffer) return boost::asio::buffer(readBuffer.GetWritePointer(), readBuffer.GetRemainingSpace()); } -template -concept SocketReadCallback = Trinity::invocable_r; - template struct InvokeReadHandlerCallback { @@ -98,6 +102,9 @@ struct ReadConnectionInitializer final : SocketConnectionInitializer void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError); + template + void async_connect(boost::asio::ip::tcp::endpoint const& endpoint, ConnectHandlerType&& handler); + template void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler); @@ -144,6 +151,22 @@ public: virtual void Start() { } + template + decltype(auto) Connect(boost::asio::ip::tcp::endpoint const& endpoint, Callback&& callback) + { + _openState = OpenState_Open; + return boost::asio::async_compose>( + Impl::Operations::Connect(this->shared_from_this(), endpoint), callback, this->underlying_stream()); + } + + template + decltype(auto) Connect(std::vector const& endpoints, Callback&& callback) + { + _openState = OpenState_Open; + return boost::asio::async_compose>( + Impl::Operations::Connect(this->shared_from_this(), endpoints), callback, this->underlying_stream()); + } + virtual bool Update() { if (_openState == OpenState_Closed) @@ -170,7 +193,13 @@ public: return _remotePort; } - template + void SetRemoteEndpoint(boost::asio::ip::tcp::endpoint const& endpoint) + { + _remoteAddress = endpoint.address(); + _remotePort = endpoint.port(); + } + + template Callback> void AsyncRead(Callback&& callback) { if (!IsOpen()) @@ -366,6 +395,91 @@ private: bool _isWritingAsync = false; }; + +namespace Impl::Operations +{ +struct ConnectState +{ + explicit ConnectState(std::shared_ptr const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint) + : SocketRef(socketRef), Endpoints(1, endpoint), Index(-1) { } + + explicit ConnectState(std::shared_ptr const& socketRef, std::vector const& endpoints) + : SocketRef(socketRef), Endpoints(endpoints), Index(-1) { } + + std::weak_ptr SocketRef; + std::vector Endpoints; + std::ptrdiff_t Index; +}; + +template +struct Connect +{ + explicit Connect(std::shared_ptr const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint) + : State(std::make_shared(std::move(socketRef), endpoint)) { } + + explicit Connect(std::shared_ptr const& socketRef, std::vector const& endpoints) + : State(std::make_shared(std::move(socketRef), endpoints)) { } + + std::shared_ptr State; + + template + void operator()(Handler& handler, boost::system::error_code error = {}) + { + std::shared_ptr socket = static_pointer_cast(State->SocketRef.lock()); + if (!socket) + { + error = boost::asio::error::operation_aborted; + handler.complete(error, boost::asio::ip::tcp::endpoint()); + return; + } + + bool isFirst = State->Index < 0; + + if (std::max(State->Index, std::ptrdiff_t(0)) >= std::ssize(State->Endpoints)) + { + Connect::HandleError(socket.get(), "failed to connect to any of specified endpoints"); + error = boost::asio::error::not_found; + handler.complete(error, boost::asio::ip::tcp::endpoint()); + return; + } + + if (!isFirst && !socket->underlying_stream().is_open()) + { + Connect::HandleError(socket.get(), "socket closed"); + error = boost::asio::error::operation_aborted; + handler.complete(error, boost::asio::ip::tcp::endpoint()); + return; + } + + if (!error && !isFirst) + { + socket->SetRemoteEndpoint(State->Endpoints[State->Index]); + handler.complete(error, State->Endpoints[State->Index]); + } + else + { +#if BOOST_VERSION >= 107700 + if (handler.cancelled() != boost::asio::cancellation_type::none) + { + Connect::HandleError(socket.get(), "connect cancelled"); + error = boost::asio::error::operation_aborted; + handler.complete(error, boost::asio::ip::tcp::endpoint()); + return; + } +#endif + + socket->underlying_stream().close(error); + socket->underlying_stream().async_connect(State->Endpoints[++State->Index], std::move(handler)); + } + } + + static void HandleError(Socket* self, std::string_view message) + { + TC_LOG_DEBUG("network", "Socket::Connect: {}", message); + self->CloseSocket(); + } +}; +} } #endif // TRINITYCORE_SOCKET_H diff --git a/src/common/network/SslStream.h b/src/common/network/SslStream.h index f1aad7022ac..6bf949dcb47 100644 --- a/src/common/network/SslStream.h +++ b/src/common/network/SslStream.h @@ -93,6 +93,12 @@ public: _sslSocket.next_layer().shutdown(what, shutdownError); } + template + decltype(auto) async_connect(boost::asio::ip::tcp::endpoint const& endpoint, ConnectHandlerType&& handler) + { + return _sslSocket.next_layer().async_connect(endpoint, std::forward(handler)); + } + template decltype(auto) async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler) { -- cgit v1.2.3