Core/Network: Move to separate project

This commit is contained in:
Shauren
2025-04-09 21:02:31 +02:00
parent 6c374c56b2
commit 71b681bbf0
38 changed files with 245 additions and 156 deletions

View File

@@ -41,6 +41,7 @@ target_link_libraries(shared
trinity-core-interface
PUBLIC
database
network
rapidjson
proto
zlib)

View File

@@ -0,0 +1,150 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "IPLocation.h"
#include "BigNumber.h"
#include "Config.h"
#include "Errors.h"
#include "IpAddress.h"
#include "Log.h"
#include "Util.h"
#include <algorithm>
#include <fstream>
IpLocationStore::IpLocationStore() = default;
IpLocationStore::~IpLocationStore() = default;
void IpLocationStore::Load()
{
_ipLocationStore.clear();
TC_LOG_INFO("server.loading", "Loading IP Location Database...");
std::string databaseFilePath = sConfigMgr->GetStringDefault("IPLocationFile", "");
if (databaseFilePath.empty())
return;
// Check if file exists
std::ifstream databaseFile(databaseFilePath);
if (!databaseFile)
{
TC_LOG_ERROR("server.loading", "IPLocation: No ip database file exists ({}).", databaseFilePath);
return;
}
if (!databaseFile.is_open())
{
TC_LOG_ERROR("server.loading", "IPLocation: Ip database file ({}) can not be opened.", databaseFilePath);
return;
}
std::string ipFrom;
std::string ipTo;
std::string countryCode;
std::string countryName;
BigNumber bnParser;
BigNumber ipv4Max(0xFFFFFFFF);
BigNumber ipv6MappedMask(0xFFFF);
ipv6MappedMask <<= 32;
auto parseStringToIPv6 = [&](std::string const& str) -> Optional<std::array<uint8, 16>>
{
bnParser.SetDecStr(str);
if (!bnParser.SetDecStr(str))
return {};
// convert ipv4 to ipv6 v4 mapped value
if (bnParser <= ipv4Max)
bnParser += ipv6MappedMask;
return bnParser.ToByteArray<16>(false);
};
while (databaseFile.good())
{
// Read lines
if (!std::getline(databaseFile, ipFrom, ','))
break;
if (!std::getline(databaseFile, ipTo, ','))
break;
if (!std::getline(databaseFile, countryCode, ','))
break;
if (!std::getline(databaseFile, countryName, '\n'))
break;
// Remove new lines and return
std::erase_if(countryName, [](char c) { return c == '\r' || c == '\n'; });
// Remove quotation marks
std::erase(ipFrom, '"');
std::erase(ipTo, '"');
std::erase(countryCode, '"');
std::erase(countryName, '"');
if (countryCode == "-")
continue;
// Convert country code to lowercase
strToLower(countryCode);
Optional<std::array<uint8, 16>> from = parseStringToIPv6(ipFrom);
if (!from)
continue;
Optional<std::array<uint8, 16>> to = parseStringToIPv6(ipTo);
if (!to)
continue;
_ipLocationStore.emplace_back(*from, *to, std::move(countryCode), std::move(countryName));
}
std::ranges::sort(_ipLocationStore, {}, &IpLocationRecord::IpFrom);
ASSERT(std::ranges::is_sorted(_ipLocationStore, [](IpLocationRecord const& a, IpLocationRecord const& b) { return a.IpFrom < b.IpTo; }),
"Overlapping IP ranges detected in database file");
databaseFile.close();
TC_LOG_INFO("server.loading", ">> Loaded {} ip location entries.", _ipLocationStore.size());
}
IpLocationRecord const* IpLocationStore::GetLocationRecord(std::string const& ipAddress) const
{
boost::system::error_code error;
boost::asio::ip::address address = Trinity::Net::make_address(ipAddress, error);
if (error)
return nullptr;
std::array<uint8, 16> bytes = [&]() -> std::array<uint8, 16>
{
if (address.is_v6())
return address.to_v6().to_bytes();
if (address.is_v4())
return Trinity::Net::make_address_v6(Trinity::Net::v4_mapped, address.to_v4()).to_bytes();
return {};
}();
auto itr = std::ranges::upper_bound(_ipLocationStore, bytes, {}, &IpLocationRecord::IpTo);
if (itr == _ipLocationStore.end())
return nullptr;
if (bytes < itr->IpFrom)
return nullptr;
return &(*itr);
}
IpLocationStore* IpLocationStore::Instance()
{
static IpLocationStore instance;
return &instance;
}

View File

@@ -0,0 +1,58 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef IPLOCATION_H
#define IPLOCATION_H
#include "Define.h"
#include <array>
#include <string>
#include <vector>
struct IpLocationRecord
{
IpLocationRecord() : IpFrom(), IpTo() { }
IpLocationRecord(std::array<uint8, 16> ipFrom, std::array<uint8, 16> ipTo, std::string&& countryCode, std::string&& countryName)
: IpFrom(ipFrom), IpTo(ipTo), CountryCode(std::move(countryCode)), CountryName(std::move(countryName)) { }
std::array<uint8, 16> IpFrom;
std::array<uint8, 16> IpTo;
std::string CountryCode;
std::string CountryName;
};
class TC_SHARED_API IpLocationStore
{
public:
IpLocationStore();
IpLocationStore(IpLocationStore const&) = delete;
IpLocationStore(IpLocationStore&&) = delete;
IpLocationStore& operator=(IpLocationStore const&) = delete;
IpLocationStore& operator=(IpLocationStore&&) = delete;
~IpLocationStore();
static IpLocationStore* Instance();
void Load();
IpLocationRecord const* GetLocationRecord(std::string const& ipAddress) const;
private:
std::vector<IpLocationRecord> _ipLocationStore;
};
#define sIPLocation IpLocationStore::Instance()
#endif

View File

@@ -1,137 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_ASYNC_ACCEPTOR_H
#define TRINITYCORE_ASYNC_ACCEPTOR_H
#include "IoContext.h"
#include "IpAddress.h"
#include "Log.h"
#include "Socket.h"
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/ip/v6_only.hpp>
#include <atomic>
#include <functional>
#define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections
namespace Trinity::Net
{
template <typename Callable>
concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&, uint32>;
class AsyncAcceptor
{
public:
AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
_acceptor(ioContext), _endpoint(make_address(bindIp), port),
_socket(ioContext), _closed(false), _socketFactory([this] { return DefeaultSocketFactory(); })
{
}
template <AcceptCallback Callback>
void AsyncAccept(Callback&& acceptCallback)
{
auto [tmpSocket, tmpThreadIndex] = _socketFactory();
// TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures)
IoContextTcpSocket* socket = tmpSocket;
uint32 threadIndex = tmpThreadIndex;
_acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error) mutable
{
if (!error)
{
try
{
socket->non_blocking(true);
acceptCallback(std::move(*socket), threadIndex);
}
catch (boost::system::system_error const& err)
{
TC_LOG_INFO("network", "Failed to initialize client's socket {}", err.what());
}
}
if (!_closed)
this->AsyncAccept(std::move(acceptCallback));
});
}
bool Bind()
{
boost::system::error_code errorCode;
_acceptor.open(_endpoint.protocol(), errorCode);
if (errorCode)
{
TC_LOG_INFO("network", "Failed to open acceptor {}", errorCode.message());
return false;
}
#if TRINITY_PLATFORM != TRINITY_PLATFORM_WINDOWS
_acceptor.set_option(boost::asio::ip::tcp::acceptor::reuse_address(true), errorCode);
if (errorCode)
{
TC_LOG_INFO("network", "Failed to set reuse_address option on acceptor {}", errorCode.message());
return false;
}
#endif
// v6_only is enabled on some *BSD distributions by default
// we want to allow both v4 and v6 connections to the same listener
if (_endpoint.protocol() == boost::asio::ip::tcp::v6())
_acceptor.set_option(boost::asio::ip::v6_only(false));
_acceptor.bind(_endpoint, errorCode);
if (errorCode)
{
TC_LOG_INFO("network", "Could not bind to {}:{} {}", _endpoint.address().to_string(), _endpoint.port(), errorCode.message());
return false;
}
_acceptor.listen(TRINITY_MAX_LISTEN_CONNECTIONS, errorCode);
if (errorCode)
{
TC_LOG_INFO("network", "Failed to start listening on {}:{} {}", _endpoint.address().to_string(), _endpoint.port(), errorCode.message());
return false;
}
return true;
}
void Close()
{
if (_closed.exchange(true))
return;
boost::system::error_code err;
_acceptor.close(err);
}
void SetSocketFactory(std::function<std::pair<IoContextTcpSocket*, uint32>()> func) { _socketFactory = std::move(func); }
private:
std::pair<IoContextTcpSocket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
boost::asio::basic_socket_acceptor<boost::asio::ip::tcp, IoContextTcpSocket::executor_type> _acceptor;
boost::asio::ip::tcp::endpoint _endpoint;
IoContextTcpSocket _socket;
std::atomic<bool> _closed;
std::function<std::pair<IoContextTcpSocket*, uint32>()> _socketFactory;
};
}
#endif // TRINITYCORE_ASYNC_ACCEPTOR_H

View File

@@ -47,7 +47,7 @@ struct IpBanCheckConnectionInitializer final : SocketConnectionInitializer
if (IpBanCheckHelpers::IsBanned(result))
{
TC_LOG_ERROR("network", "IpBanCheckConnectionInitializer: IP {} is banned.", socket->GetRemoteIpAddress().to_string());
socket->DelayedCloseSocket();
socket->CloseSocket();
return;
}

View File

@@ -1,51 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
#define TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
#include <memory>
#include <span>
namespace Trinity::Net
{
struct SocketConnectionInitializer : public std::enable_shared_from_this<SocketConnectionInitializer>
{
SocketConnectionInitializer() = default;
SocketConnectionInitializer(SocketConnectionInitializer const&) = delete;
SocketConnectionInitializer(SocketConnectionInitializer&&) noexcept = default;
SocketConnectionInitializer& operator=(SocketConnectionInitializer const&) = delete;
SocketConnectionInitializer& operator=(SocketConnectionInitializer&&) noexcept = default;
virtual ~SocketConnectionInitializer() = default;
virtual void Start() = 0;
std::shared_ptr<SocketConnectionInitializer> next;
static std::shared_ptr<SocketConnectionInitializer>& SetupChain(std::span<std::shared_ptr<SocketConnectionInitializer>> initializers)
{
for (std::size_t i = initializers.size(); i > 1; --i)
initializers[i - 2]->next.swap(initializers[i - 1]);
return initializers[0];
}
};
}
#endif // TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H

View File

@@ -1,146 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "BaseHttpSocket.h"
#include <boost/asio/buffers_iterator.hpp>
#include <boost/beast/http/serializer.hpp>
namespace Trinity::Net::Http
{
using RequestSerializer = boost::beast::http::request_serializer<ResponseBody>;
using ResponseSerializer = boost::beast::http::response_serializer<ResponseBody>;
bool AbstractSocket::ParseRequest(MessageBuffer& packet, RequestParser& parser)
{
if (!parser.is_done())
{
// need more data in the payload
boost::system::error_code ec = {};
std::size_t readDataSize = parser.put(boost::asio::const_buffer(packet.GetReadPointer(), packet.GetActiveSize()), ec);
packet.ReadCompleted(readDataSize);
}
return parser.is_done();
}
std::string AbstractSocket::SerializeRequest(Request const& request)
{
RequestSerializer serializer(request);
std::string buffer;
while (!serializer.is_done())
{
size_t totalBytes = 0;
boost::system::error_code ec = {};
serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code const&, ConstBufferSequence const& buffers)
{
size_t totalBytesInBuffers = boost::asio::buffer_size(buffers);
buffer.reserve(buffer.size() + totalBytes);
auto begin = boost::asio::buffers_begin(buffers);
auto end = boost::asio::buffers_end(buffers);
std::copy(begin, end, std::back_inserter(buffer));
totalBytes += totalBytesInBuffers;
});
serializer.consume(totalBytes);
}
return buffer;
}
MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response& response)
{
response.prepare_payload();
ResponseSerializer serializer(response);
bool (*serializerIsDone)(ResponseSerializer&);
if (request.method() != boost::beast::http::verb::head)
{
serializerIsDone = [](ResponseSerializer& s) { return s.is_done(); };
}
else
{
serializerIsDone = [](ResponseSerializer& s) { return s.is_header_done(); };
serializer.split(true);
}
MessageBuffer buffer;
while (!serializerIsDone(serializer))
{
serializer.limit(buffer.GetRemainingSpace());
size_t totalBytes = 0;
boost::system::error_code ec = {};
serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code& currentError, ConstBufferSequence const& buffers)
{
size_t totalBytesInBuffers = boost::asio::buffer_size(buffers);
if (totalBytesInBuffers > buffer.GetRemainingSpace())
{
currentError = boost::beast::http::error::need_more;
return;
}
auto begin = boost::asio::buffers_begin(buffers);
auto end = boost::asio::buffers_end(buffers);
std::copy(begin, end, buffer.GetWritePointer());
buffer.WriteCompleted(totalBytesInBuffers);
totalBytes += totalBytesInBuffers;
});
serializer.consume(totalBytes);
if (ec == boost::beast::http::error::need_more)
buffer.Resize(buffer.GetBufferSize() + 4096);
}
return buffer;
}
void AbstractSocket::LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const
{
if (Logger const* logger = sLog->GetEnabledLogger("server.http", LOG_LEVEL_DEBUG))
{
std::string clientInfo = GetClientInfo();
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_DEBUG, "{} Request {} {} done, status {}", clientInfo,
ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()), context.response.result_int());
if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE))
{
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: {}", clientInfo,
CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>");
sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: {}", clientInfo,
CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>");
}
}
}
std::string AbstractSocket::GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state)
{
std::string info = StringFormat("[{}:{}", address.to_string(), port);
if (state)
{
info.append(", Session Id: ");
info.append(boost::uuids::to_string(state->Id));
}
info += ']';
return info;
}
}

View File

@@ -1,245 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_BASE_HTTP_SOCKET_H
#define TRINITYCORE_BASE_HTTP_SOCKET_H
#include "AsyncCallbackProcessor.h"
#include "DatabaseEnvFwd.h"
#include "HttpCommon.h"
#include "HttpSessionState.h"
#include "Optional.h"
#include "QueryCallback.h"
#include "Socket.h"
#include "SocketConnectionInitializer.h"
#include <boost/beast/core/basic_stream.hpp>
#include <boost/beast/http/parser.hpp>
#include <boost/beast/http/string_body.hpp>
#include <boost/uuid/uuid_io.hpp>
namespace Trinity::Net::Http
{
using IoContextHttpSocket = boost::beast::basic_stream<boost::asio::ip::tcp, boost::asio::io_context::executor_type, boost::beast::unlimited_rate_policy>;
namespace Impl
{
class BoostBeastSocketWrapper : public IoContextHttpSocket
{
public:
using IoContextHttpSocket::basic_stream;
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
{
socket().shutdown(what, shutdownError);
}
void close(boost::system::error_code& /*error*/)
{
IoContextHttpSocket::close();
}
template<typename WaitHandlerType>
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
{
socket().async_wait(type, std::forward<WaitHandlerType>(handler));
}
IoContextTcpSocket::endpoint_type remote_endpoint() const
{
return socket().remote_endpoint();
}
};
}
using RequestParser = boost::beast::http::request_parser<RequestBody>;
class TC_SHARED_API AbstractSocket
{
public:
AbstractSocket() = default;
AbstractSocket(AbstractSocket const& other) = default;
AbstractSocket(AbstractSocket&& other) = default;
AbstractSocket& operator=(AbstractSocket const& other) = default;
AbstractSocket& operator=(AbstractSocket&& other) = default;
virtual ~AbstractSocket() = default;
static bool ParseRequest(MessageBuffer& packet, RequestParser& parser);
static std::string SerializeRequest(Request const& request);
static MessageBuffer SerializeResponse(Request const& request, Response& response);
virtual void SendResponse(RequestContext& context) = 0;
void LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const;
virtual void QueueQuery(QueryCallback&& queryCallback) = 0;
virtual std::string GetClientInfo() const = 0;
static std::string GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state);
virtual SessionState* GetSessionState() const = 0;
Optional<boost::uuids::uuid> GetSessionId() const
{
if (SessionState* state = this->GetSessionState())
return state->Id;
return {};
}
virtual void Start() = 0;
virtual bool Update() = 0;
virtual boost::asio::ip::address const& GetRemoteIpAddress() const = 0;
virtual bool IsOpen() const = 0;
virtual void CloseSocket() = 0;
};
template <typename SocketImpl>
struct HttpConnectionInitializer final : SocketConnectionInitializer
{
explicit HttpConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
void Start() override
{
_socket->ResetHttpParser();
if (this->next)
this->next->Start();
}
private:
SocketImpl* _socket;
};
template<typename Stream>
class BaseSocket : public Trinity::Net::Socket<Stream>, public AbstractSocket
{
using Base = Trinity::Net::Socket<Stream>;
public:
using Base::Base;
BaseSocket(BaseSocket const& other) = delete;
BaseSocket(BaseSocket&& other) = delete;
BaseSocket& operator=(BaseSocket const& other) = delete;
BaseSocket& operator=(BaseSocket&& other) = delete;
~BaseSocket() = default;
SocketReadCallbackResult ReadHandler() final
{
MessageBuffer& packet = this->GetReadBuffer();
while (packet.GetActiveSize() > 0)
{
if (!ParseRequest(packet, *_httpParser))
{
// Couldn't receive the whole data this time.
break;
}
if (!HandleMessage(_httpParser->get()))
{
this->CloseSocket();
return SocketReadCallbackResult::Stop;
}
this->ResetHttpParser();
}
return SocketReadCallbackResult::KeepReading;
}
bool HandleMessage(Request& request)
{
RequestContext context { .request = std::move(request) };
if (!_state)
_state = this->ObtainSessionState(context);
RequestHandlerResult status = this->RequestHandler(context);
if (status != RequestHandlerResult::Async)
this->SendResponse(context);
return status != RequestHandlerResult::Error;
}
virtual RequestHandlerResult RequestHandler(RequestContext& context) = 0;
void SendResponse(RequestContext& context) final
{
MessageBuffer buffer = SerializeResponse(context.request, context.response);
this->LogRequestAndResponse(context, buffer);
this->QueuePacket(std::move(buffer));
if (!context.response.keep_alive())
this->DelayedCloseSocket();
}
void QueueQuery(QueryCallback&& queryCallback) final
{
this->_queryProcessor.AddCallback(std::move(queryCallback));
}
void Start() override { return this->Base::Start(); }
bool Update() override
{
if (!this->Base::Update())
return false;
this->_queryProcessor.ProcessReadyCallbacks();
return true;
}
boost::asio::ip::address const& GetRemoteIpAddress() const final { return this->Base::GetRemoteIpAddress(); }
bool IsOpen() const final { return this->Base::IsOpen(); }
void CloseSocket() final { return this->Base::CloseSocket(); }
std::string GetClientInfo() const override
{
return AbstractSocket::GetClientInfo(this->GetRemoteIpAddress(), this->GetRemotePort(), this->_state.get());
}
SessionState* GetSessionState() const override { return _state.get(); }
void ResetHttpParser()
{
this->_httpParser.reset();
this->_httpParser.emplace();
this->_httpParser->eager(true);
}
protected:
virtual std::shared_ptr<SessionState> ObtainSessionState(RequestContext& context) const = 0;
QueryCallbackProcessor _queryProcessor;
Optional<RequestParser> _httpParser;
std::shared_ptr<SessionState> _state;
};
}
#endif // TRINITYCORE_BASE_HTTP_SOCKET_H

View File

@@ -1,55 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_HTTP_COMMON_H
#define TRINITYCORE_HTTP_COMMON_H
#include "Define.h"
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/string_body.hpp>
namespace Trinity::Net::Http
{
using RequestBody = boost::beast::http::string_body;
using ResponseBody = boost::beast::http::string_body;
using Request = boost::beast::http::request<RequestBody>;
using Response = boost::beast::http::response<ResponseBody>;
struct RequestContext
{
Request request;
Response response;
struct RequestHandler const* handler = nullptr;
};
TC_SHARED_API bool CanLogRequestContent(RequestContext const& context);
TC_SHARED_API bool CanLogResponseContent(RequestContext const& context);
inline std::string_view ToStdStringView(boost::beast::string_view bsw)
{
return { bsw.data(), bsw.size() };
}
enum class RequestHandlerResult
{
Handled,
Error,
Async,
};
}
#endif // TRINITYCORE_HTTP_COMMON_H

View File

@@ -1,267 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "HttpService.h"
#include "BaseHttpSocket.h"
#include "CryptoRandom.h"
#include "Timezone.h"
#include <boost/beast/version.hpp>
#include <boost/uuid/string_generator.hpp>
#include <fmt/chrono.h>
namespace Trinity::Net::Http
{
bool CanLogRequestContent(RequestContext const& context)
{
return !context.handler || !context.handler->Flags.HasFlag(RequestHandlerFlag::DoNotLogRequestContent);
}
bool CanLogResponseContent(RequestContext const& context)
{
return !context.handler || !context.handler->Flags.HasFlag(RequestHandlerFlag::DoNotLogResponseContent);
}
RequestHandlerResult DispatcherService::HandleRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context)
{
TC_LOG_DEBUG(_logger, "{} Starting request {} {}", session->GetClientInfo(),
ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()));
std::string_view path = [&]
{
std::string_view path = ToStdStringView(context.request.target());
size_t queryIndex = path.find('?');
if (queryIndex != std::string_view::npos)
path = path.substr(0, queryIndex);
return path;
}();
context.handler = [&]() -> HttpMethodHandlerMap::mapped_type const*
{
switch (context.request.method())
{
case boost::beast::http::verb::get:
case boost::beast::http::verb::head:
{
auto itr = _getHandlers.find(path);
return itr != _getHandlers.end() ? &itr->second : nullptr;
}
case boost::beast::http::verb::post:
{
auto itr = _postHandlers.find(path);
return itr != _postHandlers.end() ? &itr->second : nullptr;
}
default:
break;
}
return nullptr;
}();
SystemTimePoint responseDate = SystemTimePoint::clock::now();
context.response.set(boost::beast::http::field::date, StringFormat("{:%a, %d %b %Y %T GMT}", responseDate - Timezone::GetSystemZoneOffsetAt(responseDate)));
context.response.set(boost::beast::http::field::server, BOOST_BEAST_VERSION_STRING);
context.response.keep_alive(context.request.keep_alive());
if (!context.handler)
return HandlePathNotFound(std::move(session), context);
return context.handler->Func(std::move(session), context);
}
RequestHandlerResult DispatcherService::HandleBadRequest(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context)
{
context.response.result(boost::beast::http::status::bad_request);
return RequestHandlerResult::Handled;
}
RequestHandlerResult DispatcherService::HandleUnauthorized(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context)
{
context.response.result(boost::beast::http::status::unauthorized);
return RequestHandlerResult::Handled;
}
RequestHandlerResult DispatcherService::HandlePathNotFound(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context)
{
context.response.result(boost::beast::http::status::not_found);
return RequestHandlerResult::Handled;
}
void DispatcherService::RegisterHandler(boost::beast::http::verb method, std::string_view path,
std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> handler,
RequestHandlerFlag flags)
{
HttpMethodHandlerMap& handlerMap = [&]() -> HttpMethodHandlerMap&
{
switch (method)
{
case boost::beast::http::verb::get:
return _getHandlers;
case boost::beast::http::verb::post:
return _postHandlers;
default:
{
std::string_view methodString = ToStdStringView(boost::beast::http::to_string(method));
ABORT_MSG("Tried to register a handler for unsupported HTTP method " STRING_VIEW_FMT, STRING_VIEW_FMT_ARG(methodString));
}
}
}();
handlerMap[std::string(path)] = { .Func = std::move(handler), .Flags = flags };
TC_LOG_INFO(_logger, "Registered new handler for {} {}", ToStdStringView(boost::beast::http::to_string(method)), path);
}
void SessionService::InitAndStoreSessionState(std::shared_ptr<SessionState> state, boost::asio::ip::address const& address)
{
state->RemoteAddress = address;
// Generate session id
{
std::unique_lock lock{ _sessionsMutex };
while (state->Id.is_nil() || _sessions.contains(state->Id))
std::copy_n(Trinity::Crypto::GetRandomBytes<16>().begin(), 16, state->Id.begin());
TC_LOG_DEBUG(_logger, "Client at {} created new session {}", address.to_string(), boost::uuids::to_string(state->Id));
_sessions[state->Id] = std::move(state);
}
}
void SessionService::Start(Asio::IoContext& ioContext)
{
_inactiveSessionsKillTimer = std::make_unique<Asio::DeadlineTimer>(ioContext);
_inactiveSessionsKillTimer->expires_after(1min);
_inactiveSessionsKillTimer->async_wait([this](boost::system::error_code const& err)
{
if (err)
return;
KillInactiveSessions();
});
}
void SessionService::Stop()
{
_inactiveSessionsKillTimer = nullptr;
{
std::unique_lock lock{ _sessionsMutex };
_sessions.clear();
}
{
std::unique_lock lock{ _inactiveSessionsMutex };
_inactiveSessions.clear();
}
}
std::shared_ptr<SessionState> SessionService::FindAndRefreshSessionState(std::string_view id, boost::asio::ip::address const& address)
{
std::shared_ptr<SessionState> state;
{
std::shared_lock lock{ _sessionsMutex };
auto itr = _sessions.find(boost::uuids::string_generator()(id.begin(), id.end()));
if (itr == _sessions.end())
{
TC_LOG_DEBUG(_logger, "Client at {} attempted to use a session {} that was expired", address.to_string(), id);
return nullptr; // no session
}
state = itr->second;
}
if (state->RemoteAddress != address)
{
TC_LOG_ERROR(_logger, "Client at {} attempted to use a session {} that was last accessed from {}, denied access",
address.to_string(), id, state->RemoteAddress.to_string());
return nullptr;
}
{
std::unique_lock inactiveSessionsLock{ _inactiveSessionsMutex };
_inactiveSessions.erase(state->Id);
}
return state;
}
void SessionService::MarkSessionInactive(boost::uuids::uuid const& id)
{
bool wasActive = true;
{
std::unique_lock inactiveSessionsLock{ _inactiveSessionsMutex };
wasActive = _inactiveSessions.insert(id).second;
}
if (wasActive)
{
std::shared_lock lock{ _sessionsMutex };
auto itr = _sessions.find(id);
if (itr != _sessions.end())
{
itr->second->InactiveTimestamp = TimePoint::clock::now() + Minutes(5);
TC_LOG_TRACE(_logger, "Session {} marked as inactive", boost::uuids::to_string(id));
}
}
}
void SessionService::KillInactiveSessions()
{
std::set<boost::uuids::uuid> inactiveSessions;
{
std::unique_lock lock{ _inactiveSessionsMutex };
std::swap(_inactiveSessions, inactiveSessions);
}
{
TimePoint now = TimePoint::clock::now();
std::size_t inactiveSessionsCount = inactiveSessions.size();
std::unique_lock lock{ _sessionsMutex };
for (auto itr = inactiveSessions.begin(); itr != inactiveSessions.end(); )
{
auto sessionItr = _sessions.find(*itr);
if (sessionItr == _sessions.end() || sessionItr->second->InactiveTimestamp < now)
{
_sessions.erase(sessionItr);
itr = inactiveSessions.erase(itr);
}
else
++itr;
}
TC_LOG_DEBUG(_logger, "Killed {} inactive sessions", inactiveSessionsCount - inactiveSessions.size());
}
{
// restore sessions not killed to inactive queue
std::unique_lock lock{ _inactiveSessionsMutex };
for (auto itr = inactiveSessions.begin(); itr != inactiveSessions.end(); )
{
auto node = inactiveSessions.extract(itr++);
_inactiveSessions.insert(std::move(node));
}
}
_inactiveSessionsKillTimer->expires_after(1min);
_inactiveSessionsKillTimer->async_wait([this](boost::system::error_code const& err)
{
if (err)
return;
KillInactiveSessions();
});
}
}

View File

@@ -1,188 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_HTTP_SERVICE_H
#define TRINITYCORE_HTTP_SERVICE_H
#include "AsioHacksFwd.h"
#include "Concepts.h"
#include "Define.h"
#include "EnumFlag.h"
#include "HttpCommon.h"
#include "HttpSessionState.h"
#include "Optional.h"
#include "SocketMgr.h"
#include <boost/uuid/uuid.hpp>
#include <functional>
#include <map>
#include <set>
#include <shared_mutex>
namespace Trinity::Net::Http
{
class AbstractSocket;
enum class RequestHandlerFlag
{
None = 0x0,
DoNotLogRequestContent = 0x1,
DoNotLogResponseContent = 0x2,
};
DEFINE_ENUM_FLAG(RequestHandlerFlag);
struct RequestHandler
{
std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> Func;
EnumFlag<RequestHandlerFlag> Flags = RequestHandlerFlag::None;
};
class TC_SHARED_API DispatcherService
{
public:
explicit DispatcherService(std::string_view loggerSuffix) : _logger("server.http.dispatcher.")
{
_logger.append(loggerSuffix);
}
RequestHandlerResult HandleRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context);
static RequestHandlerResult HandleBadRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context);
static RequestHandlerResult HandleUnauthorized(std::shared_ptr<AbstractSocket> session, RequestContext& context);
static RequestHandlerResult HandlePathNotFound(std::shared_ptr<AbstractSocket> session, RequestContext& context);
protected:
void RegisterHandler(boost::beast::http::verb method, std::string_view path,
std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> handler,
RequestHandlerFlag flags = RequestHandlerFlag::None);
private:
using HttpMethodHandlerMap = std::map<std::string, RequestHandler, std::less<>>;
HttpMethodHandlerMap _getHandlers;
HttpMethodHandlerMap _postHandlers;
std::string _logger;
};
class TC_SHARED_API SessionService
{
public:
explicit SessionService(std::string_view loggerSuffix) : _logger("server.http.session.")
{
_logger.append(loggerSuffix);
}
void Start(Asio::IoContext& ioContext);
void Stop();
std::shared_ptr<SessionState> FindAndRefreshSessionState(std::string_view id, boost::asio::ip::address const& address);
void MarkSessionInactive(boost::uuids::uuid const& id);
protected:
void InitAndStoreSessionState(std::shared_ptr<SessionState> state, boost::asio::ip::address const& address);
void KillInactiveSessions();
private:
std::shared_mutex _sessionsMutex;
std::map<boost::uuids::uuid, std::shared_ptr<SessionState>> _sessions;
std::mutex _inactiveSessionsMutex;
std::set<boost::uuids::uuid> _inactiveSessions;
std::unique_ptr<Asio::DeadlineTimer> _inactiveSessionsKillTimer;
std::string _logger;
};
template<typename Callable, typename SessionImpl>
concept HttpRequestHandler = invocable_r<Callable, RequestHandlerResult, std::shared_ptr<SessionImpl>, RequestContext&>;
template<typename SessionImpl>
class HttpService : public SocketMgr<SessionImpl>, public DispatcherService, public SessionService
{
public:
HttpService(std::string_view loggerSuffix) : DispatcherService(loggerSuffix), SessionService(loggerSuffix), _ioContext(nullptr), _logger("server.http.")
{
_logger.append(loggerSuffix);
}
bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount = 1) override
{
if (!SocketMgr<SessionImpl>::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
SessionService::Start(ioContext);
return true;
}
void StopNetwork() override
{
SessionService::Stop();
SocketMgr<SessionImpl>::StopNetwork();
}
// http handling
using DispatcherService::RegisterHandler;
template<HttpRequestHandler<SessionImpl> Callable>
void RegisterHandler(boost::beast::http::verb method, std::string_view path, Callable handler, RequestHandlerFlag flags = RequestHandlerFlag::None)
{
this->DispatcherService::RegisterHandler(method, path, [handler = std::move(handler)](std::shared_ptr<AbstractSocket> session, RequestContext& context) -> RequestHandlerResult
{
return handler(std::static_pointer_cast<SessionImpl>(std::move(session)), context);
}, flags);
}
// session tracking
virtual std::shared_ptr<SessionState> CreateNewSessionState(boost::asio::ip::address const& address)
{
std::shared_ptr<SessionState> state = std::make_shared<SessionState>();
InitAndStoreSessionState(state, address);
return state;
}
protected:
class Thread : public NetworkThread<SessionImpl>
{
protected:
void SocketRemoved(std::shared_ptr<SessionImpl> const& session) override
{
if (Optional<boost::uuids::uuid> id = session->GetSessionId())
_service->MarkSessionInactive(*id);
}
private:
friend HttpService;
SessionService* _service;
};
NetworkThread<SessionImpl>* CreateThreads() const override
{
Thread* threads = new Thread[this->GetNetworkThreadCount()];
for (int32 i = 0; i < this->GetNetworkThreadCount(); ++i)
threads[i]._service = const_cast<HttpService*>(this);
return threads;
}
Asio::IoContext* _ioContext;
std::string _logger;
};
}
#endif // TRINITYCORE_HTTP_SERVICE_H

View File

@@ -1,35 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_HTTP_SESSION_STATE_H
#define TRINITYCORE_HTTP_SESSION_STATE_H
#include "Duration.h"
#include <boost/asio/ip/address.hpp>
#include <boost/uuid/uuid.hpp>
namespace Trinity::Net::Http
{
struct SessionState
{
boost::uuids::uuid Id = { };
boost::asio::ip::address RemoteAddress;
TimePoint InactiveTimestamp = TimePoint::max();
};
}
#endif // TRINITYCORE_HTTP_SESSION_STATE_H

View File

@@ -1,53 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_HTTP_SOCKET_H
#define TRINITYCORE_HTTP_SOCKET_H
#include "BaseHttpSocket.h"
#include <array>
namespace Trinity::Net::Http
{
class Socket : public BaseSocket<Impl::BoostBeastSocketWrapper>
{
using SocketBase = BaseSocket<Impl::BoostBeastSocketWrapper>;
public:
using SocketBase::SocketBase;
Socket(Socket const& other) = delete;
Socket(Socket&& other) = delete;
Socket& operator=(Socket const& other) = delete;
Socket& operator=(Socket&& other) = delete;
~Socket() = default;
void Start() override
{
std::array<std::shared_ptr<SocketConnectionInitializer>, 2> initializers =
{ {
std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
} };
SocketConnectionInitializer::SetupChain(initializers)->Start();
}
};
}
#endif // TRINITYCORE_HTTP_SOCKET_H

View File

@@ -1,58 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_HTTP_SSL_SOCKET_H
#define TRINITYCORE_HTTP_SSL_SOCKET_H
#include "BaseHttpSocket.h"
#include "SslStream.h"
namespace Trinity::Net::Http
{
class SslSocket : public BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>
{
using SocketBase = BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>;
public:
explicit SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext)
: SocketBase(std::move(socket), sslContext) { }
explicit SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext)
: SocketBase(context, sslContext) { }
SslSocket(SslSocket const& other) = delete;
SslSocket(SslSocket&& other) = delete;
SslSocket& operator=(SslSocket const& other) = delete;
SslSocket& operator=(SslSocket&& other) = delete;
~SslSocket() = default;
void Start() override
{
std::array<std::shared_ptr<SocketConnectionInitializer>, 3> initializers =
{ {
std::make_shared<SslHandshakeConnectionInitializer<SocketBase>>(this),
std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
} };
SocketConnectionInitializer::SetupChain(initializers)->Start();
}
};
}
#endif // TRINITYCORE_HTTP_SSL_SOCKET_H

View File

@@ -1,179 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_NETWORK_THREAD_H
#define TRINITYCORE_NETWORK_THREAD_H
#include "Containers.h"
#include "DeadlineTimer.h"
#include "Define.h"
#include "Errors.h"
#include "IoContext.h"
#include "Log.h"
#include "Socket.h"
#include <boost/asio/ip/tcp.hpp>
#include <atomic>
#include <memory>
#include <mutex>
#include <thread>
namespace Trinity::Net
{
template<class SocketType>
class NetworkThread
{
public:
NetworkThread() : _connections(0), _stopped(false), _thread(nullptr), _ioContext(1),
_acceptSocket(_ioContext), _updateTimer(_ioContext)
{
}
NetworkThread(NetworkThread const&) = delete;
NetworkThread(NetworkThread&&) = delete;
NetworkThread& operator=(NetworkThread const&) = delete;
NetworkThread& operator=(NetworkThread&&) = delete;
virtual ~NetworkThread()
{
Stop();
if (_thread)
Wait();
}
void Stop()
{
_stopped = true;
_ioContext.stop();
}
bool Start()
{
if (_thread)
return false;
_thread = std::make_unique<std::thread>(&NetworkThread::Run, this);
return true;
}
void Wait()
{
ASSERT(_thread);
_thread->join();
_thread = nullptr;
}
int32 GetConnectionCount() const
{
return _connections;
}
void AddSocket(std::shared_ptr<SocketType> sock)
{
std::lock_guard<std::mutex> lock(_newSocketsLock);
++_connections;
SocketAdded(_newSockets.emplace_back(std::move(sock)));
}
Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; }
protected:
virtual void SocketAdded(std::shared_ptr<SocketType> const& /*sock*/) { }
virtual void SocketRemoved(std::shared_ptr<SocketType> const& /*sock*/) { }
void AddNewSockets()
{
std::lock_guard<std::mutex> lock(_newSocketsLock);
if (_newSockets.empty())
return;
for (std::shared_ptr<SocketType>& sock : _newSockets)
{
if (!sock->IsOpen())
{
SocketRemoved(sock);
--_connections;
}
else
_sockets.emplace_back(std::move(sock));
}
_newSockets.clear();
}
void Run()
{
TC_LOG_DEBUG("misc", "Network Thread Starting");
_updateTimer.expires_after(1ms);
_updateTimer.async_wait([this](boost::system::error_code const&) { Update(); });
_ioContext.run();
TC_LOG_DEBUG("misc", "Network Thread exits");
_newSockets.clear();
_sockets.clear();
}
void Update()
{
if (_stopped)
return;
_updateTimer.expires_after(1ms);
_updateTimer.async_wait([this](boost::system::error_code const&) { Update(); });
AddNewSockets();
Trinity::Containers::EraseIf(_sockets, [this](std::shared_ptr<SocketType> const& sock)
{
if (!sock->Update())
{
if (sock->IsOpen())
sock->CloseSocket();
this->SocketRemoved(sock);
--this->_connections;
return true;
}
return false;
});
}
private:
typedef std::vector<std::shared_ptr<SocketType>> SocketContainer;
std::atomic<int32> _connections;
std::atomic<bool> _stopped;
std::unique_ptr<std::thread> _thread;
SocketContainer _sockets;
std::mutex _newSocketsLock;
SocketContainer _newSockets;
Trinity::Asio::IoContext _ioContext;
Trinity::Net::IoContextTcpSocket _acceptSocket;
Trinity::Asio::DeadlineTimer _updateTimer;
};
}
#endif // TRINITYCORE_NETWORK_THREAD_H

View File

@@ -1,362 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_SOCKET_H
#define TRINITYCORE_SOCKET_H
#include "Concepts.h"
#include "Log.h"
#include "MessageBuffer.h"
#include "SocketConnectionInitializer.h"
#include <boost/asio/io_context.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <atomic>
#include <memory>
#include <queue>
#include <type_traits>
#define READ_BLOCK_SIZE 4096
#ifdef BOOST_ASIO_HAS_IOCP
#define TC_SOCKET_USE_IOCP
#endif
namespace Trinity::Net
{
using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>;
enum class SocketReadCallbackResult
{
KeepReading,
Stop
};
template <typename Callable>
concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>;
template <typename SocketType>
struct InvokeReadHandlerCallback
{
SocketReadCallbackResult operator()() const
{
return this->Socket->ReadHandler();
}
SocketType* Socket;
};
template <typename SocketType>
struct ReadConnectionInitializer final : SocketConnectionInitializer
{
explicit ReadConnectionInitializer(SocketType* socket) : ReadCallback({ .Socket = socket }) { }
void Start() override
{
ReadCallback.Socket->AsyncRead(std::move(ReadCallback));
if (this->next)
this->next->Start();
}
InvokeReadHandlerCallback<SocketType> ReadCallback;
};
/**
@class Socket
Base async socket implementation
@tparam Stream stream type used for operations on socket
Stream must implement the following methods:
void close(boost::system::error_code& error);
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError);
template<typename MutableBufferSequence, typename ReadHandlerType>
void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler);
template<typename ConstBufferSequence, typename WriteHandlerType>
void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler);
template<typename ConstBufferSequence>
std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error);
template<typename WaitHandlerType>
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler);
template<typename SettableSocketOption>
void set_option(SettableSocketOption const& option, boost::system::error_code& error);
tcp::socket::endpoint_type remote_endpoint() const;
*/
template<class Stream = IoContextTcpSocket>
class Socket : public std::enable_shared_from_this<Socket<Stream>>
{
public:
template<typename... Args>
explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...),
_remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open)
{
}
template<typename... Args>
explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward<Args>(args)...), _openState(OpenState_Closed)
{
}
Socket(Socket const& other) = delete;
Socket(Socket&& other) = delete;
Socket& operator=(Socket const& other) = delete;
Socket& operator=(Socket&& other) = delete;
virtual ~Socket()
{
_openState = OpenState_Closed;
boost::system::error_code error;
_socket.close(error);
}
virtual void Start() { }
virtual bool Update()
{
if (_openState == OpenState_Closed)
return false;
#ifndef TC_SOCKET_USE_IOCP
if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open))
return true;
for (; HandleQueue();)
;
#endif
return true;
}
boost::asio::ip::address const& GetRemoteIpAddress() const
{
return _remoteAddress;
}
uint16 GetRemotePort() const
{
return _remotePort;
}
template <SocketReadCallback Callback>
void AsyncRead(Callback&& callback)
{
if (!IsOpen())
return;
_readBuffer.Normalize();
_readBuffer.EnsureFreeSpace();
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
[self = this->shared_from_this(), callback = std::forward<Callback>(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable
{
if (self->ReadHandlerInternal(error, transferredBytes))
if (callback() == SocketReadCallbackResult::KeepReading)
self->AsyncRead(std::forward<Callback>(callback));
});
}
void QueuePacket(MessageBuffer&& buffer)
{
_writeQueue.push(std::move(buffer));
#ifdef TC_SOCKET_USE_IOCP
AsyncProcessQueue();
#endif
}
bool IsOpen() const { return _openState == OpenState_Open; }
void CloseSocket()
{
if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0)
return;
boost::system::error_code shutdownError;
_socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError);
if (shutdownError)
TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress().to_string(),
shutdownError.value(), shutdownError.message());
this->OnClose();
}
/// Marks the socket for closing after write buffer becomes empty
void DelayedCloseSocket()
{
if (_openState.fetch_or(OpenState_Closing) != 0)
return;
if (_writeQueue.empty())
CloseSocket();
}
MessageBuffer& GetReadBuffer() { return _readBuffer; }
Stream& underlying_stream()
{
return _socket;
}
protected:
virtual void OnClose() { }
virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; }
bool AsyncProcessQueue()
{
if (_isWritingAsync)
return false;
_isWritingAsync = true;
#ifdef TC_SOCKET_USE_IOCP
MessageBuffer& buffer = _writeQueue.front();
_socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()),
[self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes)
{
self->WriteHandler(error, transferedBytes);
});
#else
_socket.async_wait(boost::asio::socket_base::wait_type::wait_write,
[self = this->shared_from_this()](boost::system::error_code const& error)
{
self->WriteHandlerWrapper(error);
});
#endif
return false;
}
void SetNoDelay(bool enable)
{
boost::system::error_code err;
_socket.set_option(boost::asio::ip::tcp::no_delay(enable), err);
if (err)
TC_LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for {} - {} ({})",
GetRemoteIpAddress().to_string(), err.value(), err.message());
}
private:
bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes)
{
if (error)
{
CloseSocket();
return false;
}
_readBuffer.WriteCompleted(transferredBytes);
return IsOpen();
}
#ifdef TC_SOCKET_USE_IOCP
void WriteHandler(boost::system::error_code const& error, std::size_t transferedBytes)
{
if (!error)
{
_isWritingAsync = false;
_writeQueue.front().ReadCompleted(transferedBytes);
if (!_writeQueue.front().GetActiveSize())
_writeQueue.pop();
if (!_writeQueue.empty())
AsyncProcessQueue();
else if (_openState == OpenState_Closing)
CloseSocket();
}
else
CloseSocket();
}
#else
void WriteHandlerWrapper(boost::system::error_code const& /*error*/)
{
_isWritingAsync = false;
HandleQueue();
}
bool HandleQueue()
{
if (_writeQueue.empty())
return false;
MessageBuffer& queuedMessage = _writeQueue.front();
std::size_t bytesToSend = queuedMessage.GetActiveSize();
boost::system::error_code error;
std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error);
if (error)
{
if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
return AsyncProcessQueue();
_writeQueue.pop();
if (_openState == OpenState_Closing && _writeQueue.empty())
CloseSocket();
return false;
}
else if (bytesSent == 0)
{
_writeQueue.pop();
if (_openState == OpenState_Closing && _writeQueue.empty())
CloseSocket();
return false;
}
else if (bytesSent < bytesToSend) // now n > 0
{
queuedMessage.ReadCompleted(bytesSent);
return AsyncProcessQueue();
}
_writeQueue.pop();
if (_openState == OpenState_Closing && _writeQueue.empty())
CloseSocket();
return !_writeQueue.empty();
}
#endif
Stream _socket;
boost::asio::ip::address _remoteAddress;
uint16 _remotePort = 0;
MessageBuffer _readBuffer = MessageBuffer(READ_BLOCK_SIZE);
std::queue<MessageBuffer> _writeQueue;
// Socket open state "enum" (not enum to enable integral std::atomic api)
static constexpr uint8 OpenState_Open = 0x0;
static constexpr uint8 OpenState_Closing = 0x1; ///< Transition to Closed state after sending all queued data
static constexpr uint8 OpenState_Closed = 0x2;
std::atomic<uint8> _openState;
bool _isWritingAsync = false;
};
}
#endif // TRINITYCORE_SOCKET_H

View File

@@ -1,146 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_SOCKET_MGR_H
#define TRINITYCORE_SOCKET_MGR_H
#include "AsyncAcceptor.h"
#include "Errors.h"
#include "NetworkThread.h"
#include "Socket.h"
#include <boost/asio/ip/tcp.hpp>
#include <memory>
namespace Trinity::Net
{
template<class SocketType>
class SocketMgr
{
public:
SocketMgr(SocketMgr const&) = delete;
SocketMgr(SocketMgr&&) = delete;
SocketMgr& operator=(SocketMgr const&) = delete;
SocketMgr& operator=(SocketMgr&&) = delete;
virtual ~SocketMgr()
{
ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction");
}
virtual bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
{
ASSERT(threadCount > 0);
std::unique_ptr<AsyncAcceptor> acceptor = nullptr;
try
{
acceptor = std::make_unique<AsyncAcceptor>(ioContext, bindIp, port);
}
catch (boost::system::system_error const& err)
{
TC_LOG_ERROR("network", "Exception caught in SocketMgr.StartNetwork ({}:{}): {}", bindIp, port, err.what());
return false;
}
if (!acceptor->Bind())
{
TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor");
return false;
}
_acceptor = std::move(acceptor);
_threadCount = threadCount;
_threads.reset(CreateThreads());
ASSERT(_threads);
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Start();
_acceptor->SetSocketFactory([this]() { return GetSocketForAccept(); });
return true;
}
virtual void StopNetwork()
{
_acceptor->Close();
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Stop();
Wait();
_acceptor = nullptr;
_threads = nullptr;
_threadCount = 0;
}
void Wait()
{
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Wait();
}
virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex)
{
try
{
std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
newSocket->Start();
_threads[threadIndex].AddSocket(newSocket);
}
catch (boost::system::system_error const& err)
{
TC_LOG_WARN("network", "Failed to retrieve client's remote address {}", err.what());
}
}
int32 GetNetworkThreadCount() const { return _threadCount; }
uint32 SelectThreadWithMinConnections() const
{
uint32 min = 0;
for (int32 i = 1; i < _threadCount; ++i)
if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
min = i;
return min;
}
std::pair<IoContextTcpSocket*, uint32> GetSocketForAccept()
{
uint32 threadIndex = SelectThreadWithMinConnections();
return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
}
protected:
SocketMgr() : _threadCount(0)
{
}
virtual NetworkThread<SocketType>* CreateThreads() const = 0;
std::unique_ptr<AsyncAcceptor> _acceptor;
std::unique_ptr<NetworkThread<SocketType>[]> _threads;
int32 _threadCount;
};
}
#endif // TRINITYCORE_SOCKET_MGR_H

View File

@@ -1,131 +0,0 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_SSL_STREAM_H
#define TRINITYCORE_SSL_STREAM_H
#include "SocketConnectionInitializer.h"
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/ssl/stream.hpp>
#include <boost/system/error_code.hpp>
namespace Trinity::Net
{
template <typename SocketImpl>
struct SslHandshakeConnectionInitializer final : SocketConnectionInitializer
{
explicit SslHandshakeConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
void Start() override
{
_socket->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
[socketRef = _socket->weak_from_this(), self = this->shared_from_this()](boost::system::error_code const& error)
{
std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock());
if (!socket)
return;
if (error)
{
TC_LOG_ERROR("session", "{} SSL Handshake failed {}", socket->GetClientInfo(), error.message());
socket->CloseSocket();
return;
}
if (self->next)
self->next->Start();
});
}
private:
SocketImpl* _socket;
};
template<class WrappedStream = IoContextTcpSocket>
class SslStream
{
public:
explicit SslStream(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext)
{
_sslSocket.set_verify_mode(boost::asio::ssl::verify_none);
}
explicit SslStream(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) : _sslSocket(context, sslContext)
{
_sslSocket.set_verify_mode(boost::asio::ssl::verify_none);
}
// adapting tcp::socket api
void close(boost::system::error_code& error)
{
_sslSocket.next_layer().close(error);
}
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
{
_sslSocket.shutdown(shutdownError);
_sslSocket.next_layer().shutdown(what, shutdownError);
}
template<typename MutableBufferSequence, typename ReadHandlerType>
void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler)
{
_sslSocket.async_read_some(buffers, std::forward<ReadHandlerType>(handler));
}
template<typename ConstBufferSequence, typename WriteHandlerType>
void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler)
{
_sslSocket.async_write_some(buffers, std::forward<WriteHandlerType>(handler));
}
template<typename ConstBufferSequence>
std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error)
{
return _sslSocket.write_some(buffers, error);
}
template<typename WaitHandlerType>
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
{
_sslSocket.next_layer().async_wait(type, std::forward<WaitHandlerType>(handler));
}
template<typename SettableSocketOption>
void set_option(SettableSocketOption const& option, boost::system::error_code& error)
{
_sslSocket.next_layer().set_option(option, error);
}
IoContextTcpSocket::endpoint_type remote_endpoint() const
{
return _sslSocket.next_layer().remote_endpoint();
}
// ssl api
template<typename HandshakeHandlerType>
void async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler)
{
_sslSocket.async_handshake(type, std::forward<HandshakeHandlerType>(handler));
}
protected:
boost::asio::ssl::stream<WrappedStream> _sslSocket;
};
}
#endif // TRINITYCORE_SSL_STREAM_H

View File

@@ -68,7 +68,7 @@ void RealmList::Initialize(Trinity::Asio::IoContext& ioContext, uint32 updateInt
{
_updateInterval = updateInterval;
_updateTimer = std::make_unique<Trinity::Asio::DeadlineTimer>(ioContext);
_resolver = std::make_unique<Trinity::Asio::Resolver>(ioContext);
_resolver = std::make_unique<Trinity::Net::Resolver>(ioContext);
ClientBuild::LoadBuildInfo();
// Get the content of the realmlist table in the database

View File

@@ -86,7 +86,7 @@ private:
std::unordered_set<std::string> _subRegions;
uint32 _updateInterval;
std::unique_ptr<Trinity::Asio::DeadlineTimer> _updateTimer;
std::unique_ptr<Trinity::Asio::Resolver> _resolver;
std::unique_ptr<Trinity::Net::Resolver> _resolver;
Optional<Battlenet::RealmHandle> _currentRealmId;
};