/* * This file is part of the AzerothCore Project. See AUTHORS file for Copyright information * * This program is free software; you can redistribute it and/or modify it * under the terms of the GNU Affero General Public License as published by the * Free Software Foundation; either version 3 of the License, or (at your * option) any later version. * * This program is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for * more details. * * You should have received a copy of the GNU General Public License along * with this program. If not, see . */ #ifndef __SOCKET_H__ #define __SOCKET_H__ #include "Log.h" #include "MessageBuffer.h" #include #include #include #include #include #include #include using boost::asio::ip::tcp; #define READ_BLOCK_SIZE 4096 #ifdef BOOST_ASIO_HAS_IOCP #define AC_SOCKET_USE_IOCP #endif enum ProxyHeaderReadingState { PROXY_HEADER_READING_STATE_NOT_STARTED, PROXY_HEADER_READING_STATE_STARTED, PROXY_HEADER_READING_STATE_FINISHED, PROXY_HEADER_READING_STATE_FAILED, }; enum ProxyHeaderAddressFamilyAndProtocol { PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V4 = 0x11, PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V6 = 0x21, }; template class Socket : public std::enable_shared_from_this { public: explicit Socket(tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false), _proxyHeaderReadingState(PROXY_HEADER_READING_STATE_NOT_STARTED) { _readBuffer.Resize(READ_BLOCK_SIZE); } virtual ~Socket() { _closed = true; boost::system::error_code error; _socket.close(error); } virtual void Start() = 0; virtual bool Update() { if (_closed) { return false; } #ifndef AC_SOCKET_USE_IOCP if (_isWritingAsync || (_writeQueue.empty() && !_closing)) { return true; } for (; HandleQueue();) ; #endif return true; } [[nodiscard]] boost::asio::ip::address GetRemoteIpAddress() const { return _remoteAddress; } [[nodiscard]] uint16 GetRemotePort() const { return _remotePort; } void AsyncRead() { if (!IsOpen()) { return; } _readBuffer.Normalize(); _readBuffer.EnsureFreeSpace(); _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), std::bind(&Socket::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); } void AsyncReadProxyHeader() { if (!IsOpen()) { return; } _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_STARTED; _readBuffer.Normalize(); _readBuffer.EnsureFreeSpace(); _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), std::bind(&Socket::ProxyReadHeaderHandler, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); } void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t)) { if (!IsOpen()) { return; } _readBuffer.Normalize(); _readBuffer.EnsureFreeSpace(); _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), std::bind(callback, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); } void QueuePacket(MessageBuffer&& buffer) { _writeQueue.push(std::move(buffer)); #ifdef AC_SOCKET_USE_IOCP AsyncProcessQueue(); #endif } [[nodiscard]] ProxyHeaderReadingState GetProxyHeaderReadingState() const { return _proxyHeaderReadingState; } [[nodiscard]] bool IsOpen() const { return !_closed && !_closing; } void CloseSocket() { if (_closed.exchange(true)) return; boost::system::error_code shutdownError; _socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError); if (shutdownError) LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress().to_string(), shutdownError.value(), shutdownError.message()); OnClose(); } /// Marks the socket for closing after write buffer becomes empty void DelayedCloseSocket() { _closing = true; } MessageBuffer& GetReadBuffer() { return _readBuffer; } protected: virtual void OnClose() { } virtual void ReadHandler() = 0; bool AsyncProcessQueue() { if (_isWritingAsync) return false; _isWritingAsync = true; #ifdef AC_SOCKET_USE_IOCP MessageBuffer& buffer = _writeQueue.front(); _socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), std::bind(&Socket::WriteHandler, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); #else _socket.async_write_some(boost::asio::null_buffers(), std::bind(&Socket::WriteHandlerWrapper, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2)); #endif return false; } void SetNoDelay(bool enable) { boost::system::error_code err; _socket.set_option(tcp::no_delay(enable), err); if (err) LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for {} - {} ({})", GetRemoteIpAddress().to_string(), err.value(), err.message()); } private: void ReadHandlerInternal(boost::system::error_code error, std::size_t transferredBytes) { if (error) { CloseSocket(); return; } _readBuffer.WriteCompleted(transferredBytes); ReadHandler(); } // ProxyReadHeaderHandler reads Proxy Protocol v2 header (v1 is not supported). // See https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt (2.2. Binary header format (version 2)) for more details. void ProxyReadHeaderHandler(boost::system::error_code error, std::size_t transferredBytes) { if (error) { CloseSocket(); return; } _readBuffer.WriteCompleted(transferredBytes); MessageBuffer& packet = GetReadBuffer(); const int minimumProxyProtocolV2Size = 28; if (packet.GetActiveSize() < minimumProxyProtocolV2Size) { AsyncReadProxyHeader(); return; } uint8* readPointer = packet.GetReadPointer(); const uint8 signatureSize = 12; const uint8 expectedSignature[signatureSize] = {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}; if (memcmp(packet.GetReadPointer(), expectedSignature, signatureSize) != 0) { _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED; LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: received bad PROXY Protocol v2 signature for {}", GetRemoteIpAddress().to_string()); return; } const uint8 version = (readPointer[signatureSize] & 0xF0) >> 4; const uint8 command = (readPointer[signatureSize] & 0xF); if (version != 2) { _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED; LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: received bad PROXY Protocol v2 signature for {}", GetRemoteIpAddress().to_string()); return; } const uint8 addressFamily = readPointer[13]; const uint16 len = (readPointer[14] << 8) | readPointer[15]; if (static_cast(len+16) > packet.GetActiveSize()) { AsyncReadProxyHeader(); return; } // Connection created by a proxy itself (health checks?), ignore and do nothing. if (command == 0) { packet.ReadCompleted(len+16); _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FINISHED; return; } auto remainingLen = packet.GetActiveSize() - 16; readPointer += 16; // Skip strait to address. switch (addressFamily) { case PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V4: { if (remainingLen < 12) { AsyncReadProxyHeader(); return; } boost::asio::ip::address_v4::bytes_type b; auto addressSize = sizeof(b); std::copy(readPointer, readPointer+addressSize, b.begin()); _remoteAddress = boost::asio::ip::address_v4(b); readPointer += 2 * addressSize; // Skip server address. _remotePort = (readPointer[0] << 8) | readPointer[1]; break; } case PROXY_HEADER_ADDRESS_FAMILY_AND_PROTOCOL_TCP_V6: { if (remainingLen < 36) { AsyncReadProxyHeader(); return; } boost::asio::ip::address_v6::bytes_type b; auto addressSize = sizeof(b); std::copy(readPointer, readPointer+addressSize, b.begin()); _remoteAddress = boost::asio::ip::address_v6(b); readPointer += 2 * addressSize; // Skip server address. _remotePort = (readPointer[0] << 8) | readPointer[1]; break; } default: _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FAILED; LOG_ERROR("network", "Socket::ProxyReadHeaderHandler: unsupported address family type {}", GetRemoteIpAddress().to_string()); return; } packet.ReadCompleted(len+16); _proxyHeaderReadingState = PROXY_HEADER_READING_STATE_FINISHED; } #ifdef AC_SOCKET_USE_IOCP void WriteHandler(boost::system::error_code error, std::size_t transferedBytes) { if (!error) { _isWritingAsync = false; _writeQueue.front().ReadCompleted(transferedBytes); if (!_writeQueue.front().GetActiveSize()) _writeQueue.pop(); if (!_writeQueue.empty()) AsyncProcessQueue(); else if (_closing) CloseSocket(); } else CloseSocket(); } #else void WriteHandlerWrapper(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/) { _isWritingAsync = false; HandleQueue(); } bool HandleQueue() { if (_writeQueue.empty()) return false; MessageBuffer& queuedMessage = _writeQueue.front(); std::size_t bytesToSend = queuedMessage.GetActiveSize(); boost::system::error_code error; std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error); if (error) { if (error == boost::asio::error::would_block || error == boost::asio::error::try_again) { return AsyncProcessQueue(); } _writeQueue.pop(); if (_closing && _writeQueue.empty()) { CloseSocket(); } return false; } else if (bytesSent == 0) { _writeQueue.pop(); if (_closing && _writeQueue.empty()) { CloseSocket(); } return false; } else if (bytesSent < bytesToSend) // now n > 0 { queuedMessage.ReadCompleted(bytesSent); return AsyncProcessQueue(); } _writeQueue.pop(); if (_closing && _writeQueue.empty()) { CloseSocket(); } return !_writeQueue.empty(); } #endif tcp::socket _socket; boost::asio::ip::address _remoteAddress; uint16 _remotePort; MessageBuffer _readBuffer; std::queue _writeQueue; std::atomic _closed; std::atomic _closing; bool _isWritingAsync; ProxyHeaderReadingState _proxyHeaderReadingState; }; #endif // __SOCKET_H__