/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see .
*/
#ifndef TRINITYCORE_SOCKET_H
#define TRINITYCORE_SOCKET_H
#include "Concepts.h"
#include "IpAddress.h"
#include "Log.h"
#include "MessageBuffer.h"
#include "SocketConnectionInitializer.h"
#include
#include
#include
#include
#include
#include
#include
#ifdef BOOST_ASIO_HAS_IOCP
#define TC_SOCKET_USE_IOCP
#endif
namespace Trinity::Net
{
using IoContextTcpSocket = boost::asio::basic_stream_socket;
namespace Impl::Operations
{
template
struct Connect;
}
enum class SocketReadCallbackResult
{
KeepReading,
Stop
};
inline boost::asio::mutable_buffer PrepareReadBuffer(MessageBuffer& readBuffer)
{
readBuffer.Normalize();
readBuffer.EnsureFreeSpace();
return boost::asio::buffer(readBuffer.GetWritePointer(), readBuffer.GetRemainingSpace());
}
template
struct InvokeReadHandlerCallback
{
SocketReadCallbackResult operator()() const
{
return this->Socket->ReadHandler();
}
SocketType* Socket;
};
template
struct ReadConnectionInitializer final : SocketConnectionInitializer
{
explicit ReadConnectionInitializer(AsyncReadObjectType* socket) : Socket(socket), ReadCallback({ .Socket = socket }) { }
explicit ReadConnectionInitializer(AsyncReadObjectType* socket, ReadHandlerObjectType* callbackSocket) : Socket(socket), ReadCallback({ .Socket = callbackSocket }) { }
void Start() override
{
Socket->AsyncRead(std::move(ReadCallback));
this->InvokeNext();
}
AsyncReadObjectType* Socket;
InvokeReadHandlerCallback ReadCallback;
};
/**
@class Socket
Base async socket implementation
@tparam Stream stream type used for operations on socket
Stream must implement the following methods:
boost::asio::io_context::executor_type get_executor();
bool is_open() const;
void close(boost::system::error_code& error);
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError);
template
void async_connect(boost::asio::ip::tcp::endpoint const& endpoint, ConnectHandlerType&& handler);
template
void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler);
template
void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler);
template
std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error);
template
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler);
template
void set_option(SettableSocketOption const& option, boost::system::error_code& error);
tcp::socket::endpoint_type remote_endpoint() const;
*/
template
class Socket : public std::enable_shared_from_this>
{
public:
template
explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward(args)...),
_remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open)
{
}
template
explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward(args)...), _openState(OpenState_Closed)
{
}
Socket(Socket const& other) = delete;
Socket(Socket&& other) = delete;
Socket& operator=(Socket const& other) = delete;
Socket& operator=(Socket&& other) = delete;
virtual ~Socket()
{
_openState = OpenState_Closed;
boost::system::error_code error;
_socket.close(error);
}
virtual void Start() { }
template
decltype(auto) Connect(boost::asio::ip::tcp::endpoint const& endpoint, Callback&& callback)
{
_openState = OpenState_Open;
return boost::asio::async_compose>(
Impl::Operations::Connect(this->shared_from_this(), endpoint), callback, this->underlying_stream());
}
template
decltype(auto) Connect(std::vector const& endpoints, Callback&& callback)
{
_openState = OpenState_Open;
return boost::asio::async_compose>(
Impl::Operations::Connect(this->shared_from_this(), endpoints), callback, this->underlying_stream());
}
virtual bool Update()
{
if (_openState == OpenState_Closed)
return false;
#ifndef TC_SOCKET_USE_IOCP
if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open))
return true;
for (; HandleQueue();)
;
#endif
return true;
}
boost::asio::ip::address const& GetRemoteIpAddress() const
{
return _remoteAddress;
}
uint16 GetRemotePort() const
{
return _remotePort;
}
void SetRemoteEndpoint(boost::asio::ip::tcp::endpoint const& endpoint)
{
_remoteAddress = endpoint.address();
_remotePort = endpoint.port();
}
template Callback>
void AsyncRead(Callback&& callback)
{
if (!IsOpen())
return;
_socket.async_read_some(PrepareReadBuffer(_readBuffer),
[self = this->shared_from_this(), callback = std::forward(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable
{
if (self->ReadHandlerInternal(error, transferredBytes))
if (callback() == SocketReadCallbackResult::KeepReading)
self->AsyncRead(std::forward(callback));
});
}
void QueuePacket(MessageBuffer&& buffer)
{
_writeQueue.push(std::move(buffer));
#ifdef TC_SOCKET_USE_IOCP
AsyncProcessQueue();
#endif
}
bool IsOpen() const { return _openState == OpenState_Open; }
void CloseSocket()
{
if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0)
return;
boost::system::error_code shutdownError;
_socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError);
if (shutdownError)
TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress(),
shutdownError.value(), shutdownError.message());
this->OnClose();
}
/// Marks the socket for closing after write buffer becomes empty
void DelayedCloseSocket()
{
if (_openState.fetch_or(OpenState_Closing) != 0)
return;
if (_writeQueue.empty())
CloseSocket();
}
MessageBuffer& GetReadBuffer() { return _readBuffer; }
Stream& underlying_stream()
{
return _socket;
}
protected:
virtual void OnClose() { }
virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; }
bool AsyncProcessQueue()
{
if (_isWritingAsync)
return false;
_isWritingAsync = true;
#ifdef TC_SOCKET_USE_IOCP
MessageBuffer& buffer = _writeQueue.front();
_socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()),
[self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes)
{
self->WriteHandler(error, transferedBytes);
});
#else
_socket.async_wait(boost::asio::socket_base::wait_type::wait_write,
[self = this->shared_from_this()](boost::system::error_code const& error)
{
self->WriteHandlerWrapper(error);
});
#endif
return false;
}
void SetNoDelay(bool enable)
{
boost::system::error_code err;
_socket.set_option(boost::asio::ip::tcp::no_delay(enable), err);
if (err)
TC_LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for {} - {} ({})",
GetRemoteIpAddress(), err.value(), err.message());
}
private:
bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes)
{
if (error)
{
CloseSocket();
return false;
}
_readBuffer.WriteCompleted(transferredBytes);
return IsOpen();
}
void QueuedBufferWriteDone()
{
_writeQueue.pop();
if (_openState == OpenState_Closing && _writeQueue.empty())
CloseSocket();
}
#ifdef TC_SOCKET_USE_IOCP
void WriteHandler(boost::system::error_code const& error, std::size_t transferedBytes)
{
if (!error)
{
_isWritingAsync = false;
_writeQueue.front().ReadCompleted(transferedBytes);
if (!_writeQueue.front().GetActiveSize())
QueuedBufferWriteDone();
if (!_writeQueue.empty())
AsyncProcessQueue();
}
else
CloseSocket();
}
#else
void WriteHandlerWrapper(boost::system::error_code const& /*error*/)
{
_isWritingAsync = false;
HandleQueue();
}
bool HandleQueue()
{
if (_writeQueue.empty())
return false;
MessageBuffer& queuedMessage = _writeQueue.front();
std::size_t bytesToSend = queuedMessage.GetActiveSize();
boost::system::error_code error;
std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error);
if (error)
{
if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
return AsyncProcessQueue();
QueuedBufferWriteDone();
return false;
}
else if (bytesSent == 0)
{
QueuedBufferWriteDone();
return false;
}
else if (bytesSent < bytesToSend) // now n > 0
{
queuedMessage.ReadCompleted(bytesSent);
return AsyncProcessQueue();
}
QueuedBufferWriteDone();
return !_writeQueue.empty();
}
#endif
Stream _socket;
boost::asio::ip::address _remoteAddress;
uint16 _remotePort = 0;
MessageBuffer _readBuffer = MessageBuffer(0x1000);
std::queue _writeQueue;
// Socket open state "enum" (not enum to enable integral std::atomic api)
static constexpr uint8 OpenState_Open = 0x0;
static constexpr uint8 OpenState_Closing = 0x1; ///< Transition to Closed state after sending all queued data
static constexpr uint8 OpenState_Closed = 0x2;
std::atomic _openState;
bool _isWritingAsync = false;
};
namespace Impl::Operations
{
struct ConnectState
{
explicit ConnectState(std::shared_ptr const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint)
: SocketRef(socketRef), Endpoints(1, endpoint), Index(-1) { }
explicit ConnectState(std::shared_ptr const& socketRef, std::vector const& endpoints)
: SocketRef(socketRef), Endpoints(endpoints), Index(-1) { }
std::weak_ptr SocketRef;
std::vector Endpoints;
std::ptrdiff_t Index;
};
template
struct Connect
{
explicit Connect(std::shared_ptr const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint)
: State(std::make_shared(std::move(socketRef), endpoint)) { }
explicit Connect(std::shared_ptr const& socketRef, std::vector const& endpoints)
: State(std::make_shared(std::move(socketRef), endpoints)) { }
std::shared_ptr State;
template
void operator()(Handler& handler, boost::system::error_code error = {})
{
std::shared_ptr socket = static_pointer_cast(State->SocketRef.lock());
if (!socket)
{
error = boost::asio::error::operation_aborted;
handler.complete(error, boost::asio::ip::tcp::endpoint());
return;
}
bool isFirst = State->Index < 0;
if (std::max(State->Index, std::ptrdiff_t(0)) >= std::ssize(State->Endpoints))
{
Connect::HandleError(socket.get(), "failed to connect to any of specified endpoints");
error = boost::asio::error::not_found;
handler.complete(error, boost::asio::ip::tcp::endpoint());
return;
}
if (!isFirst && !socket->underlying_stream().is_open())
{
Connect::HandleError(socket.get(), "socket closed");
error = boost::asio::error::operation_aborted;
handler.complete(error, boost::asio::ip::tcp::endpoint());
return;
}
if (!error && !isFirst)
{
socket->SetRemoteEndpoint(State->Endpoints[State->Index]);
handler.complete(error, State->Endpoints[State->Index]);
}
else
{
#if BOOST_VERSION >= 107700
if (handler.cancelled() != boost::asio::cancellation_type::none)
{
Connect::HandleError(socket.get(), "connect cancelled");
error = boost::asio::error::operation_aborted;
handler.complete(error, boost::asio::ip::tcp::endpoint());
return;
}
#endif
socket->underlying_stream().close(error);
socket->underlying_stream().async_connect(State->Endpoints[++State->Index], std::move(handler));
}
}
static void HandleError(Socket* self, std::string_view message)
{
TC_LOG_DEBUG("network", "Socket::Connect: {}", message);
self->CloseSocket();
}
};
}
}
#endif // TRINITYCORE_SOCKET_H