diff options
author | Shauren <shauren.trinity@gmail.com> | 2025-04-09 21:02:31 +0200 |
---|---|---|
committer | Shauren <shauren.trinity@gmail.com> | 2025-04-09 21:02:31 +0200 |
commit | 71b681bbf0f5189cd87a6cea66ef51667223f54a (patch) | |
tree | f5da2eb9d76010efcf5abd875edd39c812b62bd7 /src/common/network/Socket.h | |
parent | 6c374c56b2bd06923ae738b19ca6a4257e29d863 (diff) |
Core/Network: Move to separate project
Diffstat (limited to 'src/common/network/Socket.h')
-rw-r--r-- | src/common/network/Socket.h | 362 |
1 files changed, 362 insertions, 0 deletions
diff --git a/src/common/network/Socket.h b/src/common/network/Socket.h new file mode 100644 index 00000000000..565cc175318 --- /dev/null +++ b/src/common/network/Socket.h @@ -0,0 +1,362 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_SOCKET_H +#define TRINITYCORE_SOCKET_H + +#include "Concepts.h" +#include "Log.h" +#include "MessageBuffer.h" +#include "SocketConnectionInitializer.h" +#include <boost/asio/io_context.hpp> +#include <boost/asio/ip/tcp.hpp> +#include <atomic> +#include <memory> +#include <queue> +#include <type_traits> + +#define READ_BLOCK_SIZE 4096 +#ifdef BOOST_ASIO_HAS_IOCP +#define TC_SOCKET_USE_IOCP +#endif + +namespace Trinity::Net +{ +using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>; + +enum class SocketReadCallbackResult +{ + KeepReading, + Stop +}; + +template <typename Callable> +concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>; + +template <typename SocketType> +struct InvokeReadHandlerCallback +{ + SocketReadCallbackResult operator()() const + { + return this->Socket->ReadHandler(); + } + + SocketType* Socket; +}; + +template <typename SocketType> +struct ReadConnectionInitializer final : SocketConnectionInitializer +{ + explicit ReadConnectionInitializer(SocketType* socket) : ReadCallback({ .Socket = socket }) { } + + void Start() override + { + ReadCallback.Socket->AsyncRead(std::move(ReadCallback)); + + if (this->next) + this->next->Start(); + } + + InvokeReadHandlerCallback<SocketType> ReadCallback; +}; + +/** + @class Socket + + Base async socket implementation + + @tparam Stream stream type used for operations on socket + Stream must implement the following methods: + + void close(boost::system::error_code& error); + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError); + + template<typename MutableBufferSequence, typename ReadHandlerType> + void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler); + + template<typename ConstBufferSequence, typename WriteHandlerType> + void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler); + + template<typename ConstBufferSequence> + std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error); + + template<typename WaitHandlerType> + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler); + + template<typename SettableSocketOption> + void set_option(SettableSocketOption const& option, boost::system::error_code& error); + + tcp::socket::endpoint_type remote_endpoint() const; +*/ +template<class Stream = IoContextTcpSocket> +class Socket : public std::enable_shared_from_this<Socket<Stream>> +{ +public: + template<typename... Args> + explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...), + _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open) + { + } + + template<typename... Args> + explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward<Args>(args)...), _openState(OpenState_Closed) + { + } + + Socket(Socket const& other) = delete; + Socket(Socket&& other) = delete; + Socket& operator=(Socket const& other) = delete; + Socket& operator=(Socket&& other) = delete; + + virtual ~Socket() + { + _openState = OpenState_Closed; + boost::system::error_code error; + _socket.close(error); + } + + virtual void Start() { } + + virtual bool Update() + { + if (_openState == OpenState_Closed) + return false; + +#ifndef TC_SOCKET_USE_IOCP + if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open)) + return true; + + for (; HandleQueue();) + ; +#endif + + return true; + } + + boost::asio::ip::address const& GetRemoteIpAddress() const + { + return _remoteAddress; + } + + uint16 GetRemotePort() const + { + return _remotePort; + } + + template <SocketReadCallback Callback> + void AsyncRead(Callback&& callback) + { + if (!IsOpen()) + return; + + _readBuffer.Normalize(); + _readBuffer.EnsureFreeSpace(); + _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), + [self = this->shared_from_this(), callback = std::forward<Callback>(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable + { + if (self->ReadHandlerInternal(error, transferredBytes)) + if (callback() == SocketReadCallbackResult::KeepReading) + self->AsyncRead(std::forward<Callback>(callback)); + }); + } + + void QueuePacket(MessageBuffer&& buffer) + { + _writeQueue.push(std::move(buffer)); + +#ifdef TC_SOCKET_USE_IOCP + AsyncProcessQueue(); +#endif + } + + bool IsOpen() const { return _openState == OpenState_Open; } + + void CloseSocket() + { + if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0) + return; + + boost::system::error_code shutdownError; + _socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError); + if (shutdownError) + TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress().to_string(), + shutdownError.value(), shutdownError.message()); + + this->OnClose(); + } + + /// Marks the socket for closing after write buffer becomes empty + void DelayedCloseSocket() + { + if (_openState.fetch_or(OpenState_Closing) != 0) + return; + + if (_writeQueue.empty()) + CloseSocket(); + } + + MessageBuffer& GetReadBuffer() { return _readBuffer; } + + Stream& underlying_stream() + { + return _socket; + } + +protected: + virtual void OnClose() { } + + virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; } + + bool AsyncProcessQueue() + { + if (_isWritingAsync) + return false; + + _isWritingAsync = true; + +#ifdef TC_SOCKET_USE_IOCP + MessageBuffer& buffer = _writeQueue.front(); + _socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), + [self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes) + { + self->WriteHandler(error, transferedBytes); + }); +#else + _socket.async_wait(boost::asio::socket_base::wait_type::wait_write, + [self = this->shared_from_this()](boost::system::error_code const& error) + { + self->WriteHandlerWrapper(error); + }); +#endif + + return false; + } + + void SetNoDelay(bool enable) + { + boost::system::error_code err; + _socket.set_option(boost::asio::ip::tcp::no_delay(enable), err); + if (err) + TC_LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for {} - {} ({})", + GetRemoteIpAddress().to_string(), err.value(), err.message()); + } + +private: + bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes) + { + if (error) + { + CloseSocket(); + return false; + } + + _readBuffer.WriteCompleted(transferredBytes); + return IsOpen(); + } + +#ifdef TC_SOCKET_USE_IOCP + + void WriteHandler(boost::system::error_code const& error, std::size_t transferedBytes) + { + if (!error) + { + _isWritingAsync = false; + _writeQueue.front().ReadCompleted(transferedBytes); + if (!_writeQueue.front().GetActiveSize()) + _writeQueue.pop(); + + if (!_writeQueue.empty()) + AsyncProcessQueue(); + else if (_openState == OpenState_Closing) + CloseSocket(); + } + else + CloseSocket(); + } + +#else + + void WriteHandlerWrapper(boost::system::error_code const& /*error*/) + { + _isWritingAsync = false; + HandleQueue(); + } + + bool HandleQueue() + { + if (_writeQueue.empty()) + return false; + + MessageBuffer& queuedMessage = _writeQueue.front(); + + std::size_t bytesToSend = queuedMessage.GetActiveSize(); + + boost::system::error_code error; + std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error); + + if (error) + { + if (error == boost::asio::error::would_block || error == boost::asio::error::try_again) + return AsyncProcessQueue(); + + _writeQueue.pop(); + if (_openState == OpenState_Closing && _writeQueue.empty()) + CloseSocket(); + return false; + } + else if (bytesSent == 0) + { + _writeQueue.pop(); + if (_openState == OpenState_Closing && _writeQueue.empty()) + CloseSocket(); + return false; + } + else if (bytesSent < bytesToSend) // now n > 0 + { + queuedMessage.ReadCompleted(bytesSent); + return AsyncProcessQueue(); + } + + _writeQueue.pop(); + if (_openState == OpenState_Closing && _writeQueue.empty()) + CloseSocket(); + return !_writeQueue.empty(); + } + +#endif + + Stream _socket; + + boost::asio::ip::address _remoteAddress; + uint16 _remotePort = 0; + + MessageBuffer _readBuffer = MessageBuffer(READ_BLOCK_SIZE); + std::queue<MessageBuffer> _writeQueue; + + // Socket open state "enum" (not enum to enable integral std::atomic api) + static constexpr uint8 OpenState_Open = 0x0; + static constexpr uint8 OpenState_Closing = 0x1; ///< Transition to Closed state after sending all queued data + static constexpr uint8 OpenState_Closed = 0x2; + + std::atomic<uint8> _openState; + + bool _isWritingAsync = false; +}; +} + +#endif // TRINITYCORE_SOCKET_H |