Core/Network: Socket refactors

* Devirtualize calls to Read and Update by marking concrete implementations as final
* Removed derived class template argument
* Specialize boost::asio::basic_stream_socket for boost::asio::io_context instead of type-erased any_io_executor
* Make socket initialization easier composable (before entering Read loop)
* Remove use of deprecated boost::asio::null_buffers and boost::beast::ssl_stream
This commit is contained in:
Shauren
2025-04-08 19:15:16 +02:00
parent 40d80f3476
commit e8b2be3527
32 changed files with 967 additions and 811 deletions

View File

@@ -15,12 +15,13 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef __ASYNCACCEPT_H_
#define __ASYNCACCEPT_H_
#ifndef TRINITYCORE_ASYNC_ACCEPTOR_H
#define TRINITYCORE_ASYNC_ACCEPTOR_H
#include "IoContext.h"
#include "IpAddress.h"
#include "Log.h"
#include "Socket.h"
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/ip/v6_only.hpp>
#include <atomic>
@@ -28,28 +29,28 @@
#define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections
namespace Trinity::Net
{
template <typename Callable>
concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&, uint32>;
class AsyncAcceptor
{
public:
typedef void(*AcceptCallback)(boost::asio::ip::tcp::socket&& newSocket, uint32 threadIndex);
AsyncAcceptor(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
_acceptor(ioContext), _endpoint(Trinity::Net::make_address(bindIp), port),
AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
_acceptor(ioContext), _endpoint(make_address(bindIp), port),
_socket(ioContext), _closed(false), _socketFactory([this] { return DefeaultSocketFactory(); })
{
}
template<class T>
void AsyncAccept();
template<AcceptCallback acceptCallback>
void AsyncAcceptWithCallback()
template <AcceptCallback Callback>
void AsyncAccept(Callback&& acceptCallback)
{
auto [tmpSocket, tmpThreadIndex] = _socketFactory();
// TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures)
boost::asio::ip::tcp::socket* socket = tmpSocket;
IoContextTcpSocket* socket = tmpSocket;
uint32 threadIndex = tmpThreadIndex;
_acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error)
_acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error) mutable
{
if (!error)
{
@@ -66,7 +67,7 @@ public:
}
if (!_closed)
this->AsyncAcceptWithCallback<acceptCallback>();
this->AsyncAccept(std::move(acceptCallback));
});
}
@@ -120,40 +121,17 @@ public:
_acceptor.close(err);
}
void SetSocketFactory(std::function<std::pair<boost::asio::ip::tcp::socket*, uint32>()> func) { _socketFactory = std::move(func); }
void SetSocketFactory(std::function<std::pair<IoContextTcpSocket*, uint32>()> func) { _socketFactory = std::move(func); }
private:
std::pair<boost::asio::ip::tcp::socket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
std::pair<IoContextTcpSocket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
boost::asio::ip::tcp::acceptor _acceptor;
boost::asio::basic_socket_acceptor<boost::asio::ip::tcp, IoContextTcpSocket::executor_type> _acceptor;
boost::asio::ip::tcp::endpoint _endpoint;
boost::asio::ip::tcp::socket _socket;
IoContextTcpSocket _socket;
std::atomic<bool> _closed;
std::function<std::pair<boost::asio::ip::tcp::socket*, uint32>()> _socketFactory;
std::function<std::pair<IoContextTcpSocket*, uint32>()> _socketFactory;
};
template<class T>
void AsyncAcceptor::AsyncAccept()
{
_acceptor.async_accept(_socket, [this](boost::system::error_code error)
{
if (!error)
{
try
{
// this-> is required here to fix an segmentation fault in gcc 4.7.2 - reason is lambdas in a templated class
std::make_shared<T>(std::move(this->_socket))->Start();
}
catch (boost::system::system_error const& err)
{
TC_LOG_INFO("network", "Failed to retrieve client's remote address {}", err.what());
}
}
// lets slap some more this-> on this so we can fix this bug with gcc 4.7.2 throwing internals in yo face
if (!_closed)
this->AsyncAccept<T>();
});
}
#endif /* __ASYNCACCEPT_H_ */
#endif // TRINITYCORE_ASYNC_ACCEPTOR_H

View File

@@ -0,0 +1,42 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "IpBanCheckConnectionInitializer.h"
#include "DatabaseEnv.h"
QueryCallback Trinity::Net::IpBanCheckHelpers::AsyncQuery(std::string_view ipAddress)
{
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
stmt->setString(0, ipAddress);
return LoginDatabase.AsyncQuery(stmt);
}
bool Trinity::Net::IpBanCheckHelpers::IsBanned(PreparedQueryResult const& result)
{
if (result)
{
do
{
Field* fields = result->Fetch();
if (fields[0].GetUInt64() != 0)
return true;
} while (result->NextRow());
}
return false;
}

View File

@@ -0,0 +1,64 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H
#define TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H
#include "DatabaseEnvFwd.h"
#include "Log.h"
#include "QueryCallback.h"
#include "SocketConnectionInitializer.h"
namespace Trinity::Net
{
namespace IpBanCheckHelpers
{
TC_SHARED_API QueryCallback AsyncQuery(std::string_view ipAddress);
TC_SHARED_API bool IsBanned(PreparedQueryResult const& result);
}
template <typename SocketImpl>
struct IpBanCheckConnectionInitializer final : SocketConnectionInitializer
{
explicit IpBanCheckConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
void Start() override
{
_socket->QueueQuery(IpBanCheckHelpers::AsyncQuery(_socket->GetRemoteIpAddress().to_string()).WithPreparedCallback([socketRef = _socket->weak_from_this(), self = this->shared_from_this()](PreparedQueryResult const& result)
{
std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock());
if (!socket)
return;
if (IpBanCheckHelpers::IsBanned(result))
{
TC_LOG_ERROR("network", "IpBanCheckConnectionInitializer: IP {} is banned.", socket->GetRemoteIpAddress().to_string());
socket->DelayedCloseSocket();
return;
}
if (self->next)
self->next->Start();
}));
}
private:
SocketImpl* _socket;
};
}
#endif // TRINITYCORE_IP_BAN_CHECK_CONNECTION_INITIALIZER_H

View File

@@ -0,0 +1,51 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
#define TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
#include <memory>
#include <span>
namespace Trinity::Net
{
struct SocketConnectionInitializer : public std::enable_shared_from_this<SocketConnectionInitializer>
{
SocketConnectionInitializer() = default;
SocketConnectionInitializer(SocketConnectionInitializer const&) = delete;
SocketConnectionInitializer(SocketConnectionInitializer&&) noexcept = default;
SocketConnectionInitializer& operator=(SocketConnectionInitializer const&) = delete;
SocketConnectionInitializer& operator=(SocketConnectionInitializer&&) noexcept = default;
virtual ~SocketConnectionInitializer() = default;
virtual void Start() = 0;
std::shared_ptr<SocketConnectionInitializer> next;
static std::shared_ptr<SocketConnectionInitializer>& SetupChain(std::span<std::shared_ptr<SocketConnectionInitializer>> initializers)
{
for (std::size_t i = initializers.size(); i > 1; --i)
initializers[i - 2]->next.swap(initializers[i - 1]);
return initializers[0];
}
};
}
#endif // TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H

View File

@@ -16,6 +16,7 @@
*/
#include "BaseHttpSocket.h"
#include <boost/asio/buffers_iterator.hpp>
#include <boost/beast/http/serializer.hpp>
namespace Trinity::Net::Http
@@ -112,4 +113,34 @@ MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response
return buffer;
}
void AbstractSocket::LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const
{
if (Logger const* logger = sLog->GetEnabledLogger("server.http", LOG_LEVEL_DEBUG))
{
std::string clientInfo = GetClientInfo();
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_DEBUG, "{} Request {} {} done, status {}", clientInfo,
ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()), context.response.result_int());
if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE))
{
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: {}", clientInfo,
CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>");
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: {}", clientInfo,
CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>");
}
}
}
std::string AbstractSocket::GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state)
{
std::string info = StringFormat("[{}:{}", address.to_string(), port);
if (state)
{
info.append(", Session Id: ");
info.append(boost::uuids::to_string(state->Id));
}
info += ']';
return info;
}
}

View File

@@ -25,13 +25,46 @@
#include "Optional.h"
#include "QueryCallback.h"
#include "Socket.h"
#include <boost/asio/buffers_iterator.hpp>
#include "SocketConnectionInitializer.h"
#include <boost/beast/core/basic_stream.hpp>
#include <boost/beast/http/parser.hpp>
#include <boost/beast/http/string_body.hpp>
#include <boost/uuid/uuid_io.hpp>
namespace Trinity::Net::Http
{
using IoContextHttpSocket = boost::beast::basic_stream<boost::asio::ip::tcp, boost::asio::io_context::executor_type, boost::beast::unlimited_rate_policy>;
namespace Impl
{
class BoostBeastSocketWrapper : public IoContextHttpSocket
{
public:
using IoContextHttpSocket::basic_stream;
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
{
socket().shutdown(what, shutdownError);
}
void close(boost::system::error_code& /*error*/)
{
IoContextHttpSocket::close();
}
template<typename WaitHandlerType>
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
{
socket().async_wait(type, std::forward<WaitHandlerType>(handler));
}
IoContextTcpSocket::endpoint_type remote_endpoint() const
{
return socket().remote_endpoint();
}
};
}
using RequestParser = boost::beast::http::request_parser<RequestBody>;
class TC_SHARED_API AbstractSocket
@@ -51,22 +84,59 @@ public:
virtual void SendResponse(RequestContext& context) = 0;
void LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const;
virtual void QueueQuery(QueryCallback&& queryCallback) = 0;
virtual std::string GetClientInfo() const = 0;
virtual Optional<boost::uuids::uuid> GetSessionId() const = 0;
static std::string GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state);
virtual SessionState* GetSessionState() const = 0;
Optional<boost::uuids::uuid> GetSessionId() const
{
if (SessionState* state = this->GetSessionState())
return state->Id;
return {};
}
virtual void Start() = 0;
virtual bool Update() = 0;
virtual boost::asio::ip::address const& GetRemoteIpAddress() const = 0;
virtual bool IsOpen() const = 0;
virtual void CloseSocket() = 0;
};
template<typename Derived, typename Stream>
class BaseSocket : public ::Socket<Derived, Stream>, public AbstractSocket
template <typename SocketImpl>
struct HttpConnectionInitializer final : SocketConnectionInitializer
{
using Base = ::Socket<Derived, Stream>;
explicit HttpConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
void Start() override
{
_socket->ResetHttpParser();
if (this->next)
this->next->Start();
}
private:
SocketImpl* _socket;
};
template<typename Stream>
class BaseSocket : public Trinity::Net::Socket<Stream>, public AbstractSocket
{
using Base = Trinity::Net::Socket<Stream>;
public:
template<typename... Args>
explicit BaseSocket(boost::asio::ip::tcp::socket&& socket, Args&&... args)
: Base(std::move(socket), std::forward<Args>(args)...) { }
using Base::Base;
BaseSocket(BaseSocket const& other) = delete;
BaseSocket(BaseSocket&& other) = delete;
@@ -75,11 +145,8 @@ public:
~BaseSocket() = default;
void ReadHandler() override
SocketReadCallbackResult ReadHandler() final
{
if (!this->IsOpen())
return;
MessageBuffer& packet = this->GetReadBuffer();
while (packet.GetActiveSize() > 0)
{
@@ -92,13 +159,13 @@ public:
if (!HandleMessage(_httpParser->get()))
{
this->CloseSocket();
break;
return SocketReadCallbackResult::Stop;
}
this->ResetHttpParser();
}
this->AsyncRead();
return SocketReadCallbackResult::KeepReading;
}
bool HandleMessage(Request& request)
@@ -118,19 +185,11 @@ public:
virtual RequestHandlerResult RequestHandler(RequestContext& context) = 0;
void SendResponse(RequestContext& context) override
void SendResponse(RequestContext& context) final
{
MessageBuffer buffer = SerializeResponse(context.request, context.response);
TC_LOG_DEBUG("server.http", "{} Request {} {} done, status {}", this->GetClientInfo(), ToStdStringView(context.request.method_string()),
ToStdStringView(context.request.target()), context.response.result_int());
if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE))
{
sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Request: {}", this->GetClientInfo(),
CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>");
sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Response: {}", this->GetClientInfo(),
CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>");
}
this->LogRequestAndResponse(context, buffer);
this->QueuePacket(std::move(buffer));
@@ -138,11 +197,13 @@ public:
this->DelayedCloseSocket();
}
void QueueQuery(QueryCallback&& queryCallback) override
void QueueQuery(QueryCallback&& queryCallback) final
{
this->_queryProcessor.AddCallback(std::move(queryCallback));
}
void Start() override { return this->Base::Start(); }
bool Update() override
{
if (!this->Base::Update())
@@ -152,27 +213,19 @@ public:
return true;
}
boost::asio::ip::address const& GetRemoteIpAddress() const final { return this->Base::GetRemoteIpAddress(); }
bool IsOpen() const final { return this->Base::IsOpen(); }
void CloseSocket() final { return this->Base::CloseSocket(); }
std::string GetClientInfo() const override
{
std::string info;
info.reserve(500);
auto itr = StringFormatTo(std::back_inserter(info), "[{}:{}", this->GetRemoteIpAddress().to_string(), this->GetRemotePort());
if (_state)
itr = StringFormatTo(itr, ", Session Id: {}", boost::uuids::to_string(_state->Id));
StringFormatTo(itr, "]");
return info;
return AbstractSocket::GetClientInfo(this->GetRemoteIpAddress(), this->GetRemotePort(), this->_state.get());
}
Optional<boost::uuids::uuid> GetSessionId() const final
{
if (this->_state)
return this->_state->Id;
SessionState* GetSessionState() const override { return _state.get(); }
return {};
}
protected:
void ResetHttpParser()
{
this->_httpParser.reset();
@@ -180,6 +233,7 @@ protected:
this->_httpParser->eager(true);
}
protected:
virtual std::shared_ptr<SessionState> ObtainSessionState(RequestContext& context) const = 0;
QueryCallbackProcessor _queryProcessor;

View File

@@ -160,7 +160,7 @@ protected:
class Thread : public NetworkThread<SessionImpl>
{
protected:
void SocketRemoved(std::shared_ptr<SessionImpl> session) override
void SocketRemoved(std::shared_ptr<SessionImpl> const& session) override
{
if (Optional<boost::uuids::uuid> id = session->GetSessionId())
_service->MarkSessionInactive(*id);

View File

@@ -19,43 +19,16 @@
#define TRINITYCORE_HTTP_SOCKET_H
#include "BaseHttpSocket.h"
#include <boost/beast/core/tcp_stream.hpp>
#include <array>
namespace Trinity::Net::Http
{
namespace Impl
class Socket : public BaseSocket<Impl::BoostBeastSocketWrapper>
{
class BoostBeastSocketWrapper : public boost::beast::tcp_stream
{
public:
using boost::beast::tcp_stream::tcp_stream;
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
{
socket().shutdown(what, shutdownError);
}
void close(boost::system::error_code& /*error*/)
{
boost::beast::tcp_stream::close();
}
boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const
{
return socket().remote_endpoint();
}
};
}
template <typename Derived>
class Socket : public BaseSocket<Derived, Impl::BoostBeastSocketWrapper>
{
using SocketBase = BaseSocket<Derived, Impl::BoostBeastSocketWrapper>;
using SocketBase = BaseSocket<Impl::BoostBeastSocketWrapper>;
public:
template<typename... Args>
explicit Socket(boost::asio::ip::tcp::socket&& socket, Args&&...)
: SocketBase(std::move(socket)) { }
using SocketBase::SocketBase;
Socket(Socket const& other) = delete;
Socket(Socket&& other) = delete;
@@ -66,9 +39,13 @@ public:
void Start() override
{
this->ResetHttpParser();
std::array<std::shared_ptr<SocketConnectionInitializer>, 2> initializers =
{ {
std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
} };
this->AsyncRead();
SocketConnectionInitializer::SetupChain(initializers)->Start();
}
};
}

View File

@@ -19,47 +19,21 @@
#define TRINITYCORE_HTTP_SSL_SOCKET_H
#include "BaseHttpSocket.h"
#include "SslSocket.h"
#include <boost/beast/core/stream_traits.hpp>
#include <boost/beast/core/tcp_stream.hpp>
#include <boost/beast/ssl/ssl_stream.hpp>
#include "SslStream.h"
namespace Trinity::Net::Http
{
namespace Impl
class SslSocket : public BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>
{
class BoostBeastSslSocketWrapper : public ::SslSocket<boost::beast::ssl_stream<boost::beast::tcp_stream>>
{
public:
using SslSocket::SslSocket;
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
{
_sslSocket.shutdown(shutdownError);
boost::beast::get_lowest_layer(_sslSocket).socket().shutdown(what, shutdownError);
}
void close(boost::system::error_code& /*error*/)
{
boost::beast::get_lowest_layer(_sslSocket).close();
}
boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const
{
return boost::beast::get_lowest_layer(_sslSocket).socket().remote_endpoint();
}
};
}
template <typename Derived>
class SslSocket : public BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper>
{
using SocketBase = BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper>;
using SocketBase = BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>;
public:
explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext)
explicit SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext)
: SocketBase(std::move(socket), sslContext) { }
explicit SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext)
: SocketBase(context, sslContext) { }
SslSocket(SslSocket const& other) = delete;
SslSocket(SslSocket&& other) = delete;
SslSocket& operator=(SslSocket const& other) = delete;
@@ -69,27 +43,14 @@ public:
void Start() override
{
this->AsyncHandshake();
}
std::array<std::shared_ptr<SocketConnectionInitializer>, 3> initializers =
{ {
std::make_shared<SslHandshakeConnectionInitializer<SocketBase>>(this),
std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
} };
void AsyncHandshake()
{
this->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
[self = this->shared_from_this()](boost::system::error_code const& error) { self->HandshakeHandler(error); });
}
void HandshakeHandler(boost::system::error_code const& error)
{
if (error)
{
TC_LOG_ERROR("server.http.session.ssl", "{} SSL Handshake failed {}", this->GetClientInfo(), error.message());
this->CloseSocket();
return;
}
this->ResetHttpParser();
this->AsyncRead();
SocketConnectionInitializer::SetupChain(initializers)->Start();
}
};
}

View File

@@ -15,20 +15,24 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef NetworkThread_h__
#define NetworkThread_h__
#ifndef TRINITYCORE_NETWORK_THREAD_H
#define TRINITYCORE_NETWORK_THREAD_H
#include "Define.h"
#include "Containers.h"
#include "DeadlineTimer.h"
#include "Define.h"
#include "Errors.h"
#include "IoContext.h"
#include "Log.h"
#include "Socket.h"
#include <boost/asio/ip/tcp.hpp>
#include <atomic>
#include <memory>
#include <mutex>
#include <thread>
namespace Trinity::Net
{
template<class SocketType>
class NetworkThread
{
@@ -38,14 +42,16 @@ public:
{
}
NetworkThread(NetworkThread const&) = delete;
NetworkThread(NetworkThread&&) = delete;
NetworkThread& operator=(NetworkThread const&) = delete;
NetworkThread& operator=(NetworkThread&&) = delete;
virtual ~NetworkThread()
{
Stop();
if (_thread)
{
Wait();
delete _thread;
}
}
void Stop()
@@ -59,7 +65,7 @@ public:
if (_thread)
return false;
_thread = new std::thread(&NetworkThread::Run, this);
_thread = std::make_unique<std::thread>(&NetworkThread::Run, this);
return true;
}
@@ -68,7 +74,6 @@ public:
ASSERT(_thread);
_thread->join();
delete _thread;
_thread = nullptr;
}
@@ -82,15 +87,14 @@ public:
std::lock_guard<std::mutex> lock(_newSocketsLock);
++_connections;
_newSockets.push_back(sock);
SocketAdded(sock);
SocketAdded(_newSockets.emplace_back(std::move(sock)));
}
boost::asio::ip::tcp::socket* GetSocketForAccept() { return &_acceptSocket; }
Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; }
protected:
virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { }
virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { }
virtual void SocketAdded(std::shared_ptr<SocketType> const& /*sock*/) { }
virtual void SocketRemoved(std::shared_ptr<SocketType> const& /*sock*/) { }
void AddNewSockets()
{
@@ -99,7 +103,7 @@ protected:
if (_newSockets.empty())
return;
for (std::shared_ptr<SocketType> sock : _newSockets)
for (std::shared_ptr<SocketType>& sock : _newSockets)
{
if (!sock->IsOpen())
{
@@ -107,7 +111,7 @@ protected:
--_connections;
}
else
_sockets.push_back(sock);
_sockets.emplace_back(std::move(sock));
}
_newSockets.clear();
@@ -136,7 +140,7 @@ protected:
AddNewSockets();
_sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr<SocketType> sock)
Trinity::Containers::EraseIf(_sockets, [this](std::shared_ptr<SocketType> const& sock)
{
if (!sock->Update())
{
@@ -150,7 +154,7 @@ protected:
}
return false;
}), _sockets.end());
});
}
private:
@@ -159,7 +163,7 @@ private:
std::atomic<int32> _connections;
std::atomic<bool> _stopped;
std::thread* _thread;
std::unique_ptr<std::thread> _thread;
SocketContainer _sockets;
@@ -167,8 +171,9 @@ private:
SocketContainer _newSockets;
Trinity::Asio::IoContext _ioContext;
boost::asio::ip::tcp::socket _acceptSocket;
Trinity::Net::IoContextTcpSocket _acceptSocket;
Trinity::Asio::DeadlineTimer _updateTimer;
};
}
#endif // NetworkThread_h__
#endif // TRINITYCORE_NETWORK_THREAD_H

View File

@@ -15,11 +15,14 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef __SOCKET_H__
#define __SOCKET_H__
#ifndef TRINITYCORE_SOCKET_H
#define TRINITYCORE_SOCKET_H
#include "Concepts.h"
#include "Log.h"
#include "MessageBuffer.h"
#include "SocketConnectionInitializer.h"
#include <boost/asio/io_context.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <atomic>
#include <memory>
@@ -31,12 +34,51 @@
#define TC_SOCKET_USE_IOCP
#endif
namespace Trinity::Net
{
using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>;
enum class SocketReadCallbackResult
{
KeepReading,
Stop
};
template <typename Callable>
concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>;
template <typename SocketType>
struct InvokeReadHandlerCallback
{
SocketReadCallbackResult operator()() const
{
return this->Socket->ReadHandler();
}
SocketType* Socket;
};
template <typename SocketType>
struct ReadConnectionInitializer final : SocketConnectionInitializer
{
explicit ReadConnectionInitializer(SocketType* socket) : ReadCallback({ .Socket = socket }) { }
void Start() override
{
ReadCallback.Socket->AsyncRead(std::move(ReadCallback));
if (this->next)
this->next->Start();
}
InvokeReadHandlerCallback<SocketType> ReadCallback;
};
/**
@class Socket
Base async socket implementation
@tparam T derived class type (CRTP)
@tparam Stream stream type used for operations on socket
Stream must implement the following methods:
@@ -53,21 +95,27 @@
template<typename ConstBufferSequence>
std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error);
template<typename WaitHandlerType>
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler);
template<typename SettableSocketOption>
void set_option(SettableSocketOption const& option, boost::system::error_code& error);
tcp::socket::endpoint_type remote_endpoint() const;
*/
template<class T, class Stream = boost::asio::ip::tcp::socket>
class Socket : public std::enable_shared_from_this<T>
template<class Stream = IoContextTcpSocket>
class Socket : public std::enable_shared_from_this<Socket<Stream>>
{
public:
template<typename... Args>
explicit Socket(boost::asio::ip::tcp::socket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...),
_remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()),
_closed(false), _closing(false), _isWritingAsync(false)
explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...),
_remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open)
{
}
template<typename... Args>
explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward<Args>(args)...), _openState(OpenState_Closed)
{
_readBuffer.Resize(READ_BLOCK_SIZE);
}
Socket(Socket const& other) = delete;
@@ -77,20 +125,20 @@ public:
virtual ~Socket()
{
_closed = true;
_openState = OpenState_Closed;
boost::system::error_code error;
_socket.close(error);
}
virtual void Start() = 0;
virtual void Start() { }
virtual bool Update()
{
if (_closed)
if (_openState == OpenState_Closed)
return false;
#ifndef TC_SOCKET_USE_IOCP
if (_isWritingAsync || (_writeQueue.empty() && !_closing))
if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open))
return true;
for (; HandleQueue();)
@@ -100,7 +148,7 @@ public:
return true;
}
boost::asio::ip::address GetRemoteIpAddress() const
boost::asio::ip::address const& GetRemoteIpAddress() const
{
return _remoteAddress;
}
@@ -110,7 +158,8 @@ public:
return _remotePort;
}
void AsyncRead()
template <SocketReadCallback Callback>
void AsyncRead(Callback&& callback)
{
if (!IsOpen())
return;
@@ -118,23 +167,11 @@ public:
_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
[self = this->shared_from_this()](boost::system::error_code const& error, size_t transferredBytes)
[self = this->shared_from_this(), callback = std::forward<Callback>(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable
{
self->ReadHandlerInternal(error, transferredBytes);
});
}
void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code const&, std::size_t))
{
if (!IsOpen())
return;
_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
[self = this->shared_from_this(), callback](boost::system::error_code const& error, size_t transferredBytes)
{
(self.get()->*callback)(error, transferredBytes);
if (self->ReadHandlerInternal(error, transferredBytes))
if (callback() == SocketReadCallbackResult::KeepReading)
self->AsyncRead(std::forward<Callback>(callback));
});
}
@@ -147,11 +184,11 @@ public:
#endif
}
bool IsOpen() const { return !_closed && !_closing; }
bool IsOpen() const { return _openState == OpenState_Open; }
void CloseSocket()
{
if (_closed.exchange(true))
if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0)
return;
boost::system::error_code shutdownError;
@@ -160,13 +197,13 @@ public:
TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress().to_string(),
shutdownError.value(), shutdownError.message());
OnClose();
this->OnClose();
}
/// Marks the socket for closing after write buffer becomes empty
void DelayedCloseSocket()
{
if (_closing.exchange(true))
if (_openState.fetch_or(OpenState_Closing) != 0)
return;
if (_writeQueue.empty())
@@ -175,10 +212,15 @@ public:
MessageBuffer& GetReadBuffer() { return _readBuffer; }
Stream& underlying_stream()
{
return _socket;
}
protected:
virtual void OnClose() { }
virtual void ReadHandler() = 0;
virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; }
bool AsyncProcessQueue()
{
@@ -195,10 +237,10 @@ protected:
self->WriteHandler(error, transferedBytes);
});
#else
_socket.async_write_some(boost::asio::null_buffers(),
[self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes)
_socket.async_wait(boost::asio::socket_base::wait_type::wait_write,
[self = this->shared_from_this()](boost::system::error_code const& error)
{
self->WriteHandlerWrapper(error, transferedBytes);
self->WriteHandlerWrapper(error);
});
#endif
@@ -214,22 +256,17 @@ protected:
GetRemoteIpAddress().to_string(), err.value(), err.message());
}
Stream& underlying_stream()
{
return _socket;
}
private:
void ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes)
bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes)
{
if (error)
{
CloseSocket();
return;
return false;
}
_readBuffer.WriteCompleted(transferredBytes);
ReadHandler();
return IsOpen();
}
#ifdef TC_SOCKET_USE_IOCP
@@ -245,7 +282,7 @@ private:
if (!_writeQueue.empty())
AsyncProcessQueue();
else if (_closing)
else if (_openState == OpenState_Closing)
CloseSocket();
}
else
@@ -254,7 +291,7 @@ private:
#else
void WriteHandlerWrapper(boost::system::error_code const& /*error*/, std::size_t /*transferedBytes*/)
void WriteHandlerWrapper(boost::system::error_code const& /*error*/)
{
_isWritingAsync = false;
HandleQueue();
@@ -278,14 +315,14 @@ private:
return AsyncProcessQueue();
_writeQueue.pop();
if (_closing && _writeQueue.empty())
if (_openState == OpenState_Closing && _writeQueue.empty())
CloseSocket();
return false;
}
else if (bytesSent == 0)
{
_writeQueue.pop();
if (_closing && _writeQueue.empty())
if (_openState == OpenState_Closing && _writeQueue.empty())
CloseSocket();
return false;
}
@@ -296,7 +333,7 @@ private:
}
_writeQueue.pop();
if (_closing && _writeQueue.empty())
if (_openState == OpenState_Closing && _writeQueue.empty())
CloseSocket();
return !_writeQueue.empty();
}
@@ -306,15 +343,20 @@ private:
Stream _socket;
boost::asio::ip::address _remoteAddress;
uint16 _remotePort;
uint16 _remotePort = 0;
MessageBuffer _readBuffer;
MessageBuffer _readBuffer = MessageBuffer(READ_BLOCK_SIZE);
std::queue<MessageBuffer> _writeQueue;
std::atomic<bool> _closed;
std::atomic<bool> _closing;
// Socket open state "enum" (not enum to enable integral std::atomic api)
static constexpr uint8 OpenState_Open = 0x0;
static constexpr uint8 OpenState_Closing = 0x1; ///< Transition to Closed state after sending all queued data
static constexpr uint8 OpenState_Closed = 0x2;
bool _isWritingAsync;
std::atomic<uint8> _openState;
bool _isWritingAsync = false;
};
}
#endif // __SOCKET_H__
#endif // TRINITYCORE_SOCKET_H

View File

@@ -15,32 +15,40 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef SocketMgr_h__
#define SocketMgr_h__
#ifndef TRINITYCORE_SOCKET_MGR_H
#define TRINITYCORE_SOCKET_MGR_H
#include "AsyncAcceptor.h"
#include "Errors.h"
#include "NetworkThread.h"
#include "Socket.h"
#include <boost/asio/ip/tcp.hpp>
#include <memory>
namespace Trinity::Net
{
template<class SocketType>
class SocketMgr
{
public:
SocketMgr(SocketMgr const&) = delete;
SocketMgr(SocketMgr&&) = delete;
SocketMgr& operator=(SocketMgr const&) = delete;
SocketMgr& operator=(SocketMgr&&) = delete;
virtual ~SocketMgr()
{
ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction");
}
virtual bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
virtual bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
{
ASSERT(threadCount > 0);
AsyncAcceptor* acceptor = nullptr;
std::unique_ptr<AsyncAcceptor> acceptor = nullptr;
try
{
acceptor = new AsyncAcceptor(ioContext, bindIp, port);
acceptor = std::make_unique<AsyncAcceptor>(ioContext, bindIp, port);
}
catch (boost::system::system_error const& err)
{
@@ -51,13 +59,12 @@ public:
if (!acceptor->Bind())
{
TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor");
delete acceptor;
return false;
}
_acceptor = acceptor;
_acceptor = std::move(acceptor);
_threadCount = threadCount;
_threads = CreateThreads();
_threads.reset(CreateThreads());
ASSERT(_threads);
@@ -73,27 +80,23 @@ public:
{
_acceptor->Close();
if (_threadCount != 0)
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Stop();
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Stop();
Wait();
delete _acceptor;
_acceptor = nullptr;
delete[] _threads;
_threads = nullptr;
_threadCount = 0;
}
void Wait()
{
if (_threadCount != 0)
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Wait();
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Wait();
}
virtual void OnSocketOpen(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex)
virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex)
{
try
{
@@ -121,22 +124,23 @@ public:
return min;
}
std::pair<boost::asio::ip::tcp::socket*, uint32> GetSocketForAccept()
std::pair<IoContextTcpSocket*, uint32> GetSocketForAccept()
{
uint32 threadIndex = SelectThreadWithMinConnections();
return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
}
protected:
SocketMgr() : _acceptor(nullptr), _threads(nullptr), _threadCount(0)
SocketMgr() : _threadCount(0)
{
}
virtual NetworkThread<SocketType>* CreateThreads() const = 0;
AsyncAcceptor* _acceptor;
NetworkThread<SocketType>* _threads;
std::unique_ptr<AsyncAcceptor> _acceptor;
std::unique_ptr<NetworkThread<SocketType>[]> _threads;
int32 _threadCount;
};
}
#endif // SocketMgr_h__
#endif // TRINITYCORE_SOCKET_MGR_H

View File

@@ -1,88 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef SslSocket_h__
#define SslSocket_h__
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/ssl/stream.hpp>
#include <boost/system/error_code.hpp>
namespace boostssl = boost::asio::ssl;
template<class Stream = boostssl::stream<boost::asio::ip::tcp::socket>>
class SslSocket
{
public:
explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext)
{
_sslSocket.set_verify_mode(boostssl::verify_none);
}
// adapting tcp::socket api
void close(boost::system::error_code& error)
{
_sslSocket.lowest_layer().close(error);
}
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
{
_sslSocket.shutdown(shutdownError);
_sslSocket.lowest_layer().shutdown(what, shutdownError);
}
template<typename MutableBufferSequence, typename ReadHandlerType>
void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler)
{
_sslSocket.async_read_some(buffers, std::move(handler));
}
template<typename ConstBufferSequence, typename WriteHandlerType>
void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler)
{
_sslSocket.async_write_some(buffers, std::move(handler));
}
template<typename ConstBufferSequence>
std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error)
{
return _sslSocket.write_some(buffers, error);
}
template<typename SettableSocketOption>
void set_option(SettableSocketOption const& option, boost::system::error_code& error)
{
_sslSocket.lowest_layer().set_option(option, error);
}
boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const
{
return _sslSocket.lowest_layer().remote_endpoint();
}
// ssl api
template<typename HandshakeHandlerType>
void async_handshake(boostssl::stream_base::handshake_type type, HandshakeHandlerType&& handler)
{
_sslSocket.async_handshake(type, std::move(handler));
}
protected:
Stream _sslSocket;
};
#endif // SslSocket_h__

View File

@@ -0,0 +1,131 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_SSL_STREAM_H
#define TRINITYCORE_SSL_STREAM_H
#include "SocketConnectionInitializer.h"
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/ssl/stream.hpp>
#include <boost/system/error_code.hpp>
namespace Trinity::Net
{
template <typename SocketImpl>
struct SslHandshakeConnectionInitializer final : SocketConnectionInitializer
{
explicit SslHandshakeConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
void Start() override
{
_socket->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
[socketRef = _socket->weak_from_this(), self = this->shared_from_this()](boost::system::error_code const& error)
{
std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock());
if (!socket)
return;
if (error)
{
TC_LOG_ERROR("session", "{} SSL Handshake failed {}", socket->GetClientInfo(), error.message());
socket->CloseSocket();
return;
}
if (self->next)
self->next->Start();
});
}
private:
SocketImpl* _socket;
};
template<class WrappedStream = IoContextTcpSocket>
class SslStream
{
public:
explicit SslStream(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext)
{
_sslSocket.set_verify_mode(boost::asio::ssl::verify_none);
}
explicit SslStream(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) : _sslSocket(context, sslContext)
{
_sslSocket.set_verify_mode(boost::asio::ssl::verify_none);
}
// adapting tcp::socket api
void close(boost::system::error_code& error)
{
_sslSocket.next_layer().close(error);
}
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
{
_sslSocket.shutdown(shutdownError);
_sslSocket.next_layer().shutdown(what, shutdownError);
}
template<typename MutableBufferSequence, typename ReadHandlerType>
void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler)
{
_sslSocket.async_read_some(buffers, std::forward<ReadHandlerType>(handler));
}
template<typename ConstBufferSequence, typename WriteHandlerType>
void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler)
{
_sslSocket.async_write_some(buffers, std::forward<WriteHandlerType>(handler));
}
template<typename ConstBufferSequence>
std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error)
{
return _sslSocket.write_some(buffers, error);
}
template<typename WaitHandlerType>
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
{
_sslSocket.next_layer().async_wait(type, std::forward<WaitHandlerType>(handler));
}
template<typename SettableSocketOption>
void set_option(SettableSocketOption const& option, boost::system::error_code& error)
{
_sslSocket.next_layer().set_option(option, error);
}
IoContextTcpSocket::endpoint_type remote_endpoint() const
{
return _sslSocket.next_layer().remote_endpoint();
}
// ssl api
template<typename HandshakeHandlerType>
void async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler)
{
_sslSocket.async_handshake(type, std::forward<HandshakeHandlerType>(handler));
}
protected:
boost::asio::ssl::stream<WrappedStream> _sslSocket;
};
}
#endif // TRINITYCORE_SSL_STREAM_H