Core/Network: Implement connect operations in Socket class

This commit is contained in:
Shauren
2025-09-22 13:17:50 +02:00
parent 203ad17560
commit bb1cc8a48f
2 changed files with 124 additions and 4 deletions

View File

@@ -23,6 +23,7 @@
#include "Log.h"
#include "MessageBuffer.h"
#include "SocketConnectionInitializer.h"
#include <boost/asio/compose.hpp>
#include <boost/asio/io_context.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <atomic>
@@ -39,6 +40,12 @@ namespace Trinity::Net
{
using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>;
namespace Impl::Operations
{
template <typename Socket>
struct Connect;
}
enum class SocketReadCallbackResult
{
KeepReading,
@@ -52,9 +59,6 @@ inline boost::asio::mutable_buffer PrepareReadBuffer(MessageBuffer& readBuffer)
return boost::asio::buffer(readBuffer.GetWritePointer(), readBuffer.GetRemainingSpace());
}
template <typename Callable>
concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>;
template <typename SocketType>
struct InvokeReadHandlerCallback
{
@@ -98,6 +102,9 @@ struct ReadConnectionInitializer final : SocketConnectionInitializer
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError);
template<typename ConnectHandlerType>
void async_connect(boost::asio::ip::tcp::endpoint const& endpoint, ConnectHandlerType&& handler);
template<typename MutableBufferSequence, typename ReadHandlerType>
void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler);
@@ -144,6 +151,22 @@ public:
virtual void Start() { }
template <BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, boost::asio::ip::tcp::endpoint)) Callback>
decltype(auto) Connect(boost::asio::ip::tcp::endpoint const& endpoint, Callback&& callback)
{
_openState = OpenState_Open;
return boost::asio::async_compose<Callback, void(boost::system::error_code, boost::asio::ip::tcp::endpoint), Impl::Operations::Connect<Socket>>(
Impl::Operations::Connect<Socket>(this->shared_from_this(), endpoint), callback, this->underlying_stream());
}
template <BOOST_ASIO_COMPLETION_TOKEN_FOR(void(boost::system::error_code, boost::asio::ip::tcp::endpoint)) Callback>
decltype(auto) Connect(std::vector<boost::asio::ip::tcp::endpoint> const& endpoints, Callback&& callback)
{
_openState = OpenState_Open;
return boost::asio::async_compose<Callback, void(boost::system::error_code, boost::asio::ip::tcp::endpoint), Impl::Operations::Connect<Socket>>(
Impl::Operations::Connect<Socket>(this->shared_from_this(), endpoints), callback, this->underlying_stream());
}
virtual bool Update()
{
if (_openState == OpenState_Closed)
@@ -170,7 +193,13 @@ public:
return _remotePort;
}
template <SocketReadCallback Callback>
void SetRemoteEndpoint(boost::asio::ip::tcp::endpoint const& endpoint)
{
_remoteAddress = endpoint.address();
_remotePort = endpoint.port();
}
template <invocable_r<SocketReadCallbackResult> Callback>
void AsyncRead(Callback&& callback)
{
if (!IsOpen())
@@ -366,6 +395,91 @@ private:
bool _isWritingAsync = false;
};
namespace Impl::Operations
{
struct ConnectState
{
explicit ConnectState(std::shared_ptr<void> const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint)
: SocketRef(socketRef), Endpoints(1, endpoint), Index(-1) { }
explicit ConnectState(std::shared_ptr<void> const& socketRef, std::vector<boost::asio::ip::tcp::endpoint> const& endpoints)
: SocketRef(socketRef), Endpoints(endpoints), Index(-1) { }
std::weak_ptr<void> SocketRef;
std::vector<boost::asio::ip::tcp::endpoint> Endpoints;
std::ptrdiff_t Index;
};
template <typename Socket>
struct Connect
{
explicit Connect(std::shared_ptr<Socket> const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint)
: State(std::make_shared<ConnectState>(std::move(socketRef), endpoint)) { }
explicit Connect(std::shared_ptr<Socket> const& socketRef, std::vector<boost::asio::ip::tcp::endpoint> const& endpoints)
: State(std::make_shared<ConnectState>(std::move(socketRef), endpoints)) { }
std::shared_ptr<ConnectState> State;
template <typename Handler>
void operator()(Handler& handler, boost::system::error_code error = {})
{
std::shared_ptr<Socket> socket = static_pointer_cast<Socket>(State->SocketRef.lock());
if (!socket)
{
error = boost::asio::error::operation_aborted;
handler.complete(error, boost::asio::ip::tcp::endpoint());
return;
}
bool isFirst = State->Index < 0;
if (std::max(State->Index, std::ptrdiff_t(0)) >= std::ssize(State->Endpoints))
{
Connect::HandleError(socket.get(), "failed to connect to any of specified endpoints");
error = boost::asio::error::not_found;
handler.complete(error, boost::asio::ip::tcp::endpoint());
return;
}
if (!isFirst && !socket->underlying_stream().is_open())
{
Connect::HandleError(socket.get(), "socket closed");
error = boost::asio::error::operation_aborted;
handler.complete(error, boost::asio::ip::tcp::endpoint());
return;
}
if (!error && !isFirst)
{
socket->SetRemoteEndpoint(State->Endpoints[State->Index]);
handler.complete(error, State->Endpoints[State->Index]);
}
else
{
#if BOOST_VERSION >= 107700
if (handler.cancelled() != boost::asio::cancellation_type::none)
{
Connect::HandleError(socket.get(), "connect cancelled");
error = boost::asio::error::operation_aborted;
handler.complete(error, boost::asio::ip::tcp::endpoint());
return;
}
#endif
socket->underlying_stream().close(error);
socket->underlying_stream().async_connect(State->Endpoints[++State->Index], std::move(handler));
}
}
static void HandleError(Socket* self, std::string_view message)
{
TC_LOG_DEBUG("network", "Socket::Connect: {}", message);
self->CloseSocket();
}
};
}
}
#endif // TRINITYCORE_SOCKET_H

View File

@@ -93,6 +93,12 @@ public:
_sslSocket.next_layer().shutdown(what, shutdownError);
}
template<typename ConnectHandlerType>
decltype(auto) async_connect(boost::asio::ip::tcp::endpoint const& endpoint, ConnectHandlerType&& handler)
{
return _sslSocket.next_layer().async_connect(endpoint, std::forward<ConnectHandlerType>(handler));
}
template<typename MutableBufferSequence, typename ReadHandlerType>
decltype(auto) async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler)
{