diff options
author | Shauren <shauren.trinity@gmail.com> | 2025-04-08 19:15:16 +0200 |
---|---|---|
committer | Shauren <shauren.trinity@gmail.com> | 2025-04-08 19:15:16 +0200 |
commit | e8b2be3527c7683e8bfca70ed7706fc20da566fd (patch) | |
tree | 54d5099554c8628cad719e6f1a49d387c7eced4f /src/server/shared | |
parent | 40d80f3476ade4898be24659408e82aa4234b099 (diff) |
Core/Network: Socket refactors
* Devirtualize calls to Read and Update by marking concrete implementations as final
* Removed derived class template argument
* Specialize boost::asio::basic_stream_socket for boost::asio::io_context instead of type-erased any_io_executor
* Make socket initialization easier composable (before entering Read loop)
* Remove use of deprecated boost::asio::null_buffers and boost::beast::ssl_stream
Diffstat (limited to 'src/server/shared')
14 files changed, 612 insertions, 360 deletions
diff --git a/src/server/shared/Networking/AsyncAcceptor.h b/src/server/shared/Networking/AsyncAcceptor.h index 95e25e02166..dd0857c2b38 100644 --- a/src/server/shared/Networking/AsyncAcceptor.h +++ b/src/server/shared/Networking/AsyncAcceptor.h @@ -15,12 +15,13 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef __ASYNCACCEPT_H_ -#define __ASYNCACCEPT_H_ +#ifndef TRINITYCORE_ASYNC_ACCEPTOR_H +#define TRINITYCORE_ASYNC_ACCEPTOR_H #include "IoContext.h" #include "IpAddress.h" #include "Log.h" +#include "Socket.h" #include <boost/asio/ip/tcp.hpp> #include <boost/asio/ip/v6_only.hpp> #include <atomic> @@ -28,28 +29,28 @@ #define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections +namespace Trinity::Net +{ +template <typename Callable> +concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&, uint32>; + class AsyncAcceptor { public: - typedef void(*AcceptCallback)(boost::asio::ip::tcp::socket&& newSocket, uint32 threadIndex); - - AsyncAcceptor(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) : - _acceptor(ioContext), _endpoint(Trinity::Net::make_address(bindIp), port), + AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) : + _acceptor(ioContext), _endpoint(make_address(bindIp), port), _socket(ioContext), _closed(false), _socketFactory([this] { return DefeaultSocketFactory(); }) { } - template<class T> - void AsyncAccept(); - - template<AcceptCallback acceptCallback> - void AsyncAcceptWithCallback() + template <AcceptCallback Callback> + void AsyncAccept(Callback&& acceptCallback) { auto [tmpSocket, tmpThreadIndex] = _socketFactory(); // TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures) - boost::asio::ip::tcp::socket* socket = tmpSocket; + IoContextTcpSocket* socket = tmpSocket; uint32 threadIndex = tmpThreadIndex; - _acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error) + _acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error) mutable { if (!error) { @@ -66,7 +67,7 @@ public: } if (!_closed) - this->AsyncAcceptWithCallback<acceptCallback>(); + this->AsyncAccept(std::move(acceptCallback)); }); } @@ -120,40 +121,17 @@ public: _acceptor.close(err); } - void SetSocketFactory(std::function<std::pair<boost::asio::ip::tcp::socket*, uint32>()> func) { _socketFactory = std::move(func); } + void SetSocketFactory(std::function<std::pair<IoContextTcpSocket*, uint32>()> func) { _socketFactory = std::move(func); } private: - std::pair<boost::asio::ip::tcp::socket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); } + std::pair<IoContextTcpSocket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); } - boost::asio::ip::tcp::acceptor _acceptor; + boost::asio::basic_socket_acceptor<boost::asio::ip::tcp, IoContextTcpSocket::executor_type> _acceptor; boost::asio::ip::tcp::endpoint _endpoint; - boost::asio::ip::tcp::socket _socket; + IoContextTcpSocket _socket; std::atomic<bool> _closed; - std::function<std::pair<boost::asio::ip::tcp::socket*, uint32>()> _socketFactory; + std::function<std::pair<IoContextTcpSocket*, uint32>()> _socketFactory; }; - -template<class T> -void AsyncAcceptor::AsyncAccept() -{ - _acceptor.async_accept(_socket, [this](boost::system::error_code error) - { - if (!error) - { - try - { - // this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class - std::make_shared<T>(std::move(this->_socket))->Start(); - } - catch (boost::system::system_error const& err) - { - TC_LOG_INFO("network", "Failed to retrieve client's remote address {}", err.what()); - } - } - - // lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face - if (!_closed) - this->AsyncAccept<T>(); - }); } -#endif /* __ASYNCACCEPT_H_ */ +#endif // TRINITYCORE_ASYNC_ACCEPTOR_H diff --git a/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp new file mode 100644 index 00000000000..5996c40faee --- /dev/null +++ b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.cpp @@ -0,0 +1,42 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "IpBanCheckConnectionInitializer.h" +#include "DatabaseEnv.h" + +QueryCallback Trinity::Net::IpBanCheckHelpers::AsyncQuery(std::string_view ipAddress) +{ + LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); + stmt->setString(0, ipAddress); + return LoginDatabase.AsyncQuery(stmt); +} + +bool Trinity::Net::IpBanCheckHelpers::IsBanned(PreparedQueryResult const& result) +{ + if (result) + { + do + { + Field* fields = result->Fetch(); + if (fields[0].GetUInt64() != 0) + return true; + + } while (result->NextRow()); + } + + return false; +} diff --git a/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h new file mode 100644 index 00000000000..ff8210a7f69 --- /dev/null +++ b/src/server/shared/Networking/ConnectionInitializers/IpBanCheckConnectionInitializer.h @@ -0,0 +1,64 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H +#define TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H + +#include "DatabaseEnvFwd.h" +#include "Log.h" +#include "QueryCallback.h" +#include "SocketConnectionInitializer.h" + +namespace Trinity::Net +{ +namespace IpBanCheckHelpers +{ +TC_SHARED_API QueryCallback AsyncQuery(std::string_view ipAddress); +TC_SHARED_API bool IsBanned(PreparedQueryResult const& result); +} + +template <typename SocketImpl> +struct IpBanCheckConnectionInitializer final : SocketConnectionInitializer +{ + explicit IpBanCheckConnectionInitializer(SocketImpl* socket) : _socket(socket) { } + + void Start() override + { + _socket->QueueQuery(IpBanCheckHelpers::AsyncQuery(_socket->GetRemoteIpAddress().to_string()).WithPreparedCallback([socketRef = _socket->weak_from_this(), self = this->shared_from_this()](PreparedQueryResult const& result) + { + std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock()); + if (!socket) + return; + + if (IpBanCheckHelpers::IsBanned(result)) + { + TC_LOG_ERROR("network", "IpBanCheckConnectionInitializer: IP {} is banned.", socket->GetRemoteIpAddress().to_string()); + socket->DelayedCloseSocket(); + return; + } + + if (self->next) + self->next->Start(); + })); + } + +private: + SocketImpl* _socket; +}; +} + +#endif // TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H diff --git a/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h b/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h new file mode 100644 index 00000000000..d3f0bb16dbf --- /dev/null +++ b/src/server/shared/Networking/ConnectionInitializers/SocketConnectionInitializer.h @@ -0,0 +1,51 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H +#define TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H + +#include <memory> +#include <span> + +namespace Trinity::Net +{ +struct SocketConnectionInitializer : public std::enable_shared_from_this<SocketConnectionInitializer> +{ + SocketConnectionInitializer() = default; + + SocketConnectionInitializer(SocketConnectionInitializer const&) = delete; + SocketConnectionInitializer(SocketConnectionInitializer&&) noexcept = default; + SocketConnectionInitializer& operator=(SocketConnectionInitializer const&) = delete; + SocketConnectionInitializer& operator=(SocketConnectionInitializer&&) noexcept = default; + + virtual ~SocketConnectionInitializer() = default; + + virtual void Start() = 0; + + std::shared_ptr<SocketConnectionInitializer> next; + + static std::shared_ptr<SocketConnectionInitializer>& SetupChain(std::span<std::shared_ptr<SocketConnectionInitializer>> initializers) + { + for (std::size_t i = initializers.size(); i > 1; --i) + initializers[i - 2]->next.swap(initializers[i - 1]); + + return initializers[0]; + } +}; +} + +#endif // TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H diff --git a/src/server/shared/Networking/Http/BaseHttpSocket.cpp b/src/server/shared/Networking/Http/BaseHttpSocket.cpp index ca92c442aa2..33053399c14 100644 --- a/src/server/shared/Networking/Http/BaseHttpSocket.cpp +++ b/src/server/shared/Networking/Http/BaseHttpSocket.cpp @@ -16,6 +16,7 @@ */ #include "BaseHttpSocket.h" +#include <boost/asio/buffers_iterator.hpp> #include <boost/beast/http/serializer.hpp> namespace Trinity::Net::Http @@ -112,4 +113,34 @@ MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response return buffer; } + +void AbstractSocket::LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const +{ + if (Logger const* logger = sLog->GetEnabledLogger("server.http", LOG_LEVEL_DEBUG)) + { + std::string clientInfo = GetClientInfo(); + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_DEBUG, "{} Request {} {} done, status {}", clientInfo, + ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()), context.response.result_int()); + if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE)) + { + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: {}", clientInfo, + CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>"); + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: {}", clientInfo, + CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>"); + } + } +} + +std::string AbstractSocket::GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state) +{ + std::string info = StringFormat("[{}:{}", address.to_string(), port); + if (state) + { + info.append(", Session Id: "); + info.append(boost::uuids::to_string(state->Id)); + } + + info += ']'; + return info; +} } diff --git a/src/server/shared/Networking/Http/BaseHttpSocket.h b/src/server/shared/Networking/Http/BaseHttpSocket.h index c02b13e6463..4b7c3bd9dd1 100644 --- a/src/server/shared/Networking/Http/BaseHttpSocket.h +++ b/src/server/shared/Networking/Http/BaseHttpSocket.h @@ -25,13 +25,46 @@ #include "Optional.h" #include "QueryCallback.h" #include "Socket.h" -#include <boost/asio/buffers_iterator.hpp> +#include "SocketConnectionInitializer.h" +#include <boost/beast/core/basic_stream.hpp> #include <boost/beast/http/parser.hpp> #include <boost/beast/http/string_body.hpp> #include <boost/uuid/uuid_io.hpp> namespace Trinity::Net::Http { +using IoContextHttpSocket = boost::beast::basic_stream<boost::asio::ip::tcp, boost::asio::io_context::executor_type, boost::beast::unlimited_rate_policy>; + +namespace Impl +{ +class BoostBeastSocketWrapper : public IoContextHttpSocket +{ +public: + using IoContextHttpSocket::basic_stream; + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) + { + socket().shutdown(what, shutdownError); + } + + void close(boost::system::error_code& /*error*/) + { + IoContextHttpSocket::close(); + } + + template<typename WaitHandlerType> + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler) + { + socket().async_wait(type, std::forward<WaitHandlerType>(handler)); + } + + IoContextTcpSocket::endpoint_type remote_endpoint() const + { + return socket().remote_endpoint(); + } +}; +} + using RequestParser = boost::beast::http::request_parser<RequestBody>; class TC_SHARED_API AbstractSocket @@ -51,22 +84,59 @@ public: virtual void SendResponse(RequestContext& context) = 0; + void LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const; + virtual void QueueQuery(QueryCallback&& queryCallback) = 0; virtual std::string GetClientInfo() const = 0; - virtual Optional<boost::uuids::uuid> GetSessionId() const = 0; + static std::string GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state); + + virtual SessionState* GetSessionState() const = 0; + + Optional<boost::uuids::uuid> GetSessionId() const + { + if (SessionState* state = this->GetSessionState()) + return state->Id; + + return {}; + } + + virtual void Start() = 0; + + virtual bool Update() = 0; + + virtual boost::asio::ip::address const& GetRemoteIpAddress() const = 0; + + virtual bool IsOpen() const = 0; + + virtual void CloseSocket() = 0; }; -template<typename Derived, typename Stream> -class BaseSocket : public ::Socket<Derived, Stream>, public AbstractSocket +template <typename SocketImpl> +struct HttpConnectionInitializer final : SocketConnectionInitializer { - using Base = ::Socket<Derived, Stream>; + explicit HttpConnectionInitializer(SocketImpl* socket) : _socket(socket) { } + + void Start() override + { + _socket->ResetHttpParser(); + + if (this->next) + this->next->Start(); + } + +private: + SocketImpl* _socket; +}; + +template<typename Stream> +class BaseSocket : public Trinity::Net::Socket<Stream>, public AbstractSocket +{ + using Base = Trinity::Net::Socket<Stream>; public: - template<typename... Args> - explicit BaseSocket(boost::asio::ip::tcp::socket&& socket, Args&&... args) - : Base(std::move(socket), std::forward<Args>(args)...) { } + using Base::Base; BaseSocket(BaseSocket const& other) = delete; BaseSocket(BaseSocket&& other) = delete; @@ -75,11 +145,8 @@ public: ~BaseSocket() = default; - void ReadHandler() override + SocketReadCallbackResult ReadHandler() final { - if (!this->IsOpen()) - return; - MessageBuffer& packet = this->GetReadBuffer(); while (packet.GetActiveSize() > 0) { @@ -92,13 +159,13 @@ public: if (!HandleMessage(_httpParser->get())) { this->CloseSocket(); - break; + return SocketReadCallbackResult::Stop; } this->ResetHttpParser(); } - this->AsyncRead(); + return SocketReadCallbackResult::KeepReading; } bool HandleMessage(Request& request) @@ -118,19 +185,11 @@ public: virtual RequestHandlerResult RequestHandler(RequestContext& context) = 0; - void SendResponse(RequestContext& context) override + void SendResponse(RequestContext& context) final { MessageBuffer buffer = SerializeResponse(context.request, context.response); - TC_LOG_DEBUG("server.http", "{} Request {} {} done, status {}", this->GetClientInfo(), ToStdStringView(context.request.method_string()), - ToStdStringView(context.request.target()), context.response.result_int()); - if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE)) - { - sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Request: {}", this->GetClientInfo(), - CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>"); - sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Response: {}", this->GetClientInfo(), - CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>"); - } + this->LogRequestAndResponse(context, buffer); this->QueuePacket(std::move(buffer)); @@ -138,11 +197,13 @@ public: this->DelayedCloseSocket(); } - void QueueQuery(QueryCallback&& queryCallback) override + void QueueQuery(QueryCallback&& queryCallback) final { this->_queryProcessor.AddCallback(std::move(queryCallback)); } + void Start() override { return this->Base::Start(); } + bool Update() override { if (!this->Base::Update()) @@ -152,27 +213,19 @@ public: return true; } + boost::asio::ip::address const& GetRemoteIpAddress() const final { return this->Base::GetRemoteIpAddress(); } + + bool IsOpen() const final { return this->Base::IsOpen(); } + + void CloseSocket() final { return this->Base::CloseSocket(); } + std::string GetClientInfo() const override { - std::string info; - info.reserve(500); - auto itr = StringFormatTo(std::back_inserter(info), "[{}:{}", this->GetRemoteIpAddress().to_string(), this->GetRemotePort()); - if (_state) - itr = StringFormatTo(itr, ", Session Id: {}", boost::uuids::to_string(_state->Id)); - - StringFormatTo(itr, "]"); - return info; + return AbstractSocket::GetClientInfo(this->GetRemoteIpAddress(), this->GetRemotePort(), this->_state.get()); } - Optional<boost::uuids::uuid> GetSessionId() const final - { - if (this->_state) - return this->_state->Id; - - return {}; - } + SessionState* GetSessionState() const override { return _state.get(); } -protected: void ResetHttpParser() { this->_httpParser.reset(); @@ -180,6 +233,7 @@ protected: this->_httpParser->eager(true); } +protected: virtual std::shared_ptr<SessionState> ObtainSessionState(RequestContext& context) const = 0; QueryCallbackProcessor _queryProcessor; diff --git a/src/server/shared/Networking/Http/HttpService.h b/src/server/shared/Networking/Http/HttpService.h index e68f8d2795f..2a377c734da 100644 --- a/src/server/shared/Networking/Http/HttpService.h +++ b/src/server/shared/Networking/Http/HttpService.h @@ -160,7 +160,7 @@ protected: class Thread : public NetworkThread<SessionImpl> { protected: - void SocketRemoved(std::shared_ptr<SessionImpl> session) override + void SocketRemoved(std::shared_ptr<SessionImpl> const& session) override { if (Optional<boost::uuids::uuid> id = session->GetSessionId()) _service->MarkSessionInactive(*id); diff --git a/src/server/shared/Networking/Http/HttpSocket.h b/src/server/shared/Networking/Http/HttpSocket.h index 9a333a8e779..2cfc3ba8ed8 100644 --- a/src/server/shared/Networking/Http/HttpSocket.h +++ b/src/server/shared/Networking/Http/HttpSocket.h @@ -19,43 +19,16 @@ #define TRINITYCORE_HTTP_SOCKET_H #include "BaseHttpSocket.h" -#include <boost/beast/core/tcp_stream.hpp> +#include <array> namespace Trinity::Net::Http { -namespace Impl +class Socket : public BaseSocket<Impl::BoostBeastSocketWrapper> { -class BoostBeastSocketWrapper : public boost::beast::tcp_stream -{ -public: - using boost::beast::tcp_stream::tcp_stream; - - void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) - { - socket().shutdown(what, shutdownError); - } - - void close(boost::system::error_code& /*error*/) - { - boost::beast::tcp_stream::close(); - } - - boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const - { - return socket().remote_endpoint(); - } -}; -} - -template <typename Derived> -class Socket : public BaseSocket<Derived, Impl::BoostBeastSocketWrapper> -{ - using SocketBase = BaseSocket<Derived, Impl::BoostBeastSocketWrapper>; + using SocketBase = BaseSocket<Impl::BoostBeastSocketWrapper>; public: - template<typename... Args> - explicit Socket(boost::asio::ip::tcp::socket&& socket, Args&&...) - : SocketBase(std::move(socket)) { } + using SocketBase::SocketBase; Socket(Socket const& other) = delete; Socket(Socket&& other) = delete; @@ -66,9 +39,13 @@ public: void Start() override { - this->ResetHttpParser(); + std::array<std::shared_ptr<SocketConnectionInitializer>, 2> initializers = + { { + std::make_shared<HttpConnectionInitializer<SocketBase>>(this), + std::make_shared<ReadConnectionInitializer<SocketBase>>(this), + } }; - this->AsyncRead(); + SocketConnectionInitializer::SetupChain(initializers)->Start(); } }; } diff --git a/src/server/shared/Networking/Http/HttpSslSocket.h b/src/server/shared/Networking/Http/HttpSslSocket.h index cdb70645e05..c789cbfefaf 100644 --- a/src/server/shared/Networking/Http/HttpSslSocket.h +++ b/src/server/shared/Networking/Http/HttpSslSocket.h @@ -19,47 +19,21 @@ #define TRINITYCORE_HTTP_SSL_SOCKET_H #include "BaseHttpSocket.h" -#include "SslSocket.h" -#include <boost/beast/core/stream_traits.hpp> -#include <boost/beast/core/tcp_stream.hpp> -#include <boost/beast/ssl/ssl_stream.hpp> +#include "SslStream.h" namespace Trinity::Net::Http { -namespace Impl +class SslSocket : public BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>> { -class BoostBeastSslSocketWrapper : public ::SslSocket<boost::beast::ssl_stream<boost::beast::tcp_stream>> -{ -public: - using SslSocket::SslSocket; - - void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) - { - _sslSocket.shutdown(shutdownError); - boost::beast::get_lowest_layer(_sslSocket).socket().shutdown(what, shutdownError); - } - - void close(boost::system::error_code& /*error*/) - { - boost::beast::get_lowest_layer(_sslSocket).close(); - } - - boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const - { - return boost::beast::get_lowest_layer(_sslSocket).socket().remote_endpoint(); - } -}; -} - -template <typename Derived> -class SslSocket : public BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper> -{ - using SocketBase = BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper>; + using SocketBase = BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>; public: - explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext) + explicit SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : SocketBase(std::move(socket), sslContext) { } + explicit SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) + : SocketBase(context, sslContext) { } + SslSocket(SslSocket const& other) = delete; SslSocket(SslSocket&& other) = delete; SslSocket& operator=(SslSocket const& other) = delete; @@ -69,27 +43,14 @@ public: void Start() override { - this->AsyncHandshake(); - } - - void AsyncHandshake() - { - this->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server, - [self = this->shared_from_this()](boost::system::error_code const& error) { self->HandshakeHandler(error); }); - } - - void HandshakeHandler(boost::system::error_code const& error) - { - if (error) - { - TC_LOG_ERROR("server.http.session.ssl", "{} SSL Handshake failed {}", this->GetClientInfo(), error.message()); - this->CloseSocket(); - return; - } - - this->ResetHttpParser(); - - this->AsyncRead(); + std::array<std::shared_ptr<SocketConnectionInitializer>, 3> initializers = + { { + std::make_shared<SslHandshakeConnectionInitializer<SocketBase>>(this), + std::make_shared<HttpConnectionInitializer<SocketBase>>(this), + std::make_shared<ReadConnectionInitializer<SocketBase>>(this), + } }; + + SocketConnectionInitializer::SetupChain(initializers)->Start(); } }; } diff --git a/src/server/shared/Networking/NetworkThread.h b/src/server/shared/Networking/NetworkThread.h index fc7d9647fc1..d16da442149 100644 --- a/src/server/shared/Networking/NetworkThread.h +++ b/src/server/shared/Networking/NetworkThread.h @@ -15,20 +15,24 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef NetworkThread_h__ -#define NetworkThread_h__ +#ifndef TRINITYCORE_NETWORK_THREAD_H +#define TRINITYCORE_NETWORK_THREAD_H -#include "Define.h" +#include "Containers.h" #include "DeadlineTimer.h" +#include "Define.h" #include "Errors.h" #include "IoContext.h" #include "Log.h" +#include "Socket.h" #include <boost/asio/ip/tcp.hpp> #include <atomic> #include <memory> #include <mutex> #include <thread> +namespace Trinity::Net +{ template<class SocketType> class NetworkThread { @@ -38,14 +42,16 @@ public: { } + NetworkThread(NetworkThread const&) = delete; + NetworkThread(NetworkThread&&) = delete; + NetworkThread& operator=(NetworkThread const&) = delete; + NetworkThread& operator=(NetworkThread&&) = delete; + virtual ~NetworkThread() { Stop(); if (_thread) - { Wait(); - delete _thread; - } } void Stop() @@ -59,7 +65,7 @@ public: if (_thread) return false; - _thread = new std::thread(&NetworkThread::Run, this); + _thread = std::make_unique<std::thread>(&NetworkThread::Run, this); return true; } @@ -68,7 +74,6 @@ public: ASSERT(_thread); _thread->join(); - delete _thread; _thread = nullptr; } @@ -82,15 +87,14 @@ public: std::lock_guard<std::mutex> lock(_newSocketsLock); ++_connections; - _newSockets.push_back(sock); - SocketAdded(sock); + SocketAdded(_newSockets.emplace_back(std::move(sock))); } - boost::asio::ip::tcp::socket* GetSocketForAccept() { return &_acceptSocket; } + Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; } protected: - virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { } - virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { } + virtual void SocketAdded(std::shared_ptr<SocketType> const& /*sock*/) { } + virtual void SocketRemoved(std::shared_ptr<SocketType> const& /*sock*/) { } void AddNewSockets() { @@ -99,7 +103,7 @@ protected: if (_newSockets.empty()) return; - for (std::shared_ptr<SocketType> sock : _newSockets) + for (std::shared_ptr<SocketType>& sock : _newSockets) { if (!sock->IsOpen()) { @@ -107,7 +111,7 @@ protected: --_connections; } else - _sockets.push_back(sock); + _sockets.emplace_back(std::move(sock)); } _newSockets.clear(); @@ -136,7 +140,7 @@ protected: AddNewSockets(); - _sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr<SocketType> sock) + Trinity::Containers::EraseIf(_sockets, [this](std::shared_ptr<SocketType> const& sock) { if (!sock->Update()) { @@ -150,7 +154,7 @@ protected: } return false; - }), _sockets.end()); + }); } private: @@ -159,7 +163,7 @@ private: std::atomic<int32> _connections; std::atomic<bool> _stopped; - std::thread* _thread; + std::unique_ptr<std::thread> _thread; SocketContainer _sockets; @@ -167,8 +171,9 @@ private: SocketContainer _newSockets; Trinity::Asio::IoContext _ioContext; - boost::asio::ip::tcp::socket _acceptSocket; + Trinity::Net::IoContextTcpSocket _acceptSocket; Trinity::Asio::DeadlineTimer _updateTimer; }; +} -#endif // NetworkThread_h__ +#endif // TRINITYCORE_NETWORK_THREAD_H diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h index 40f5820da92..565cc175318 100644 --- a/src/server/shared/Networking/Socket.h +++ b/src/server/shared/Networking/Socket.h @@ -15,11 +15,14 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef __SOCKET_H__ -#define __SOCKET_H__ +#ifndef TRINITYCORE_SOCKET_H +#define TRINITYCORE_SOCKET_H +#include "Concepts.h" #include "Log.h" #include "MessageBuffer.h" +#include "SocketConnectionInitializer.h" +#include <boost/asio/io_context.hpp> #include <boost/asio/ip/tcp.hpp> #include <atomic> #include <memory> @@ -31,12 +34,51 @@ #define TC_SOCKET_USE_IOCP #endif +namespace Trinity::Net +{ +using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>; + +enum class SocketReadCallbackResult +{ + KeepReading, + Stop +}; + +template <typename Callable> +concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>; + +template <typename SocketType> +struct InvokeReadHandlerCallback +{ + SocketReadCallbackResult operator()() const + { + return this->Socket->ReadHandler(); + } + + SocketType* Socket; +}; + +template <typename SocketType> +struct ReadConnectionInitializer final : SocketConnectionInitializer +{ + explicit ReadConnectionInitializer(SocketType* socket) : ReadCallback({ .Socket = socket }) { } + + void Start() override + { + ReadCallback.Socket->AsyncRead(std::move(ReadCallback)); + + if (this->next) + this->next->Start(); + } + + InvokeReadHandlerCallback<SocketType> ReadCallback; +}; + /** @class Socket Base async socket implementation - @tparam T derived class type (CRTP) @tparam Stream stream type used for operations on socket Stream must implement the following methods: @@ -53,21 +95,27 @@ template<typename ConstBufferSequence> std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error); + template<typename WaitHandlerType> + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler); + template<typename SettableSocketOption> void set_option(SettableSocketOption const& option, boost::system::error_code& error); tcp::socket::endpoint_type remote_endpoint() const; */ -template<class T, class Stream = boost::asio::ip::tcp::socket> -class Socket : public std::enable_shared_from_this<T> +template<class Stream = IoContextTcpSocket> +class Socket : public std::enable_shared_from_this<Socket<Stream>> { public: template<typename... Args> - explicit Socket(boost::asio::ip::tcp::socket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...), - _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), - _closed(false), _closing(false), _isWritingAsync(false) + explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...), + _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open) + { + } + + template<typename... Args> + explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward<Args>(args)...), _openState(OpenState_Closed) { - _readBuffer.Resize(READ_BLOCK_SIZE); } Socket(Socket const& other) = delete; @@ -77,20 +125,20 @@ public: virtual ~Socket() { - _closed = true; + _openState = OpenState_Closed; boost::system::error_code error; _socket.close(error); } - virtual void Start() = 0; + virtual void Start() { } virtual bool Update() { - if (_closed) + if (_openState == OpenState_Closed) return false; #ifndef TC_SOCKET_USE_IOCP - if (_isWritingAsync || (_writeQueue.empty() && !_closing)) + if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open)) return true; for (; HandleQueue();) @@ -100,7 +148,7 @@ public: return true; } - boost::asio::ip::address GetRemoteIpAddress() const + boost::asio::ip::address const& GetRemoteIpAddress() const { return _remoteAddress; } @@ -110,7 +158,8 @@ public: return _remotePort; } - void AsyncRead() + template <SocketReadCallback Callback> + void AsyncRead(Callback&& callback) { if (!IsOpen()) return; @@ -118,23 +167,11 @@ public: _readBuffer.Normalize(); _readBuffer.EnsureFreeSpace(); _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), - [self = this->shared_from_this()](boost::system::error_code const& error, size_t transferredBytes) + [self = this->shared_from_this(), callback = std::forward<Callback>(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable { - self->ReadHandlerInternal(error, transferredBytes); - }); - } - - void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code const&, std::size_t)) - { - if (!IsOpen()) - return; - - _readBuffer.Normalize(); - _readBuffer.EnsureFreeSpace(); - _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), - [self = this->shared_from_this(), callback](boost::system::error_code const& error, size_t transferredBytes) - { - (self.get()->*callback)(error, transferredBytes); + if (self->ReadHandlerInternal(error, transferredBytes)) + if (callback() == SocketReadCallbackResult::KeepReading) + self->AsyncRead(std::forward<Callback>(callback)); }); } @@ -147,11 +184,11 @@ public: #endif } - bool IsOpen() const { return !_closed && !_closing; } + bool IsOpen() const { return _openState == OpenState_Open; } void CloseSocket() { - if (_closed.exchange(true)) + if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0) return; boost::system::error_code shutdownError; @@ -160,13 +197,13 @@ public: TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress().to_string(), shutdownError.value(), shutdownError.message()); - OnClose(); + this->OnClose(); } /// Marks the socket for closing after write buffer becomes empty void DelayedCloseSocket() { - if (_closing.exchange(true)) + if (_openState.fetch_or(OpenState_Closing) != 0) return; if (_writeQueue.empty()) @@ -175,10 +212,15 @@ public: MessageBuffer& GetReadBuffer() { return _readBuffer; } + Stream& underlying_stream() + { + return _socket; + } + protected: virtual void OnClose() { } - virtual void ReadHandler() = 0; + virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; } bool AsyncProcessQueue() { @@ -195,10 +237,10 @@ protected: self->WriteHandler(error, transferedBytes); }); #else - _socket.async_write_some(boost::asio::null_buffers(), - [self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes) + _socket.async_wait(boost::asio::socket_base::wait_type::wait_write, + [self = this->shared_from_this()](boost::system::error_code const& error) { - self->WriteHandlerWrapper(error, transferedBytes); + self->WriteHandlerWrapper(error); }); #endif @@ -214,22 +256,17 @@ protected: GetRemoteIpAddress().to_string(), err.value(), err.message()); } - Stream& underlying_stream() - { - return _socket; - } - private: - void ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes) + bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes) { if (error) { CloseSocket(); - return; + return false; } _readBuffer.WriteCompleted(transferredBytes); - ReadHandler(); + return IsOpen(); } #ifdef TC_SOCKET_USE_IOCP @@ -245,7 +282,7 @@ private: if (!_writeQueue.empty()) AsyncProcessQueue(); - else if (_closing) + else if (_openState == OpenState_Closing) CloseSocket(); } else @@ -254,7 +291,7 @@ private: #else - void WriteHandlerWrapper(boost::system::error_code const& /*error*/, std::size_t /*transferedBytes*/) + void WriteHandlerWrapper(boost::system::error_code const& /*error*/) { _isWritingAsync = false; HandleQueue(); @@ -278,14 +315,14 @@ private: return AsyncProcessQueue(); _writeQueue.pop(); - if (_closing && _writeQueue.empty()) + if (_openState == OpenState_Closing && _writeQueue.empty()) CloseSocket(); return false; } else if (bytesSent == 0) { _writeQueue.pop(); - if (_closing && _writeQueue.empty()) + if (_openState == OpenState_Closing && _writeQueue.empty()) CloseSocket(); return false; } @@ -296,7 +333,7 @@ private: } _writeQueue.pop(); - if (_closing && _writeQueue.empty()) + if (_openState == OpenState_Closing && _writeQueue.empty()) CloseSocket(); return !_writeQueue.empty(); } @@ -306,15 +343,20 @@ private: Stream _socket; boost::asio::ip::address _remoteAddress; - uint16 _remotePort; + uint16 _remotePort = 0; - MessageBuffer _readBuffer; + MessageBuffer _readBuffer = MessageBuffer(READ_BLOCK_SIZE); std::queue<MessageBuffer> _writeQueue; - std::atomic<bool> _closed; - std::atomic<bool> _closing; + // Socket open state "enum" (not enum to enable integral std::atomic api) + static constexpr uint8 OpenState_Open = 0x0; + static constexpr uint8 OpenState_Closing = 0x1; ///< Transition to Closed state after sending all queued data + static constexpr uint8 OpenState_Closed = 0x2; + + std::atomic<uint8> _openState; - bool _isWritingAsync; + bool _isWritingAsync = false; }; +} -#endif // __SOCKET_H__ +#endif // TRINITYCORE_SOCKET_H diff --git a/src/server/shared/Networking/SocketMgr.h b/src/server/shared/Networking/SocketMgr.h index 0b2d03e0944..07252355308 100644 --- a/src/server/shared/Networking/SocketMgr.h +++ b/src/server/shared/Networking/SocketMgr.h @@ -15,32 +15,40 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef SocketMgr_h__ -#define SocketMgr_h__ +#ifndef TRINITYCORE_SOCKET_MGR_H +#define TRINITYCORE_SOCKET_MGR_H #include "AsyncAcceptor.h" #include "Errors.h" #include "NetworkThread.h" +#include "Socket.h" #include <boost/asio/ip/tcp.hpp> #include <memory> +namespace Trinity::Net +{ template<class SocketType> class SocketMgr { public: + SocketMgr(SocketMgr const&) = delete; + SocketMgr(SocketMgr&&) = delete; + SocketMgr& operator=(SocketMgr const&) = delete; + SocketMgr& operator=(SocketMgr&&) = delete; + virtual ~SocketMgr() { ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction"); } - virtual bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) + virtual bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) { ASSERT(threadCount > 0); - AsyncAcceptor* acceptor = nullptr; + std::unique_ptr<AsyncAcceptor> acceptor = nullptr; try { - acceptor = new AsyncAcceptor(ioContext, bindIp, port); + acceptor = std::make_unique<AsyncAcceptor>(ioContext, bindIp, port); } catch (boost::system::system_error const& err) { @@ -51,13 +59,12 @@ public: if (!acceptor->Bind()) { TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor"); - delete acceptor; return false; } - _acceptor = acceptor; + _acceptor = std::move(acceptor); _threadCount = threadCount; - _threads = CreateThreads(); + _threads.reset(CreateThreads()); ASSERT(_threads); @@ -73,27 +80,23 @@ public: { _acceptor->Close(); - if (_threadCount != 0) - for (int32 i = 0; i < _threadCount; ++i) - _threads[i].Stop(); + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Stop(); Wait(); - delete _acceptor; _acceptor = nullptr; - delete[] _threads; _threads = nullptr; _threadCount = 0; } void Wait() { - if (_threadCount != 0) - for (int32 i = 0; i < _threadCount; ++i) - _threads[i].Wait(); + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Wait(); } - virtual void OnSocketOpen(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex) + virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex) { try { @@ -121,22 +124,23 @@ public: return min; } - std::pair<boost::asio::ip::tcp::socket*, uint32> GetSocketForAccept() + std::pair<IoContextTcpSocket*, uint32> GetSocketForAccept() { uint32 threadIndex = SelectThreadWithMinConnections(); return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex); } protected: - SocketMgr() : _acceptor(nullptr), _threads(nullptr), _threadCount(0) + SocketMgr() : _threadCount(0) { } virtual NetworkThread<SocketType>* CreateThreads() const = 0; - AsyncAcceptor* _acceptor; - NetworkThread<SocketType>* _threads; + std::unique_ptr<AsyncAcceptor> _acceptor; + std::unique_ptr<NetworkThread<SocketType>[]> _threads; int32 _threadCount; }; +} -#endif // SocketMgr_h__ +#endif // TRINITYCORE_SOCKET_MGR_H diff --git a/src/server/shared/Networking/SslSocket.h b/src/server/shared/Networking/SslSocket.h deleted file mode 100644 index c19c8612edf..00000000000 --- a/src/server/shared/Networking/SslSocket.h +++ /dev/null @@ -1,88 +0,0 @@ -/* - * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information - * - * This program is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation; either version 2 of the License, or (at your - * option) any later version. - * - * This program is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for - * more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#ifndef SslSocket_h__ -#define SslSocket_h__ - -#include <boost/asio/ip/tcp.hpp> -#include <boost/asio/ssl/stream.hpp> -#include <boost/system/error_code.hpp> - -namespace boostssl = boost::asio::ssl; - -template<class Stream = boostssl::stream<boost::asio::ip::tcp::socket>> -class SslSocket -{ -public: - explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext) - { - _sslSocket.set_verify_mode(boostssl::verify_none); - } - - // adapting tcp::socket api - void close(boost::system::error_code& error) - { - _sslSocket.lowest_layer().close(error); - } - - void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) - { - _sslSocket.shutdown(shutdownError); - _sslSocket.lowest_layer().shutdown(what, shutdownError); - } - - template<typename MutableBufferSequence, typename ReadHandlerType> - void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler) - { - _sslSocket.async_read_some(buffers, std::move(handler)); - } - - template<typename ConstBufferSequence, typename WriteHandlerType> - void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler) - { - _sslSocket.async_write_some(buffers, std::move(handler)); - } - - template<typename ConstBufferSequence> - std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error) - { - return _sslSocket.write_some(buffers, error); - } - - template<typename SettableSocketOption> - void set_option(SettableSocketOption const& option, boost::system::error_code& error) - { - _sslSocket.lowest_layer().set_option(option, error); - } - - boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const - { - return _sslSocket.lowest_layer().remote_endpoint(); - } - - // ssl api - template<typename HandshakeHandlerType> - void async_handshake(boostssl::stream_base::handshake_type type, HandshakeHandlerType&& handler) - { - _sslSocket.async_handshake(type, std::move(handler)); - } - -protected: - Stream _sslSocket; -}; - -#endif // SslSocket_h__ diff --git a/src/server/shared/Networking/SslStream.h b/src/server/shared/Networking/SslStream.h new file mode 100644 index 00000000000..2cced44e5ff --- /dev/null +++ b/src/server/shared/Networking/SslStream.h @@ -0,0 +1,131 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_SSL_STREAM_H +#define TRINITYCORE_SSL_STREAM_H + +#include "SocketConnectionInitializer.h" +#include <boost/asio/ip/tcp.hpp> +#include <boost/asio/ssl/stream.hpp> +#include <boost/system/error_code.hpp> + +namespace Trinity::Net +{ +template <typename SocketImpl> +struct SslHandshakeConnectionInitializer final : SocketConnectionInitializer +{ + explicit SslHandshakeConnectionInitializer(SocketImpl* socket) : _socket(socket) { } + + void Start() override + { + _socket->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server, + [socketRef = _socket->weak_from_this(), self = this->shared_from_this()](boost::system::error_code const& error) + { + std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock()); + if (!socket) + return; + + if (error) + { + TC_LOG_ERROR("session", "{} SSL Handshake failed {}", socket->GetClientInfo(), error.message()); + socket->CloseSocket(); + return; + } + + if (self->next) + self->next->Start(); + }); + } + +private: + SocketImpl* _socket; +}; + +template<class WrappedStream = IoContextTcpSocket> +class SslStream +{ +public: + explicit SslStream(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext) + { + _sslSocket.set_verify_mode(boost::asio::ssl::verify_none); + } + + explicit SslStream(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) : _sslSocket(context, sslContext) + { + _sslSocket.set_verify_mode(boost::asio::ssl::verify_none); + } + + // adapting tcp::socket api + void close(boost::system::error_code& error) + { + _sslSocket.next_layer().close(error); + } + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) + { + _sslSocket.shutdown(shutdownError); + _sslSocket.next_layer().shutdown(what, shutdownError); + } + + template<typename MutableBufferSequence, typename ReadHandlerType> + void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler) + { + _sslSocket.async_read_some(buffers, std::forward<ReadHandlerType>(handler)); + } + + template<typename ConstBufferSequence, typename WriteHandlerType> + void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler) + { + _sslSocket.async_write_some(buffers, std::forward<WriteHandlerType>(handler)); + } + + template<typename ConstBufferSequence> + std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error) + { + return _sslSocket.write_some(buffers, error); + } + + template<typename WaitHandlerType> + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler) + { + _sslSocket.next_layer().async_wait(type, std::forward<WaitHandlerType>(handler)); + } + + template<typename SettableSocketOption> + void set_option(SettableSocketOption const& option, boost::system::error_code& error) + { + _sslSocket.next_layer().set_option(option, error); + } + + IoContextTcpSocket::endpoint_type remote_endpoint() const + { + return _sslSocket.next_layer().remote_endpoint(); + } + + // ssl api + template<typename HandshakeHandlerType> + void async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler) + { + _sslSocket.async_handshake(type, std::forward<HandshakeHandlerType>(handler)); + } + +protected: + boost::asio::ssl::stream<WrappedStream> _sslSocket; +}; +} + +#endif // TRINITYCORE_SSL_STREAM_H |