From 3d81a7a1962985e17b649f58c9a31312f172b760 Mon Sep 17 00:00:00 2001 From: Shauren Date: Sun, 13 Apr 2025 11:25:31 +0200 Subject: Core/Network: Minor include cleanup and add more required functions and typdefs to SslStream and BoostBeastSocketWrapper (cherry picked from commit c8ab1b58b183db0cb856a667b2f410d7b7a57a44) --- src/common/network/Http/BaseHttpSocket.cpp | 57 +++++++++++++++++++------ src/common/network/Http/BaseHttpSocket.h | 25 ++++++++--- src/common/network/Http/HttpService.cpp | 1 + src/common/network/Http/HttpSocket.cpp | 43 +++++++++++++++++++ src/common/network/Http/HttpSocket.h | 20 +++------ src/common/network/Http/HttpSslSocket.cpp | 44 +++++++++++++++++++ src/common/network/Http/HttpSslSocket.h | 22 +++------- src/common/network/Resolver.h | 2 +- src/common/network/Socket.h | 15 +++++-- src/common/network/SslStream.h | 34 +++++++++++---- src/server/bnetserver/REST/LoginHttpSession.cpp | 1 + 11 files changed, 202 insertions(+), 62 deletions(-) create mode 100644 src/common/network/Http/HttpSocket.cpp create mode 100644 src/common/network/Http/HttpSslSocket.cpp diff --git a/src/common/network/Http/BaseHttpSocket.cpp b/src/common/network/Http/BaseHttpSocket.cpp index 33053399c14..f14c0c98a9e 100644 --- a/src/common/network/Http/BaseHttpSocket.cpp +++ b/src/common/network/Http/BaseHttpSocket.cpp @@ -18,6 +18,7 @@ #include "BaseHttpSocket.h" #include #include +#include namespace Trinity::Net::Http { @@ -37,38 +38,58 @@ bool AbstractSocket::ParseRequest(MessageBuffer& packet, RequestParser& parser) return parser.is_done(); } -std::string AbstractSocket::SerializeRequest(Request const& request) +bool AbstractSocket::ParseResponse(MessageBuffer& packet, ResponseParser& parser) +{ + if (!parser.is_done()) + { + // need more data in the payload + boost::system::error_code ec = {}; + std::size_t readDataSize = parser.put(boost::asio::const_buffer(packet.GetReadPointer(), packet.GetActiveSize()), ec); + packet.ReadCompleted(readDataSize); + } + + return parser.is_done(); +} + +MessageBuffer AbstractSocket::SerializeRequest(Request const& request) { RequestSerializer serializer(request); - std::string buffer; + MessageBuffer buffer; while (!serializer.is_done()) { + serializer.limit(buffer.GetRemainingSpace()); + size_t totalBytes = 0; boost::system::error_code ec = {}; - serializer.next(ec, [&](boost::system::error_code const&, ConstBufferSequence const& buffers) + serializer.next(ec, [&](boost::system::error_code& currentError, ConstBufferSequence const& buffers) { size_t totalBytesInBuffers = boost::asio::buffer_size(buffers); - - buffer.reserve(buffer.size() + totalBytes); + if (totalBytesInBuffers > buffer.GetRemainingSpace()) + { + currentError = boost::beast::http::error::need_more; + return; + } auto begin = boost::asio::buffers_begin(buffers); auto end = boost::asio::buffers_end(buffers); - std::copy(begin, end, std::back_inserter(buffer)); + std::copy(begin, end, buffer.GetWritePointer()); + buffer.WriteCompleted(totalBytesInBuffers); totalBytes += totalBytesInBuffers; }); serializer.consume(totalBytes); + + if (ec == boost::beast::http::error::need_more) + buffer.Resize(buffer.GetBufferSize() + 4096); } return buffer; } -MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response& response) +MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response const& response) { - response.prepare_payload(); - ResponseSerializer serializer(response); bool (*serializerIsDone)(ResponseSerializer&); if (request.method() != boost::beast::http::verb::head) @@ -123,10 +144,20 @@ void AbstractSocket::LogRequestAndResponse(RequestContext const& context, Messag ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()), context.response.result_int()); if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE)) { - sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: {}", clientInfo, - CanLogRequestContent(context) ? SerializeRequest(context.request) : ""); - sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: {}", clientInfo, - CanLogResponseContent(context) ? std::string_view(reinterpret_cast(buffer.GetBasePointer()), buffer.GetActiveSize()) : ""); + if (CanLogRequestContent(context)) + { + MessageBuffer request = SerializeRequest(context.request); + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: {}", clientInfo, + std::string_view(reinterpret_cast(request.GetBasePointer()), request.GetActiveSize())); + } + else + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: ", clientInfo); + + if (CanLogResponseContent(context)) + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: {}", clientInfo, + std::string_view(reinterpret_cast(buffer.GetBasePointer()), buffer.GetActiveSize())); + else + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: ", clientInfo); } } } diff --git a/src/common/network/Http/BaseHttpSocket.h b/src/common/network/Http/BaseHttpSocket.h index 2d58b6d5aa1..c4d9c675de1 100644 --- a/src/common/network/Http/BaseHttpSocket.h +++ b/src/common/network/Http/BaseHttpSocket.h @@ -18,7 +18,6 @@ #ifndef TRINITYCORE_BASE_HTTP_SOCKET_H #define TRINITYCORE_BASE_HTTP_SOCKET_H -#include "AsyncCallbackProcessor.h" #include "HttpCommon.h" #include "HttpSessionState.h" #include "Optional.h" @@ -27,7 +26,6 @@ #include #include #include -#include namespace Trinity::Net::Http { @@ -40,9 +38,9 @@ class BoostBeastSocketWrapper : public IoContextHttpSocket public: using IoContextHttpSocket::basic_stream; - void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) + bool is_open() const { - socket().shutdown(what, shutdownError); + return socket().is_open(); } void close(boost::system::error_code& /*error*/) @@ -50,12 +48,23 @@ public: IoContextHttpSocket::close(); } + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) + { + socket().shutdown(what, shutdownError); + } + template void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler) { socket().async_wait(type, std::forward(handler)); } + template + void set_option(SettableSocketOption const& option, boost::system::error_code& ec) + { + socket().set_option(option, ec); + } + IoContextTcpSocket::endpoint_type remote_endpoint() const { return socket().remote_endpoint(); @@ -64,6 +73,7 @@ public: } using RequestParser = boost::beast::http::request_parser; +using ResponseParser = boost::beast::http::response_parser; class TC_NETWORK_API AbstractSocket { @@ -76,9 +86,10 @@ public: virtual ~AbstractSocket() = default; static bool ParseRequest(MessageBuffer& packet, RequestParser& parser); + static bool ParseResponse(MessageBuffer& packet, ResponseParser& parser); - static std::string SerializeRequest(Request const& request); - static MessageBuffer SerializeResponse(Request const& request, Response& response); + static MessageBuffer SerializeRequest(Request const& request); + static MessageBuffer SerializeResponse(Request const& request, Response const& response); virtual void SendResponse(RequestContext& context) = 0; @@ -183,6 +194,8 @@ public: void SendResponse(RequestContext& context) final { + context.response.prepare_payload(); + MessageBuffer buffer = SerializeResponse(context.request, context.response); this->LogRequestAndResponse(context, buffer); diff --git a/src/common/network/Http/HttpService.cpp b/src/common/network/Http/HttpService.cpp index b01e27e296a..6115fd6b6d2 100644 --- a/src/common/network/Http/HttpService.cpp +++ b/src/common/network/Http/HttpService.cpp @@ -21,6 +21,7 @@ #include "Timezone.h" #include #include +#include #include namespace Trinity::Net::Http diff --git a/src/common/network/Http/HttpSocket.cpp b/src/common/network/Http/HttpSocket.cpp new file mode 100644 index 00000000000..45fdf5cb54b --- /dev/null +++ b/src/common/network/Http/HttpSocket.cpp @@ -0,0 +1,43 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see . + */ + +#include "HttpSocket.h" +#include + +namespace Trinity::Net::Http +{ +Socket::Socket(IoContextTcpSocket&& socket): SocketBase(std::move(socket)) +{ +} + +Socket::Socket(boost::asio::io_context& context): SocketBase(context) +{ +} + +Socket::~Socket() = default; + +void Socket::Start() +{ + std::array, 2> initializers = + { { + std::make_shared>(this), + std::make_shared>(this), + } }; + + SocketConnectionInitializer::SetupChain(initializers)->Start(); +} +} diff --git a/src/common/network/Http/HttpSocket.h b/src/common/network/Http/HttpSocket.h index 2cfc3ba8ed8..97715e0aa96 100644 --- a/src/common/network/Http/HttpSocket.h +++ b/src/common/network/Http/HttpSocket.h @@ -19,34 +19,26 @@ #define TRINITYCORE_HTTP_SOCKET_H #include "BaseHttpSocket.h" -#include namespace Trinity::Net::Http { -class Socket : public BaseSocket +class TC_NETWORK_API Socket : public BaseSocket { using SocketBase = BaseSocket; public: - using SocketBase::SocketBase; + explicit Socket(IoContextTcpSocket&& socket); + + explicit Socket(boost::asio::io_context& context); Socket(Socket const& other) = delete; Socket(Socket&& other) = delete; Socket& operator=(Socket const& other) = delete; Socket& operator=(Socket&& other) = delete; - ~Socket() = default; - - void Start() override - { - std::array, 2> initializers = - { { - std::make_shared>(this), - std::make_shared>(this), - } }; + ~Socket(); - SocketConnectionInitializer::SetupChain(initializers)->Start(); - } + void Start() override; }; } diff --git a/src/common/network/Http/HttpSslSocket.cpp b/src/common/network/Http/HttpSslSocket.cpp new file mode 100644 index 00000000000..5a1bb1002fc --- /dev/null +++ b/src/common/network/Http/HttpSslSocket.cpp @@ -0,0 +1,44 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see . + */ + +#include "HttpSslSocket.h" +#include + +namespace Trinity::Net::Http +{ +SslSocket::SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : SocketBase(std::move(socket), sslContext) +{ +} + +SslSocket::SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) : SocketBase(context, sslContext) +{ +} + +SslSocket::~SslSocket() = default; + +void SslSocket::Start() +{ + std::array, 3> initializers = + { { + std::make_shared>(this), + std::make_shared>(this), + std::make_shared>(this), + } }; + + SocketConnectionInitializer::SetupChain(initializers)->Start(); +} +} diff --git a/src/common/network/Http/HttpSslSocket.h b/src/common/network/Http/HttpSslSocket.h index c789cbfefaf..d411e7c337b 100644 --- a/src/common/network/Http/HttpSslSocket.h +++ b/src/common/network/Http/HttpSslSocket.h @@ -23,35 +23,23 @@ namespace Trinity::Net::Http { -class SslSocket : public BaseSocket> +class TC_NETWORK_API SslSocket : public BaseSocket> { using SocketBase = BaseSocket>; public: - explicit SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) - : SocketBase(std::move(socket), sslContext) { } + explicit SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext); - explicit SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) - : SocketBase(context, sslContext) { } + explicit SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext); SslSocket(SslSocket const& other) = delete; SslSocket(SslSocket&& other) = delete; SslSocket& operator=(SslSocket const& other) = delete; SslSocket& operator=(SslSocket&& other) = delete; - ~SslSocket() = default; + ~SslSocket(); - void Start() override - { - std::array, 3> initializers = - { { - std::make_shared>(this), - std::make_shared>(this), - std::make_shared>(this), - } }; - - SocketConnectionInitializer::SetupChain(initializers)->Start(); - } + void Start() override; }; } diff --git a/src/common/network/Resolver.h b/src/common/network/Resolver.h index c7d24658aa5..28acd7f9d79 100644 --- a/src/common/network/Resolver.h +++ b/src/common/network/Resolver.h @@ -40,7 +40,7 @@ public: std::vector ResolveAll(std::string_view host, std::string_view service); private: - boost::asio::ip::tcp::resolver _impl; + boost::asio::ip::basic_resolver _impl; }; } diff --git a/src/common/network/Socket.h b/src/common/network/Socket.h index 565cc175318..e91d3198b58 100644 --- a/src/common/network/Socket.h +++ b/src/common/network/Socket.h @@ -44,6 +44,13 @@ enum class SocketReadCallbackResult Stop }; +inline boost::asio::mutable_buffer PrepareReadBuffer(MessageBuffer& readBuffer) +{ + readBuffer.Normalize(); + readBuffer.EnsureFreeSpace(); + return boost::asio::buffer(readBuffer.GetWritePointer(), readBuffer.GetRemainingSpace()); +} + template concept SocketReadCallback = Trinity::invocable_r; @@ -82,6 +89,10 @@ struct ReadConnectionInitializer final : SocketConnectionInitializer @tparam Stream stream type used for operations on socket Stream must implement the following methods: + boost::asio::io_context::executor_type get_executor(); + + bool is_open() const; + void close(boost::system::error_code& error); void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError); @@ -164,9 +175,7 @@ public: if (!IsOpen()) return; - _readBuffer.Normalize(); - _readBuffer.EnsureFreeSpace(); - _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), + _socket.async_read_some(PrepareReadBuffer(_readBuffer), [self = this->shared_from_this(), callback = std::forward(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable { if (self->ReadHandlerInternal(error, transferredBytes)) diff --git a/src/common/network/SslStream.h b/src/common/network/SslStream.h index 2cced44e5ff..f1aad7022ac 100644 --- a/src/common/network/SslStream.h +++ b/src/common/network/SslStream.h @@ -59,6 +59,8 @@ template class SslStream { public: + using executor_type = typename WrappedStream::executor_type; + explicit SslStream(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext) { _sslSocket.set_verify_mode(boost::asio::ssl::verify_none); @@ -70,6 +72,16 @@ public: } // adapting tcp::socket api + boost::asio::io_context::executor_type get_executor() + { + return _sslSocket.get_executor(); + } + + bool is_open() const + { + return _sslSocket.next_layer().is_open(); + } + void close(boost::system::error_code& error) { _sslSocket.next_layer().close(error); @@ -82,15 +94,15 @@ public: } template - void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler) + decltype(auto) async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler) { - _sslSocket.async_read_some(buffers, std::forward(handler)); + return _sslSocket.async_read_some(buffers, std::forward(handler)); } template - void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler) + decltype(auto) async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler) { - _sslSocket.async_write_some(buffers, std::forward(handler)); + return _sslSocket.async_write_some(buffers, std::forward(handler)); } template @@ -100,9 +112,9 @@ public: } template - void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler) + decltype(auto) async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler) { - _sslSocket.next_layer().async_wait(type, std::forward(handler)); + return _sslSocket.next_layer().async_wait(type, std::forward(handler)); } template @@ -118,9 +130,15 @@ public: // ssl api template - void async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler) + decltype(auto) async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler) + { + return _sslSocket.async_handshake(type, std::forward(handler)); + } + + void set_server_name(std::string const& serverName, boost::system::error_code& error) { - _sslSocket.async_handshake(type, std::forward(handler)); + if (!SSL_set_tlsext_host_name(_sslSocket.native_handle(), serverName.c_str())) + error.assign(static_cast(::ERR_get_error()), boost::asio::error::get_ssl_category()); } protected: diff --git a/src/server/bnetserver/REST/LoginHttpSession.cpp b/src/server/bnetserver/REST/LoginHttpSession.cpp index bd8afdbcf2f..d1c38d667d9 100644 --- a/src/server/bnetserver/REST/LoginHttpSession.cpp +++ b/src/server/bnetserver/REST/LoginHttpSession.cpp @@ -24,6 +24,7 @@ #include "SslContext.h" #include "Util.h" #include +#include namespace { -- cgit v1.2.3