Core/Network: Minor include cleanup and add more required functions and typdefs to SslStream and BoostBeastSocketWrapper

This commit is contained in:
Shauren
2025-04-13 11:25:31 +02:00
parent 2f05dd6c07
commit c8ab1b58b1
11 changed files with 202 additions and 62 deletions

View File

@@ -18,6 +18,7 @@
#include "BaseHttpSocket.h"
#include <boost/asio/buffers_iterator.hpp>
#include <boost/beast/http/serializer.hpp>
#include <boost/uuid/uuid_io.hpp>
namespace Trinity::Net::Http
{
@@ -37,38 +38,58 @@ bool AbstractSocket::ParseRequest(MessageBuffer& packet, RequestParser& parser)
return parser.is_done();
}
std::string AbstractSocket::SerializeRequest(Request const& request)
bool AbstractSocket::ParseResponse(MessageBuffer& packet, ResponseParser& parser)
{
if (!parser.is_done())
{
// need more data in the payload
boost::system::error_code ec = {};
std::size_t readDataSize = parser.put(boost::asio::const_buffer(packet.GetReadPointer(), packet.GetActiveSize()), ec);
packet.ReadCompleted(readDataSize);
}
return parser.is_done();
}
MessageBuffer AbstractSocket::SerializeRequest(Request const& request)
{
RequestSerializer serializer(request);
std::string buffer;
MessageBuffer buffer;
while (!serializer.is_done())
{
serializer.limit(buffer.GetRemainingSpace());
size_t totalBytes = 0;
boost::system::error_code ec = {};
serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code const&, ConstBufferSequence const& buffers)
serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code& currentError, ConstBufferSequence const& buffers)
{
size_t totalBytesInBuffers = boost::asio::buffer_size(buffers);
buffer.reserve(buffer.size() + totalBytes);
if (totalBytesInBuffers > buffer.GetRemainingSpace())
{
currentError = boost::beast::http::error::need_more;
return;
}
auto begin = boost::asio::buffers_begin(buffers);
auto end = boost::asio::buffers_end(buffers);
std::copy(begin, end, std::back_inserter(buffer));
std::copy(begin, end, buffer.GetWritePointer());
buffer.WriteCompleted(totalBytesInBuffers);
totalBytes += totalBytesInBuffers;
});
serializer.consume(totalBytes);
if (ec == boost::beast::http::error::need_more)
buffer.Resize(buffer.GetBufferSize() + 4096);
}
return buffer;
}
MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response& response)
MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response const& response)
{
response.prepare_payload();
ResponseSerializer serializer(response);
bool (*serializerIsDone)(ResponseSerializer&);
if (request.method() != boost::beast::http::verb::head)
@@ -123,10 +144,20 @@ void AbstractSocket::LogRequestAndResponse(RequestContext const& context, Messag
ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()), context.response.result_int());
if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE))
{
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: {}", clientInfo,
CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>");
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: {}", clientInfo,
CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>");
if (CanLogRequestContent(context))
{
MessageBuffer request = SerializeRequest(context.request);
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: {}", clientInfo,
std::string_view(reinterpret_cast<char const*>(request.GetBasePointer()), request.GetActiveSize()));
}
else
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: <REDACTED>", clientInfo);
if (CanLogResponseContent(context))
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: {}", clientInfo,
std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()));
else
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: <REDACTED>", clientInfo);
}
}
}

View File

@@ -18,7 +18,6 @@
#ifndef TRINITYCORE_BASE_HTTP_SOCKET_H
#define TRINITYCORE_BASE_HTTP_SOCKET_H
#include "AsyncCallbackProcessor.h"
#include "HttpCommon.h"
#include "HttpSessionState.h"
#include "Optional.h"
@@ -27,7 +26,6 @@
#include <boost/beast/core/basic_stream.hpp>
#include <boost/beast/http/parser.hpp>
#include <boost/beast/http/string_body.hpp>
#include <boost/uuid/uuid_io.hpp>
namespace Trinity::Net::Http
{
@@ -40,9 +38,9 @@ class BoostBeastSocketWrapper : public IoContextHttpSocket
public:
using IoContextHttpSocket::basic_stream;
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
bool is_open() const
{
socket().shutdown(what, shutdownError);
return socket().is_open();
}
void close(boost::system::error_code& /*error*/)
@@ -50,12 +48,23 @@ public:
IoContextHttpSocket::close();
}
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
{
socket().shutdown(what, shutdownError);
}
template<typename WaitHandlerType>
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
{
socket().async_wait(type, std::forward<WaitHandlerType>(handler));
}
template <typename SettableSocketOption>
void set_option(SettableSocketOption const& option, boost::system::error_code& ec)
{
socket().set_option(option, ec);
}
IoContextTcpSocket::endpoint_type remote_endpoint() const
{
return socket().remote_endpoint();
@@ -64,6 +73,7 @@ public:
}
using RequestParser = boost::beast::http::request_parser<RequestBody>;
using ResponseParser = boost::beast::http::response_parser<RequestBody>;
class TC_NETWORK_API AbstractSocket
{
@@ -76,9 +86,10 @@ public:
virtual ~AbstractSocket() = default;
static bool ParseRequest(MessageBuffer& packet, RequestParser& parser);
static bool ParseResponse(MessageBuffer& packet, ResponseParser& parser);
static std::string SerializeRequest(Request const& request);
static MessageBuffer SerializeResponse(Request const& request, Response& response);
static MessageBuffer SerializeRequest(Request const& request);
static MessageBuffer SerializeResponse(Request const& request, Response const& response);
virtual void SendResponse(RequestContext& context) = 0;
@@ -183,6 +194,8 @@ public:
void SendResponse(RequestContext& context) final
{
context.response.prepare_payload();
MessageBuffer buffer = SerializeResponse(context.request, context.response);
this->LogRequestAndResponse(context, buffer);

View File

@@ -21,6 +21,7 @@
#include "Timezone.h"
#include <boost/beast/version.hpp>
#include <boost/uuid/string_generator.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <fmt/chrono.h>
namespace Trinity::Net::Http

View File

@@ -0,0 +1,43 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "HttpSocket.h"
#include <array>
namespace Trinity::Net::Http
{
Socket::Socket(IoContextTcpSocket&& socket): SocketBase(std::move(socket))
{
}
Socket::Socket(boost::asio::io_context& context): SocketBase(context)
{
}
Socket::~Socket() = default;
void Socket::Start()
{
std::array<std::shared_ptr<SocketConnectionInitializer>, 2> initializers =
{ {
std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
} };
SocketConnectionInitializer::SetupChain(initializers)->Start();
}
}

View File

@@ -19,34 +19,26 @@
#define TRINITYCORE_HTTP_SOCKET_H
#include "BaseHttpSocket.h"
#include <array>
namespace Trinity::Net::Http
{
class Socket : public BaseSocket<Impl::BoostBeastSocketWrapper>
class TC_NETWORK_API Socket : public BaseSocket<Impl::BoostBeastSocketWrapper>
{
using SocketBase = BaseSocket<Impl::BoostBeastSocketWrapper>;
public:
using SocketBase::SocketBase;
explicit Socket(IoContextTcpSocket&& socket);
explicit Socket(boost::asio::io_context& context);
Socket(Socket const& other) = delete;
Socket(Socket&& other) = delete;
Socket& operator=(Socket const& other) = delete;
Socket& operator=(Socket&& other) = delete;
~Socket() = default;
~Socket();
void Start() override
{
std::array<std::shared_ptr<SocketConnectionInitializer>, 2> initializers =
{ {
std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
} };
SocketConnectionInitializer::SetupChain(initializers)->Start();
}
void Start() override;
};
}

View File

@@ -0,0 +1,44 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "HttpSslSocket.h"
#include <array>
namespace Trinity::Net::Http
{
SslSocket::SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : SocketBase(std::move(socket), sslContext)
{
}
SslSocket::SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) : SocketBase(context, sslContext)
{
}
SslSocket::~SslSocket() = default;
void SslSocket::Start()
{
std::array<std::shared_ptr<SocketConnectionInitializer>, 3> initializers =
{ {
std::make_shared<SslHandshakeConnectionInitializer<SocketBase>>(this),
std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
} };
SocketConnectionInitializer::SetupChain(initializers)->Start();
}
}

View File

@@ -23,35 +23,23 @@
namespace Trinity::Net::Http
{
class SslSocket : public BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>
class TC_NETWORK_API SslSocket : public BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>
{
using SocketBase = BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>;
public:
explicit SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext)
: SocketBase(std::move(socket), sslContext) { }
explicit SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext);
explicit SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext)
: SocketBase(context, sslContext) { }
explicit SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext);
SslSocket(SslSocket const& other) = delete;
SslSocket(SslSocket&& other) = delete;
SslSocket& operator=(SslSocket const& other) = delete;
SslSocket& operator=(SslSocket&& other) = delete;
~SslSocket() = default;
~SslSocket();
void Start() override
{
std::array<std::shared_ptr<SocketConnectionInitializer>, 3> initializers =
{ {
std::make_shared<SslHandshakeConnectionInitializer<SocketBase>>(this),
std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
} };
SocketConnectionInitializer::SetupChain(initializers)->Start();
}
void Start() override;
};
}

View File

@@ -40,7 +40,7 @@ public:
std::vector<boost::asio::ip::tcp::endpoint> ResolveAll(std::string_view host, std::string_view service);
private:
boost::asio::ip::tcp::resolver _impl;
boost::asio::ip::basic_resolver<boost::asio::ip::tcp, Asio::IoContext::Executor> _impl;
};
}

View File

@@ -44,6 +44,13 @@ enum class SocketReadCallbackResult
Stop
};
inline boost::asio::mutable_buffer PrepareReadBuffer(MessageBuffer& readBuffer)
{
readBuffer.Normalize();
readBuffer.EnsureFreeSpace();
return boost::asio::buffer(readBuffer.GetWritePointer(), readBuffer.GetRemainingSpace());
}
template <typename Callable>
concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>;
@@ -82,6 +89,10 @@ struct ReadConnectionInitializer final : SocketConnectionInitializer
@tparam Stream stream type used for operations on socket
Stream must implement the following methods:
boost::asio::io_context::executor_type get_executor();
bool is_open() const;
void close(boost::system::error_code& error);
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError);
@@ -164,9 +175,7 @@ public:
if (!IsOpen())
return;
_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
_socket.async_read_some(PrepareReadBuffer(_readBuffer),
[self = this->shared_from_this(), callback = std::forward<Callback>(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable
{
if (self->ReadHandlerInternal(error, transferredBytes))

View File

@@ -59,6 +59,8 @@ template<class WrappedStream = IoContextTcpSocket>
class SslStream
{
public:
using executor_type = typename WrappedStream::executor_type;
explicit SslStream(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext)
{
_sslSocket.set_verify_mode(boost::asio::ssl::verify_none);
@@ -70,6 +72,16 @@ public:
}
// adapting tcp::socket api
boost::asio::io_context::executor_type get_executor()
{
return _sslSocket.get_executor();
}
bool is_open() const
{
return _sslSocket.next_layer().is_open();
}
void close(boost::system::error_code& error)
{
_sslSocket.next_layer().close(error);
@@ -82,15 +94,15 @@ public:
}
template<typename MutableBufferSequence, typename ReadHandlerType>
void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler)
decltype(auto) async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler)
{
_sslSocket.async_read_some(buffers, std::forward<ReadHandlerType>(handler));
return _sslSocket.async_read_some(buffers, std::forward<ReadHandlerType>(handler));
}
template<typename ConstBufferSequence, typename WriteHandlerType>
void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler)
decltype(auto) async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler)
{
_sslSocket.async_write_some(buffers, std::forward<WriteHandlerType>(handler));
return _sslSocket.async_write_some(buffers, std::forward<WriteHandlerType>(handler));
}
template<typename ConstBufferSequence>
@@ -100,9 +112,9 @@ public:
}
template<typename WaitHandlerType>
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
decltype(auto) async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
{
_sslSocket.next_layer().async_wait(type, std::forward<WaitHandlerType>(handler));
return _sslSocket.next_layer().async_wait(type, std::forward<WaitHandlerType>(handler));
}
template<typename SettableSocketOption>
@@ -118,9 +130,15 @@ public:
// ssl api
template<typename HandshakeHandlerType>
void async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler)
decltype(auto) async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler)
{
_sslSocket.async_handshake(type, std::forward<HandshakeHandlerType>(handler));
return _sslSocket.async_handshake(type, std::forward<HandshakeHandlerType>(handler));
}
void set_server_name(std::string const& serverName, boost::system::error_code& error)
{
if (!SSL_set_tlsext_host_name(_sslSocket.native_handle(), serverName.c_str()))
error.assign(static_cast<int>(::ERR_get_error()), boost::asio::error::get_ssl_category());
}
protected:

View File

@@ -24,6 +24,7 @@
#include "SslContext.h"
#include "Util.h"
#include <boost/container/static_vector.hpp>
#include <boost/uuid/uuid_io.hpp>
namespace
{