aboutsummaryrefslogtreecommitdiff
path: root/src/common/network/Socket.h
diff options
context:
space:
mode:
authorShauren <shauren.trinity@gmail.com>2025-09-22 13:17:50 +0200
committerShauren <shauren.trinity@gmail.com>2025-09-22 13:17:50 +0200
commitbb1cc8a48f8f3472ebdfdbaa6a97ec2e1c5190c8 (patch)
tree632b4d1606380f30075000c20ca17d0c848085e8 /src/common/network/Socket.h
parent203ad17560057b3bf70a1632f3d715e429512701 (diff)
Core/Network: Implement connect operations in Socket class
Diffstat (limited to 'src/common/network/Socket.h')
-rw-r--r--src/common/network/Socket.h122
1 files changed, 118 insertions, 4 deletions
diff --git a/src/common/network/Socket.h b/src/common/network/Socket.h
index aa42c7299e8..c2b2c63a291 100644
--- a/src/common/network/Socket.h
+++ b/src/common/network/Socket.h
@@ -23,6 +23,7 @@
#include "Log.h"
#include "MessageBuffer.h"
#include "SocketConnectionInitializer.h"
+#include <boost/asio/compose.hpp>
#include <boost/asio/io_context.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <atomic>
@@ -39,6 +40,12 @@ namespace Trinity::Net
{
using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>;
+namespace Impl::Operations
+{
+template <typename Socket>
+struct Connect;
+}
+
enum class SocketReadCallbackResult
{
KeepReading,
@@ -52,9 +59,6 @@ inline boost::asio::mutable_buffer PrepareReadBuffer(MessageBuffer& readBuffer)
return boost::asio::buffer(readBuffer.GetWritePointer(), readBuffer.GetRemainingSpace());
}
-template <typename Callable>
-concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>;
-
template <typename SocketType>
struct InvokeReadHandlerCallback
{
@@ -98,6 +102,9 @@ struct ReadConnectionInitializer final : SocketConnectionInitializer
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError);
+ template<typename ConnectHandlerType>
+ void async_connect(boost::asio::ip::tcp::endpoint const& endpoint, ConnectHandlerType&& handler);
+
template<typename MutableBufferSequence, typename ReadHandlerType>
void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler);
@@ -144,6 +151,22 @@ public:
virtual void Start() { }
+ template <BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, boost::asio::ip::tcp::endpoint)) Callback>
+ decltype(auto) Connect(boost::asio::ip::tcp::endpoint const& endpoint, Callback&& callback)
+ {
+ _openState = OpenState_Open;
+ return boost::asio::async_compose<Callback, void(boost::system::error_code, boost::asio::ip::tcp::endpoint), Impl::Operations::Connect<Socket>>(
+ Impl::Operations::Connect<Socket>(this->shared_from_this(), endpoint), callback, this->underlying_stream());
+ }
+
+ template <BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, boost::asio::ip::tcp::endpoint)) Callback>
+ decltype(auto) Connect(std::vector<boost::asio::ip::tcp::endpoint> const& endpoints, Callback&& callback)
+ {
+ _openState = OpenState_Open;
+ return boost::asio::async_compose<Callback, void(boost::system::error_code, boost::asio::ip::tcp::endpoint), Impl::Operations::Connect<Socket>>(
+ Impl::Operations::Connect<Socket>(this->shared_from_this(), endpoints), callback, this->underlying_stream());
+ }
+
virtual bool Update()
{
if (_openState == OpenState_Closed)
@@ -170,7 +193,13 @@ public:
return _remotePort;
}
- template <SocketReadCallback Callback>
+ void SetRemoteEndpoint(boost::asio::ip::tcp::endpoint const& endpoint)
+ {
+ _remoteAddress = endpoint.address();
+ _remotePort = endpoint.port();
+ }
+
+ template <invocable_r<SocketReadCallbackResult> Callback>
void AsyncRead(Callback&& callback)
{
if (!IsOpen())
@@ -366,6 +395,91 @@ private:
bool _isWritingAsync = false;
};
+
+namespace Impl::Operations
+{
+struct ConnectState
+{
+ explicit ConnectState(std::shared_ptr<void> const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint)
+ : SocketRef(socketRef), Endpoints(1, endpoint), Index(-1) { }
+
+ explicit ConnectState(std::shared_ptr<void> const& socketRef, std::vector<boost::asio::ip::tcp::endpoint> const& endpoints)
+ : SocketRef(socketRef), Endpoints(endpoints), Index(-1) { }
+
+ std::weak_ptr<void> SocketRef;
+ std::vector<boost::asio::ip::tcp::endpoint> Endpoints;
+ std::ptrdiff_t Index;
+};
+
+template <typename Socket>
+struct Connect
+{
+ explicit Connect(std::shared_ptr<Socket> const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint)
+ : State(std::make_shared<ConnectState>(std::move(socketRef), endpoint)) { }
+
+ explicit Connect(std::shared_ptr<Socket> const& socketRef, std::vector<boost::asio::ip::tcp::endpoint> const& endpoints)
+ : State(std::make_shared<ConnectState>(std::move(socketRef), endpoints)) { }
+
+ std::shared_ptr<ConnectState> State;
+
+ template <typename Handler>
+ void operator()(Handler& handler, boost::system::error_code error = {})
+ {
+ std::shared_ptr<Socket> socket = static_pointer_cast<Socket>(State->SocketRef.lock());
+ if (!socket)
+ {
+ error = boost::asio::error::operation_aborted;
+ handler.complete(error, boost::asio::ip::tcp::endpoint());
+ return;
+ }
+
+ bool isFirst = State->Index < 0;
+
+ if (std::max(State->Index, std::ptrdiff_t(0)) >= std::ssize(State->Endpoints))
+ {
+ Connect::HandleError(socket.get(), "failed to connect to any of specified endpoints");
+ error = boost::asio::error::not_found;
+ handler.complete(error, boost::asio::ip::tcp::endpoint());
+ return;
+ }
+
+ if (!isFirst && !socket->underlying_stream().is_open())
+ {
+ Connect::HandleError(socket.get(), "socket closed");
+ error = boost::asio::error::operation_aborted;
+ handler.complete(error, boost::asio::ip::tcp::endpoint());
+ return;
+ }
+
+ if (!error && !isFirst)
+ {
+ socket->SetRemoteEndpoint(State->Endpoints[State->Index]);
+ handler.complete(error, State->Endpoints[State->Index]);
+ }
+ else
+ {
+#if BOOST_VERSION >= 107700
+ if (handler.cancelled() != boost::asio::cancellation_type::none)
+ {
+ Connect::HandleError(socket.get(), "connect cancelled");
+ error = boost::asio::error::operation_aborted;
+ handler.complete(error, boost::asio::ip::tcp::endpoint());
+ return;
+ }
+#endif
+
+ socket->underlying_stream().close(error);
+ socket->underlying_stream().async_connect(State->Endpoints[++State->Index], std::move(handler));
+ }
+ }
+
+ static void HandleError(Socket* self, std::string_view message)
+ {
+ TC_LOG_DEBUG("network", "Socket::Connect: {}", message);
+ self->CloseSocket();
+ }
+};
+}
}
#endif // TRINITYCORE_SOCKET_H