diff options
author | Shauren <shauren.trinity@gmail.com> | 2025-09-22 13:17:50 +0200 |
---|---|---|
committer | Shauren <shauren.trinity@gmail.com> | 2025-09-22 13:17:50 +0200 |
commit | bb1cc8a48f8f3472ebdfdbaa6a97ec2e1c5190c8 (patch) | |
tree | 632b4d1606380f30075000c20ca17d0c848085e8 /src | |
parent | 203ad17560057b3bf70a1632f3d715e429512701 (diff) |
Core/Network: Implement connect operations in Socket class
Diffstat (limited to 'src')
-rw-r--r-- | src/common/network/Socket.h | 122 | ||||
-rw-r--r-- | src/common/network/SslStream.h | 6 |
2 files changed, 124 insertions, 4 deletions
diff --git a/src/common/network/Socket.h b/src/common/network/Socket.h index aa42c7299e8..c2b2c63a291 100644 --- a/src/common/network/Socket.h +++ b/src/common/network/Socket.h @@ -23,6 +23,7 @@ #include "Log.h" #include "MessageBuffer.h" #include "SocketConnectionInitializer.h" +#include <boost/asio/compose.hpp> #include <boost/asio/io_context.hpp> #include <boost/asio/ip/tcp.hpp> #include <atomic> @@ -39,6 +40,12 @@ namespace Trinity::Net { using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>; +namespace Impl::Operations +{ +template <typename Socket> +struct Connect; +} + enum class SocketReadCallbackResult { KeepReading, @@ -52,9 +59,6 @@ inline boost::asio::mutable_buffer PrepareReadBuffer(MessageBuffer& readBuffer) return boost::asio::buffer(readBuffer.GetWritePointer(), readBuffer.GetRemainingSpace()); } -template <typename Callable> -concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>; - template <typename SocketType> struct InvokeReadHandlerCallback { @@ -98,6 +102,9 @@ struct ReadConnectionInitializer final : SocketConnectionInitializer void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError); + template<typename ConnectHandlerType> + void async_connect(boost::asio::ip::tcp::endpoint const& endpoint, ConnectHandlerType&& handler); + template<typename MutableBufferSequence, typename ReadHandlerType> void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler); @@ -144,6 +151,22 @@ public: virtual void Start() { } + template <BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, boost::asio::ip::tcp::endpoint)) Callback> + decltype(auto) Connect(boost::asio::ip::tcp::endpoint const& endpoint, Callback&& callback) + { + _openState = OpenState_Open; + return boost::asio::async_compose<Callback, void(boost::system::error_code, boost::asio::ip::tcp::endpoint), Impl::Operations::Connect<Socket>>( + Impl::Operations::Connect<Socket>(this->shared_from_this(), endpoint), callback, this->underlying_stream()); + } + + template <BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, boost::asio::ip::tcp::endpoint)) Callback> + decltype(auto) Connect(std::vector<boost::asio::ip::tcp::endpoint> const& endpoints, Callback&& callback) + { + _openState = OpenState_Open; + return boost::asio::async_compose<Callback, void(boost::system::error_code, boost::asio::ip::tcp::endpoint), Impl::Operations::Connect<Socket>>( + Impl::Operations::Connect<Socket>(this->shared_from_this(), endpoints), callback, this->underlying_stream()); + } + virtual bool Update() { if (_openState == OpenState_Closed) @@ -170,7 +193,13 @@ public: return _remotePort; } - template <SocketReadCallback Callback> + void SetRemoteEndpoint(boost::asio::ip::tcp::endpoint const& endpoint) + { + _remoteAddress = endpoint.address(); + _remotePort = endpoint.port(); + } + + template <invocable_r<SocketReadCallbackResult> Callback> void AsyncRead(Callback&& callback) { if (!IsOpen()) @@ -366,6 +395,91 @@ private: bool _isWritingAsync = false; }; + +namespace Impl::Operations +{ +struct ConnectState +{ + explicit ConnectState(std::shared_ptr<void> const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint) + : SocketRef(socketRef), Endpoints(1, endpoint), Index(-1) { } + + explicit ConnectState(std::shared_ptr<void> const& socketRef, std::vector<boost::asio::ip::tcp::endpoint> const& endpoints) + : SocketRef(socketRef), Endpoints(endpoints), Index(-1) { } + + std::weak_ptr<void> SocketRef; + std::vector<boost::asio::ip::tcp::endpoint> Endpoints; + std::ptrdiff_t Index; +}; + +template <typename Socket> +struct Connect +{ + explicit Connect(std::shared_ptr<Socket> const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint) + : State(std::make_shared<ConnectState>(std::move(socketRef), endpoint)) { } + + explicit Connect(std::shared_ptr<Socket> const& socketRef, std::vector<boost::asio::ip::tcp::endpoint> const& endpoints) + : State(std::make_shared<ConnectState>(std::move(socketRef), endpoints)) { } + + std::shared_ptr<ConnectState> State; + + template <typename Handler> + void operator()(Handler& handler, boost::system::error_code error = {}) + { + std::shared_ptr<Socket> socket = static_pointer_cast<Socket>(State->SocketRef.lock()); + if (!socket) + { + error = boost::asio::error::operation_aborted; + handler.complete(error, boost::asio::ip::tcp::endpoint()); + return; + } + + bool isFirst = State->Index < 0; + + if (std::max(State->Index, std::ptrdiff_t(0)) >= std::ssize(State->Endpoints)) + { + Connect::HandleError(socket.get(), "failed to connect to any of specified endpoints"); + error = boost::asio::error::not_found; + handler.complete(error, boost::asio::ip::tcp::endpoint()); + return; + } + + if (!isFirst && !socket->underlying_stream().is_open()) + { + Connect::HandleError(socket.get(), "socket closed"); + error = boost::asio::error::operation_aborted; + handler.complete(error, boost::asio::ip::tcp::endpoint()); + return; + } + + if (!error && !isFirst) + { + socket->SetRemoteEndpoint(State->Endpoints[State->Index]); + handler.complete(error, State->Endpoints[State->Index]); + } + else + { +#if BOOST_VERSION >= 107700 + if (handler.cancelled() != boost::asio::cancellation_type::none) + { + Connect::HandleError(socket.get(), "connect cancelled"); + error = boost::asio::error::operation_aborted; + handler.complete(error, boost::asio::ip::tcp::endpoint()); + return; + } +#endif + + socket->underlying_stream().close(error); + socket->underlying_stream().async_connect(State->Endpoints[++State->Index], std::move(handler)); + } + } + + static void HandleError(Socket* self, std::string_view message) + { + TC_LOG_DEBUG("network", "Socket::Connect: {}", message); + self->CloseSocket(); + } +}; +} } #endif // TRINITYCORE_SOCKET_H diff --git a/src/common/network/SslStream.h b/src/common/network/SslStream.h index f1aad7022ac..6bf949dcb47 100644 --- a/src/common/network/SslStream.h +++ b/src/common/network/SslStream.h @@ -93,6 +93,12 @@ public: _sslSocket.next_layer().shutdown(what, shutdownError); } + template<typename ConnectHandlerType> + decltype(auto) async_connect(boost::asio::ip::tcp::endpoint const& endpoint, ConnectHandlerType&& handler) + { + return _sslSocket.next_layer().async_connect(endpoint, std::forward<ConnectHandlerType>(handler)); + } + template<typename MutableBufferSequence, typename ReadHandlerType> decltype(auto) async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler) { |