diff options
author | Shauren <shauren.trinity@gmail.com> | 2025-04-09 21:02:31 +0200 |
---|---|---|
committer | Ovahlord <dreadkiller@gmx.de> | 2025-04-09 21:09:48 +0200 |
commit | 48c8c93ec4791002e473e4ea7af2bea9d693be0a (patch) | |
tree | 55c8c295698f9d2542ede02d0c237e582908749d /src/common | |
parent | 00482e96553ad578dc32591c1b207f769f1d4eb9 (diff) |
Core/Network: Move to separate project
(cherry picked from commit 71b681bbf0f5189cd87a6cea66ef51667223f54a)
Diffstat (limited to 'src/common')
28 files changed, 2229 insertions, 316 deletions
diff --git a/src/common/Asio/AsioHacksFwd.h b/src/common/Asio/AsioHacksFwd.h index a8f04b16d81..06f5e531c20 100644 --- a/src/common/Asio/AsioHacksFwd.h +++ b/src/common/Asio/AsioHacksFwd.h @@ -60,9 +60,13 @@ namespace Trinity { class DeadlineTimer; class IoContext; - class Resolver; class Strand; } + + namespace Net + { + class Resolver; + } } #endif // AsioHacksFwd_h__ diff --git a/src/common/Asio/Resolver.h b/src/common/Asio/Resolver.h deleted file mode 100644 index 84dedd21bfa..00000000000 --- a/src/common/Asio/Resolver.h +++ /dev/null @@ -1,69 +0,0 @@ -/* - * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information - * - * This program is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation; either version 2 of the License, or (at your - * option) any later version. - * - * This program is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for - * more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#ifndef Resolver_h__ -#define Resolver_h__ - -#include "IoContext.h" -#include "Optional.h" -#include <boost/asio/ip/tcp.hpp> -#include <algorithm> -#include <string_view> -#include <vector> - -namespace Trinity -{ - namespace Asio - { - /** - Hack to make it possible to forward declare resolver (one of its template arguments is a typedef to something super long and using nested classes) - */ - class Resolver - { - public: - explicit Resolver(IoContext& ioContext) : _impl(ioContext) { } - - Optional<boost::asio::ip::tcp::endpoint> Resolve(boost::asio::ip::tcp const& protocol, std::string_view host, std::string_view service) - { - boost::system::error_code ec; - boost::asio::ip::resolver_base::flags flagsResolver = boost::asio::ip::resolver_base::all_matching; - boost::asio::ip::tcp::resolver::results_type results = _impl.resolve(protocol, host, service, flagsResolver, ec); - if (results.begin() == results.end() || ec) - return {}; - - return results.begin()->endpoint(); - } - - std::vector<boost::asio::ip::tcp::endpoint> ResolveAll(std::string_view host, std::string_view service) - { - boost::system::error_code ec; - boost::asio::ip::resolver_base::flags flagsResolver = boost::asio::ip::resolver_base::all_matching; - boost::asio::ip::tcp::resolver::results_type results = _impl.resolve(host, service, flagsResolver, ec); - std::vector<boost::asio::ip::tcp::endpoint> result; - if (!ec) - std::ranges::transform(results, std::back_inserter(result), [](boost::asio::ip::tcp::resolver::results_type::value_type const& entry) { return entry.endpoint(); }); - - return result; - } - - private: - boost::asio::ip::tcp::resolver _impl; - }; - } -} - -#endif // Resolver_h__ diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index b202dc5f9a4..8b8ecc0f471 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -14,7 +14,8 @@ CollectSourceFiles( # Exclude ${CMAKE_CURRENT_SOURCE_DIR}/Debugging/Windows ${CMAKE_CURRENT_SOURCE_DIR}/Platform - ${CMAKE_CURRENT_SOURCE_DIR}/PrecompiledHeaders) + ${CMAKE_CURRENT_SOURCE_DIR}/PrecompiledHeaders + ${CMAKE_CURRENT_SOURCE_DIR}/network) if(WIN32) CollectSourceFiles( @@ -27,6 +28,8 @@ if(WIN32) WINDOWS_PLATFORM_SOURCES) list(APPEND PRIVATE_SOURCES ${WINDOWS_PLATFORM_SOURCES}) + unset(WINDOWS_DEBUGGING_SOURCES) + unset(WINDOWS_PLATFORM_SOURCES) endif() if(USE_COREPCH) @@ -43,7 +46,8 @@ CollectIncludeDirectories( ${CMAKE_CURRENT_SOURCE_DIR} PUBLIC_INCLUDES # Exclude - ${CMAKE_CURRENT_SOURCE_DIR}/PrecompiledHeaders) + ${CMAKE_CURRENT_SOURCE_DIR}/PrecompiledHeaders + ${CMAKE_CURRENT_SOURCE_DIR}/network) target_include_directories(common PUBLIC @@ -95,3 +99,9 @@ endif() if(USE_COREPCH) add_cxx_pch(common ${PRIVATE_PCH_HEADER}) endif() + +unset(PRIVATE_SOURCES) +unset(PRIVATE_PCH_HEADER) +unset(PUBLIC_INCLUDES) + +add_subdirectory(network) diff --git a/src/common/Define.h b/src/common/Define.h index e4a2333c66d..f918db84314 100644 --- a/src/common/Define.h +++ b/src/common/Define.h @@ -111,6 +111,12 @@ # define TC_DATABASE_API TC_API_IMPORT #endif +#ifdef TRINITY_API_EXPORT_NETWORK +# define TC_NETWORK_API TC_API_EXPORT +#else +# define TC_NETWORK_API TC_API_IMPORT +#endif + #ifdef TRINITY_API_EXPORT_SHARED # define TC_SHARED_API TC_API_EXPORT #else diff --git a/src/common/IPLocation/IPLocation.cpp b/src/common/IPLocation/IPLocation.cpp deleted file mode 100644 index 72383bd3554..00000000000 --- a/src/common/IPLocation/IPLocation.cpp +++ /dev/null @@ -1,150 +0,0 @@ -/* - * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information - * - * This program is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation; either version 2 of the License, or (at your - * option) any later version. - * - * This program is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for - * more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#include "IPLocation.h" -#include "BigNumber.h" -#include "Config.h" -#include "Errors.h" -#include "IpAddress.h" -#include "Log.h" -#include "Util.h" -#include <algorithm> -#include <fstream> - -IpLocationStore::IpLocationStore() = default; -IpLocationStore::~IpLocationStore() = default; - -void IpLocationStore::Load() -{ - _ipLocationStore.clear(); - TC_LOG_INFO("server.loading", "Loading IP Location Database..."); - - std::string databaseFilePath = sConfigMgr->GetStringDefault("IPLocationFile", ""); - if (databaseFilePath.empty()) - return; - - // Check if file exists - std::ifstream databaseFile(databaseFilePath); - if (!databaseFile) - { - TC_LOG_ERROR("server.loading", "IPLocation: No ip database file exists ({}).", databaseFilePath); - return; - } - - if (!databaseFile.is_open()) - { - TC_LOG_ERROR("server.loading", "IPLocation: Ip database file ({}) can not be opened.", databaseFilePath); - return; - } - - std::string ipFrom; - std::string ipTo; - std::string countryCode; - std::string countryName; - BigNumber bnParser; - BigNumber ipv4Max(0xFFFFFFFF); - BigNumber ipv6MappedMask(0xFFFF); - ipv6MappedMask <<= 32; - - auto parseStringToIPv6 = [&](std::string const& str) -> Optional<std::array<uint8, 16>> - { - bnParser.SetDecStr(str); - if (!bnParser.SetDecStr(str)) - return {}; - // convert ipv4 to ipv6 v4 mapped value - if (bnParser <= ipv4Max) - bnParser += ipv6MappedMask; - return bnParser.ToByteArray<16>(false); - }; - - while (databaseFile.good()) - { - // Read lines - if (!std::getline(databaseFile, ipFrom, ',')) - break; - if (!std::getline(databaseFile, ipTo, ',')) - break; - if (!std::getline(databaseFile, countryCode, ',')) - break; - if (!std::getline(databaseFile, countryName, '\n')) - break; - - // Remove new lines and return - std::erase_if(countryName, [](char c) { return c == '\r' || c == '\n'; }); - - // Remove quotation marks - std::erase(ipFrom, '"'); - std::erase(ipTo, '"'); - std::erase(countryCode, '"'); - std::erase(countryName, '"'); - - if (countryCode == "-") - continue; - - // Convert country code to lowercase - strToLower(countryCode); - - Optional<std::array<uint8, 16>> from = parseStringToIPv6(ipFrom); - if (!from) - continue; - - Optional<std::array<uint8, 16>> to = parseStringToIPv6(ipTo); - if (!to) - continue; - - _ipLocationStore.emplace_back(*from, *to, std::move(countryCode), std::move(countryName)); - } - - std::ranges::sort(_ipLocationStore, {}, &IpLocationRecord::IpFrom); - ASSERT(std::ranges::is_sorted(_ipLocationStore, [](IpLocationRecord const& a, IpLocationRecord const& b) { return a.IpFrom < b.IpTo; }), - "Overlapping IP ranges detected in database file"); - - databaseFile.close(); - - TC_LOG_INFO("server.loading", ">> Loaded {} ip location entries.", _ipLocationStore.size()); -} - -IpLocationRecord const* IpLocationStore::GetLocationRecord(std::string const& ipAddress) const -{ - boost::system::error_code error; - boost::asio::ip::address address = Trinity::Net::make_address(ipAddress, error); - if (error) - return nullptr; - - std::array<uint8, 16> bytes = [&]() -> std::array<uint8, 16> - { - if (address.is_v6()) - return address.to_v6().to_bytes(); - if (address.is_v4()) - return Trinity::Net::make_address_v6(Trinity::Net::v4_mapped, address.to_v4()).to_bytes(); - return {}; - }(); - auto itr = std::ranges::upper_bound(_ipLocationStore, bytes, {}, &IpLocationRecord::IpTo); - if (itr == _ipLocationStore.end()) - return nullptr; - - if (bytes < itr->IpFrom) - return nullptr; - - return &(*itr); -} - -IpLocationStore* IpLocationStore::Instance() -{ - static IpLocationStore instance; - return &instance; -} diff --git a/src/common/IPLocation/IPLocation.h b/src/common/IPLocation/IPLocation.h deleted file mode 100644 index 5471948586b..00000000000 --- a/src/common/IPLocation/IPLocation.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information - * - * This program is free software; you can redistribute it and/or modify it - * under the terms of the GNU General Public License as published by the - * Free Software Foundation; either version 2 of the License, or (at your - * option) any later version. - * - * This program is distributed in the hope that it will be useful, but WITHOUT - * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for - * more details. - * - * You should have received a copy of the GNU General Public License along - * with this program. If not, see <http://www.gnu.org/licenses/>. - */ - -#ifndef IPLOCATION_H -#define IPLOCATION_H - -#include "Define.h" -#include <array> -#include <string> -#include <vector> - -struct IpLocationRecord -{ - IpLocationRecord() : IpFrom(), IpTo() { } - IpLocationRecord(std::array<uint8, 16> ipFrom, std::array<uint8, 16> ipTo, std::string&& countryCode, std::string&& countryName) - : IpFrom(ipFrom), IpTo(ipTo), CountryCode(std::move(countryCode)), CountryName(std::move(countryName)) { } - - std::array<uint8, 16> IpFrom; - std::array<uint8, 16> IpTo; - std::string CountryCode; - std::string CountryName; -}; - -class TC_COMMON_API IpLocationStore -{ - public: - IpLocationStore(); - IpLocationStore(IpLocationStore const&) = delete; - IpLocationStore(IpLocationStore&&) = delete; - IpLocationStore& operator=(IpLocationStore const&) = delete; - IpLocationStore& operator=(IpLocationStore&&) = delete; - ~IpLocationStore(); - static IpLocationStore* Instance(); - - void Load(); - IpLocationRecord const* GetLocationRecord(std::string const& ipAddress) const; - - private: - std::vector<IpLocationRecord> _ipLocationStore; -}; - -#define sIPLocation IpLocationStore::Instance() - -#endif diff --git a/src/common/Utilities/Util.cpp b/src/common/Utilities/Util.cpp index 93bc3b853a5..c374bff9cf1 100644 --- a/src/common/Utilities/Util.cpp +++ b/src/common/Utilities/Util.cpp @@ -18,7 +18,6 @@ #include "Util.h" #include "Common.h" #include "Containers.h" -#include "IpAddress.h" #include "StringConvert.h" #include "StringFormat.h" #include <boost/core/demangle.hpp> @@ -29,6 +28,10 @@ #include <cstdarg> #include <ctime> +#if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS +#include <Windows.h> +#endif + void Trinity::VerifyOsVersion() { #if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS @@ -270,17 +273,6 @@ std::string TimeToHumanReadable(time_t t) return std::string(buf); } -/// Check if the string is a valid ip address representation -bool IsIPAddress(char const* ipaddress) -{ - if (!ipaddress) - return false; - - boost::system::error_code error; - Trinity::Net::make_address(ipaddress, error); - return !error; -} - /// create PID file uint32 CreatePIDFile(std::string const& filename) { diff --git a/src/common/Utilities/Util.h b/src/common/Utilities/Util.h index c85fd858806..4907b4edbae 100644 --- a/src/common/Utilities/Util.h +++ b/src/common/Utilities/Util.h @@ -390,8 +390,6 @@ TC_COMMON_API bool WriteWinConsole(std::string_view str, bool error = false); TC_COMMON_API Optional<std::size_t> RemoveCRLF(std::string& str); -TC_COMMON_API bool IsIPAddress(char const* ipaddress); - TC_COMMON_API uint32 CreatePIDFile(std::string const& filename); TC_COMMON_API uint32 GetPID(); diff --git a/src/common/network/AsyncAcceptor.h b/src/common/network/AsyncAcceptor.h new file mode 100644 index 00000000000..dd0857c2b38 --- /dev/null +++ b/src/common/network/AsyncAcceptor.h @@ -0,0 +1,137 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_ASYNC_ACCEPTOR_H +#define TRINITYCORE_ASYNC_ACCEPTOR_H + +#include "IoContext.h" +#include "IpAddress.h" +#include "Log.h" +#include "Socket.h" +#include <boost/asio/ip/tcp.hpp> +#include <boost/asio/ip/v6_only.hpp> +#include <atomic> +#include <functional> + +#define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections + +namespace Trinity::Net +{ +template <typename Callable> +concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&, uint32>; + +class AsyncAcceptor +{ +public: + AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) : + _acceptor(ioContext), _endpoint(make_address(bindIp), port), + _socket(ioContext), _closed(false), _socketFactory([this] { return DefeaultSocketFactory(); }) + { + } + + template <AcceptCallback Callback> + void AsyncAccept(Callback&& acceptCallback) + { + auto [tmpSocket, tmpThreadIndex] = _socketFactory(); + // TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures) + IoContextTcpSocket* socket = tmpSocket; + uint32 threadIndex = tmpThreadIndex; + _acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error) mutable + { + if (!error) + { + try + { + socket->non_blocking(true); + + acceptCallback(std::move(*socket), threadIndex); + } + catch (boost::system::system_error const& err) + { + TC_LOG_INFO("network", "Failed to initialize client's socket {}", err.what()); + } + } + + if (!_closed) + this->AsyncAccept(std::move(acceptCallback)); + }); + } + + bool Bind() + { + boost::system::error_code errorCode; + _acceptor.open(_endpoint.protocol(), errorCode); + if (errorCode) + { + TC_LOG_INFO("network", "Failed to open acceptor {}", errorCode.message()); + return false; + } + +#if TRINITY_PLATFORM != TRINITY_PLATFORM_WINDOWS + _acceptor.set_option(boost::asio::ip::tcp::acceptor::reuse_address(true), errorCode); + if (errorCode) + { + TC_LOG_INFO("network", "Failed to set reuse_address option on acceptor {}", errorCode.message()); + return false; + } +#endif + + // v6_only is enabled on some *BSD distributions by default + // we want to allow both v4 and v6 connections to the same listener + if (_endpoint.protocol() == boost::asio::ip::tcp::v6()) + _acceptor.set_option(boost::asio::ip::v6_only(false)); + + _acceptor.bind(_endpoint, errorCode); + if (errorCode) + { + TC_LOG_INFO("network", "Could not bind to {}:{} {}", _endpoint.address().to_string(), _endpoint.port(), errorCode.message()); + return false; + } + + _acceptor.listen(TRINITY_MAX_LISTEN_CONNECTIONS, errorCode); + if (errorCode) + { + TC_LOG_INFO("network", "Failed to start listening on {}:{} {}", _endpoint.address().to_string(), _endpoint.port(), errorCode.message()); + return false; + } + + return true; + } + + void Close() + { + if (_closed.exchange(true)) + return; + + boost::system::error_code err; + _acceptor.close(err); + } + + void SetSocketFactory(std::function<std::pair<IoContextTcpSocket*, uint32>()> func) { _socketFactory = std::move(func); } + +private: + std::pair<IoContextTcpSocket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); } + + boost::asio::basic_socket_acceptor<boost::asio::ip::tcp, IoContextTcpSocket::executor_type> _acceptor; + boost::asio::ip::tcp::endpoint _endpoint; + IoContextTcpSocket _socket; + std::atomic<bool> _closed; + std::function<std::pair<IoContextTcpSocket*, uint32>()> _socketFactory; +}; +} + +#endif // TRINITYCORE_ASYNC_ACCEPTOR_H diff --git a/src/common/network/CMakeLists.txt b/src/common/network/CMakeLists.txt new file mode 100644 index 00000000000..70408faed51 --- /dev/null +++ b/src/common/network/CMakeLists.txt @@ -0,0 +1,53 @@ +# This file is part of the TrinityCore Project. See AUTHORS file for Copyright information +# +# This file is free software; as a special exception the author gives +# unlimited permission to copy and/or distribute it, with or without +# modifications, as long as this notice is preserved. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY, to the extent permitted by law; without even the +# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + +CollectSourceFiles( + ${CMAKE_CURRENT_SOURCE_DIR} + PRIVATE_SOURCES) + +GroupSources(${CMAKE_CURRENT_SOURCE_DIR}) + +add_library(network + ${PRIVATE_SOURCES}) + +CollectIncludeDirectories( + ${CMAKE_CURRENT_SOURCE_DIR} + PUBLIC_INCLUDES) + +target_include_directories(network + PUBLIC + ${PUBLIC_INCLUDES} + PRIVATE + ${CMAKE_CURRENT_BINARY_DIR}) + +target_link_libraries(network + PRIVATE + trinity-core-interface + PUBLIC + common) + +set_target_properties(network + PROPERTIES + COMPILE_WARNING_AS_ERROR ${WITH_WARNINGS_AS_ERRORS} + DEFINE_SYMBOL TRINITY_API_EXPORT_NETWORK + FOLDER "server" + OUTPUT_NAME trinity_network) + +if(BUILD_SHARED_LIBS) + if(UNIX) + install(TARGETS network + LIBRARY + DESTINATION lib) + elseif(WIN32) + install(TARGETS network + RUNTIME + DESTINATION "${CMAKE_INSTALL_PREFIX}") + endif() +endif() diff --git a/src/common/network/ConnectionInitializers/SocketConnectionInitializer.h b/src/common/network/ConnectionInitializers/SocketConnectionInitializer.h new file mode 100644 index 00000000000..d3f0bb16dbf --- /dev/null +++ b/src/common/network/ConnectionInitializers/SocketConnectionInitializer.h @@ -0,0 +1,51 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H +#define TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H + +#include <memory> +#include <span> + +namespace Trinity::Net +{ +struct SocketConnectionInitializer : public std::enable_shared_from_this<SocketConnectionInitializer> +{ + SocketConnectionInitializer() = default; + + SocketConnectionInitializer(SocketConnectionInitializer const&) = delete; + SocketConnectionInitializer(SocketConnectionInitializer&&) noexcept = default; + SocketConnectionInitializer& operator=(SocketConnectionInitializer const&) = delete; + SocketConnectionInitializer& operator=(SocketConnectionInitializer&&) noexcept = default; + + virtual ~SocketConnectionInitializer() = default; + + virtual void Start() = 0; + + std::shared_ptr<SocketConnectionInitializer> next; + + static std::shared_ptr<SocketConnectionInitializer>& SetupChain(std::span<std::shared_ptr<SocketConnectionInitializer>> initializers) + { + for (std::size_t i = initializers.size(); i > 1; --i) + initializers[i - 2]->next.swap(initializers[i - 1]); + + return initializers[0]; + } +}; +} + +#endif // TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H diff --git a/src/common/network/Http/BaseHttpSocket.cpp b/src/common/network/Http/BaseHttpSocket.cpp new file mode 100644 index 00000000000..33053399c14 --- /dev/null +++ b/src/common/network/Http/BaseHttpSocket.cpp @@ -0,0 +1,146 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "BaseHttpSocket.h" +#include <boost/asio/buffers_iterator.hpp> +#include <boost/beast/http/serializer.hpp> + +namespace Trinity::Net::Http +{ +using RequestSerializer = boost::beast::http::request_serializer<ResponseBody>; +using ResponseSerializer = boost::beast::http::response_serializer<ResponseBody>; + +bool AbstractSocket::ParseRequest(MessageBuffer& packet, RequestParser& parser) +{ + if (!parser.is_done()) + { + // need more data in the payload + boost::system::error_code ec = {}; + std::size_t readDataSize = parser.put(boost::asio::const_buffer(packet.GetReadPointer(), packet.GetActiveSize()), ec); + packet.ReadCompleted(readDataSize); + } + + return parser.is_done(); +} + +std::string AbstractSocket::SerializeRequest(Request const& request) +{ + RequestSerializer serializer(request); + + std::string buffer; + while (!serializer.is_done()) + { + size_t totalBytes = 0; + boost::system::error_code ec = {}; + serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code const&, ConstBufferSequence const& buffers) + { + size_t totalBytesInBuffers = boost::asio::buffer_size(buffers); + + buffer.reserve(buffer.size() + totalBytes); + + auto begin = boost::asio::buffers_begin(buffers); + auto end = boost::asio::buffers_end(buffers); + + std::copy(begin, end, std::back_inserter(buffer)); + totalBytes += totalBytesInBuffers; + }); + + serializer.consume(totalBytes); + } + + return buffer; +} + +MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response& response) +{ + response.prepare_payload(); + + ResponseSerializer serializer(response); + bool (*serializerIsDone)(ResponseSerializer&); + if (request.method() != boost::beast::http::verb::head) + { + serializerIsDone = [](ResponseSerializer& s) { return s.is_done(); }; + } + else + { + serializerIsDone = [](ResponseSerializer& s) { return s.is_header_done(); }; + serializer.split(true); + } + + MessageBuffer buffer; + while (!serializerIsDone(serializer)) + { + serializer.limit(buffer.GetRemainingSpace()); + + size_t totalBytes = 0; + boost::system::error_code ec = {}; + serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code& currentError, ConstBufferSequence const& buffers) + { + size_t totalBytesInBuffers = boost::asio::buffer_size(buffers); + if (totalBytesInBuffers > buffer.GetRemainingSpace()) + { + currentError = boost::beast::http::error::need_more; + return; + } + + auto begin = boost::asio::buffers_begin(buffers); + auto end = boost::asio::buffers_end(buffers); + + std::copy(begin, end, buffer.GetWritePointer()); + buffer.WriteCompleted(totalBytesInBuffers); + totalBytes += totalBytesInBuffers; + }); + + serializer.consume(totalBytes); + + if (ec == boost::beast::http::error::need_more) + buffer.Resize(buffer.GetBufferSize() + 4096); + } + + return buffer; +} + +void AbstractSocket::LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const +{ + if (Logger const* logger = sLog->GetEnabledLogger("server.http", LOG_LEVEL_DEBUG)) + { + std::string clientInfo = GetClientInfo(); + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_DEBUG, "{} Request {} {} done, status {}", clientInfo, + ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()), context.response.result_int()); + if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE)) + { + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: {}", clientInfo, + CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>"); + sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: {}", clientInfo, + CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>"); + } + } +} + +std::string AbstractSocket::GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state) +{ + std::string info = StringFormat("[{}:{}", address.to_string(), port); + if (state) + { + info.append(", Session Id: "); + info.append(boost::uuids::to_string(state->Id)); + } + + info += ']'; + return info; +} +} diff --git a/src/common/network/Http/BaseHttpSocket.h b/src/common/network/Http/BaseHttpSocket.h new file mode 100644 index 00000000000..2d58b6d5aa1 --- /dev/null +++ b/src/common/network/Http/BaseHttpSocket.h @@ -0,0 +1,228 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_BASE_HTTP_SOCKET_H +#define TRINITYCORE_BASE_HTTP_SOCKET_H + +#include "AsyncCallbackProcessor.h" +#include "HttpCommon.h" +#include "HttpSessionState.h" +#include "Optional.h" +#include "Socket.h" +#include "SocketConnectionInitializer.h" +#include <boost/beast/core/basic_stream.hpp> +#include <boost/beast/http/parser.hpp> +#include <boost/beast/http/string_body.hpp> +#include <boost/uuid/uuid_io.hpp> + +namespace Trinity::Net::Http +{ +using IoContextHttpSocket = boost::beast::basic_stream<boost::asio::ip::tcp, boost::asio::io_context::executor_type, boost::beast::unlimited_rate_policy>; + +namespace Impl +{ +class BoostBeastSocketWrapper : public IoContextHttpSocket +{ +public: + using IoContextHttpSocket::basic_stream; + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) + { + socket().shutdown(what, shutdownError); + } + + void close(boost::system::error_code& /*error*/) + { + IoContextHttpSocket::close(); + } + + template<typename WaitHandlerType> + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler) + { + socket().async_wait(type, std::forward<WaitHandlerType>(handler)); + } + + IoContextTcpSocket::endpoint_type remote_endpoint() const + { + return socket().remote_endpoint(); + } +}; +} + +using RequestParser = boost::beast::http::request_parser<RequestBody>; + +class TC_NETWORK_API AbstractSocket +{ +public: + AbstractSocket() = default; + AbstractSocket(AbstractSocket const& other) = default; + AbstractSocket(AbstractSocket&& other) = default; + AbstractSocket& operator=(AbstractSocket const& other) = default; + AbstractSocket& operator=(AbstractSocket&& other) = default; + virtual ~AbstractSocket() = default; + + static bool ParseRequest(MessageBuffer& packet, RequestParser& parser); + + static std::string SerializeRequest(Request const& request); + static MessageBuffer SerializeResponse(Request const& request, Response& response); + + virtual void SendResponse(RequestContext& context) = 0; + + void LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const; + + virtual std::string GetClientInfo() const = 0; + + static std::string GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state); + + virtual SessionState* GetSessionState() const = 0; + + Optional<boost::uuids::uuid> GetSessionId() const + { + if (SessionState* state = this->GetSessionState()) + return state->Id; + + return {}; + } + + virtual void Start() = 0; + + virtual bool Update() = 0; + + virtual boost::asio::ip::address const& GetRemoteIpAddress() const = 0; + + virtual bool IsOpen() const = 0; + + virtual void CloseSocket() = 0; +}; + +template <typename SocketImpl> +struct HttpConnectionInitializer final : SocketConnectionInitializer +{ + explicit HttpConnectionInitializer(SocketImpl* socket) : _socket(socket) { } + + void Start() override + { + _socket->ResetHttpParser(); + + if (this->next) + this->next->Start(); + } + +private: + SocketImpl* _socket; +}; + +template<typename Stream> +class BaseSocket : public Trinity::Net::Socket<Stream>, public AbstractSocket +{ + using Base = Trinity::Net::Socket<Stream>; + +public: + using Base::Base; + + BaseSocket(BaseSocket const& other) = delete; + BaseSocket(BaseSocket&& other) = delete; + BaseSocket& operator=(BaseSocket const& other) = delete; + BaseSocket& operator=(BaseSocket&& other) = delete; + + ~BaseSocket() = default; + + SocketReadCallbackResult ReadHandler() final + { + MessageBuffer& packet = this->GetReadBuffer(); + while (packet.GetActiveSize() > 0) + { + if (!ParseRequest(packet, *_httpParser)) + { + // Couldn't receive the whole data this time. + break; + } + + if (!HandleMessage(_httpParser->get())) + { + this->CloseSocket(); + return SocketReadCallbackResult::Stop; + } + + this->ResetHttpParser(); + } + + return SocketReadCallbackResult::KeepReading; + } + + bool HandleMessage(Request& request) + { + RequestContext context { .request = std::move(request) }; + + if (!_state) + _state = this->ObtainSessionState(context); + + RequestHandlerResult status = this->RequestHandler(context); + + if (status != RequestHandlerResult::Async) + this->SendResponse(context); + + return status != RequestHandlerResult::Error; + } + + virtual RequestHandlerResult RequestHandler(RequestContext& context) = 0; + + void SendResponse(RequestContext& context) final + { + MessageBuffer buffer = SerializeResponse(context.request, context.response); + + this->LogRequestAndResponse(context, buffer); + + this->QueuePacket(std::move(buffer)); + + if (!context.response.keep_alive()) + this->DelayedCloseSocket(); + } + + void Start() override { return this->Base::Start(); } + + bool Update() override { return this->Base::Update(); } + + boost::asio::ip::address const& GetRemoteIpAddress() const final { return this->Base::GetRemoteIpAddress(); } + + bool IsOpen() const final { return this->Base::IsOpen(); } + + void CloseSocket() final { return this->Base::CloseSocket(); } + + std::string GetClientInfo() const override + { + return AbstractSocket::GetClientInfo(this->GetRemoteIpAddress(), this->GetRemotePort(), this->_state.get()); + } + + SessionState* GetSessionState() const override { return _state.get(); } + + void ResetHttpParser() + { + this->_httpParser.reset(); + this->_httpParser.emplace(); + this->_httpParser->eager(true); + } + +protected: + virtual std::shared_ptr<SessionState> ObtainSessionState(RequestContext& context) const = 0; + + Optional<RequestParser> _httpParser; + std::shared_ptr<SessionState> _state; +}; +} + +#endif // TRINITYCORE_BASE_HTTP_SOCKET_H diff --git a/src/common/network/Http/HttpCommon.h b/src/common/network/Http/HttpCommon.h new file mode 100644 index 00000000000..274b59d7536 --- /dev/null +++ b/src/common/network/Http/HttpCommon.h @@ -0,0 +1,55 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_HTTP_COMMON_H +#define TRINITYCORE_HTTP_COMMON_H + +#include "Define.h" +#include <boost/beast/http/message.hpp> +#include <boost/beast/http/string_body.hpp> + +namespace Trinity::Net::Http +{ +using RequestBody = boost::beast::http::string_body; +using ResponseBody = boost::beast::http::string_body; + +using Request = boost::beast::http::request<RequestBody>; +using Response = boost::beast::http::response<ResponseBody>; + +struct RequestContext +{ + Request request; + Response response; + struct RequestHandler const* handler = nullptr; +}; + +TC_NETWORK_API bool CanLogRequestContent(RequestContext const& context); +TC_NETWORK_API bool CanLogResponseContent(RequestContext const& context); + +inline std::string_view ToStdStringView(boost::beast::string_view bsw) +{ + return { bsw.data(), bsw.size() }; +} + +enum class RequestHandlerResult +{ + Handled, + Error, + Async, +}; +} +#endif // TRINITYCORE_HTTP_COMMON_H diff --git a/src/common/network/Http/HttpService.cpp b/src/common/network/Http/HttpService.cpp new file mode 100644 index 00000000000..b01e27e296a --- /dev/null +++ b/src/common/network/Http/HttpService.cpp @@ -0,0 +1,267 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "HttpService.h" +#include "BaseHttpSocket.h" +#include "CryptoRandom.h" +#include "Timezone.h" +#include <boost/beast/version.hpp> +#include <boost/uuid/string_generator.hpp> +#include <fmt/chrono.h> + +namespace Trinity::Net::Http +{ +bool CanLogRequestContent(RequestContext const& context) +{ + return !context.handler || !context.handler->Flags.HasFlag(RequestHandlerFlag::DoNotLogRequestContent); +} + +bool CanLogResponseContent(RequestContext const& context) +{ + return !context.handler || !context.handler->Flags.HasFlag(RequestHandlerFlag::DoNotLogResponseContent); +} + +RequestHandlerResult DispatcherService::HandleRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context) +{ + TC_LOG_DEBUG(_logger, "{} Starting request {} {}", session->GetClientInfo(), + ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target())); + + std::string_view path = [&] + { + std::string_view path = ToStdStringView(context.request.target()); + size_t queryIndex = path.find('?'); + if (queryIndex != std::string_view::npos) + path = path.substr(0, queryIndex); + return path; + }(); + + context.handler = [&]() -> HttpMethodHandlerMap::mapped_type const* + { + switch (context.request.method()) + { + case boost::beast::http::verb::get: + case boost::beast::http::verb::head: + { + auto itr = _getHandlers.find(path); + return itr != _getHandlers.end() ? &itr->second : nullptr; + } + case boost::beast::http::verb::post: + { + auto itr = _postHandlers.find(path); + return itr != _postHandlers.end() ? &itr->second : nullptr; + } + default: + break; + } + return nullptr; + }(); + + SystemTimePoint responseDate = SystemTimePoint::clock::now(); + context.response.set(boost::beast::http::field::date, StringFormat("{:%a, %d %b %Y %T GMT}", responseDate - Timezone::GetSystemZoneOffsetAt(responseDate))); + context.response.set(boost::beast::http::field::server, BOOST_BEAST_VERSION_STRING); + context.response.keep_alive(context.request.keep_alive()); + + if (!context.handler) + return HandlePathNotFound(std::move(session), context); + + return context.handler->Func(std::move(session), context); +} + +RequestHandlerResult DispatcherService::HandleBadRequest(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context) +{ + context.response.result(boost::beast::http::status::bad_request); + return RequestHandlerResult::Handled; +} + +RequestHandlerResult DispatcherService::HandleUnauthorized(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context) +{ + context.response.result(boost::beast::http::status::unauthorized); + return RequestHandlerResult::Handled; +} + +RequestHandlerResult DispatcherService::HandlePathNotFound(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context) +{ + context.response.result(boost::beast::http::status::not_found); + return RequestHandlerResult::Handled; +} + +void DispatcherService::RegisterHandler(boost::beast::http::verb method, std::string_view path, + std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> handler, + RequestHandlerFlag flags) +{ + HttpMethodHandlerMap& handlerMap = [&]() -> HttpMethodHandlerMap& + { + switch (method) + { + case boost::beast::http::verb::get: + return _getHandlers; + case boost::beast::http::verb::post: + return _postHandlers; + default: + { + std::string_view methodString = ToStdStringView(boost::beast::http::to_string(method)); + ABORT_MSG("Tried to register a handler for unsupported HTTP method " STRING_VIEW_FMT, STRING_VIEW_FMT_ARG(methodString)); + } + } + }(); + + handlerMap[std::string(path)] = { .Func = std::move(handler), .Flags = flags }; + TC_LOG_INFO(_logger, "Registered new handler for {} {}", ToStdStringView(boost::beast::http::to_string(method)), path); +} + +void SessionService::InitAndStoreSessionState(std::shared_ptr<SessionState> state, boost::asio::ip::address const& address) +{ + state->RemoteAddress = address; + + // Generate session id + { + std::unique_lock lock{ _sessionsMutex }; + + while (state->Id.is_nil() || _sessions.contains(state->Id)) + std::copy_n(Trinity::Crypto::GetRandomBytes<16>().begin(), 16, state->Id.begin()); + + TC_LOG_DEBUG(_logger, "Client at {} created new session {}", address.to_string(), boost::uuids::to_string(state->Id)); + _sessions[state->Id] = std::move(state); + } +} + +void SessionService::Start(Asio::IoContext& ioContext) +{ + _inactiveSessionsKillTimer = std::make_unique<Asio::DeadlineTimer>(ioContext); + _inactiveSessionsKillTimer->expires_after(1min); + _inactiveSessionsKillTimer->async_wait([this](boost::system::error_code const& err) + { + if (err) + return; + + KillInactiveSessions(); + }); +} + +void SessionService::Stop() +{ + _inactiveSessionsKillTimer = nullptr; + { + std::unique_lock lock{ _sessionsMutex }; + _sessions.clear(); + } + { + std::unique_lock lock{ _inactiveSessionsMutex }; + _inactiveSessions.clear(); + } +} + +std::shared_ptr<SessionState> SessionService::FindAndRefreshSessionState(std::string_view id, boost::asio::ip::address const& address) +{ + std::shared_ptr<SessionState> state; + + { + std::shared_lock lock{ _sessionsMutex }; + auto itr = _sessions.find(boost::uuids::string_generator()(id.begin(), id.end())); + if (itr == _sessions.end()) + { + TC_LOG_DEBUG(_logger, "Client at {} attempted to use a session {} that was expired", address.to_string(), id); + return nullptr; // no session + } + + state = itr->second; + } + + if (state->RemoteAddress != address) + { + TC_LOG_ERROR(_logger, "Client at {} attempted to use a session {} that was last accessed from {}, denied access", + address.to_string(), id, state->RemoteAddress.to_string()); + return nullptr; + } + + { + std::unique_lock inactiveSessionsLock{ _inactiveSessionsMutex }; + _inactiveSessions.erase(state->Id); + } + + return state; +} + +void SessionService::MarkSessionInactive(boost::uuids::uuid const& id) +{ + bool wasActive = true; + { + std::unique_lock inactiveSessionsLock{ _inactiveSessionsMutex }; + wasActive = _inactiveSessions.insert(id).second; + } + + if (wasActive) + { + std::shared_lock lock{ _sessionsMutex }; + auto itr = _sessions.find(id); + if (itr != _sessions.end()) + { + itr->second->InactiveTimestamp = TimePoint::clock::now() + Minutes(5); + TC_LOG_TRACE(_logger, "Session {} marked as inactive", boost::uuids::to_string(id)); + } + } +} + +void SessionService::KillInactiveSessions() +{ + std::set<boost::uuids::uuid> inactiveSessions; + + { + std::unique_lock lock{ _inactiveSessionsMutex }; + std::swap(_inactiveSessions, inactiveSessions); + } + + { + TimePoint now = TimePoint::clock::now(); + std::size_t inactiveSessionsCount = inactiveSessions.size(); + + std::unique_lock lock{ _sessionsMutex }; + for (auto itr = inactiveSessions.begin(); itr != inactiveSessions.end(); ) + { + auto sessionItr = _sessions.find(*itr); + if (sessionItr == _sessions.end() || sessionItr->second->InactiveTimestamp < now) + { + _sessions.erase(sessionItr); + itr = inactiveSessions.erase(itr); + } + else + ++itr; + } + + TC_LOG_DEBUG(_logger, "Killed {} inactive sessions", inactiveSessionsCount - inactiveSessions.size()); + } + + { + // restore sessions not killed to inactive queue + std::unique_lock lock{ _inactiveSessionsMutex }; + for (auto itr = inactiveSessions.begin(); itr != inactiveSessions.end(); ) + { + auto node = inactiveSessions.extract(itr++); + _inactiveSessions.insert(std::move(node)); + } + } + + _inactiveSessionsKillTimer->expires_after(1min); + _inactiveSessionsKillTimer->async_wait([this](boost::system::error_code const& err) + { + if (err) + return; + + KillInactiveSessions(); + }); +} +} diff --git a/src/common/network/Http/HttpService.h b/src/common/network/Http/HttpService.h new file mode 100644 index 00000000000..1549893576f --- /dev/null +++ b/src/common/network/Http/HttpService.h @@ -0,0 +1,188 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_HTTP_SERVICE_H +#define TRINITYCORE_HTTP_SERVICE_H + +#include "AsioHacksFwd.h" +#include "Concepts.h" +#include "Define.h" +#include "EnumFlag.h" +#include "HttpCommon.h" +#include "HttpSessionState.h" +#include "Optional.h" +#include "SocketMgr.h" +#include <boost/uuid/uuid.hpp> +#include <functional> +#include <map> +#include <set> +#include <shared_mutex> + +namespace Trinity::Net::Http +{ +class AbstractSocket; + +enum class RequestHandlerFlag +{ + None = 0x0, + DoNotLogRequestContent = 0x1, + DoNotLogResponseContent = 0x2, +}; + +DEFINE_ENUM_FLAG(RequestHandlerFlag); + +struct RequestHandler +{ + std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> Func; + EnumFlag<RequestHandlerFlag> Flags = RequestHandlerFlag::None; +}; + +class TC_NETWORK_API DispatcherService +{ +public: + explicit DispatcherService(std::string_view loggerSuffix) : _logger("server.http.dispatcher.") + { + _logger.append(loggerSuffix); + } + + RequestHandlerResult HandleRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context); + + static RequestHandlerResult HandleBadRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context); + static RequestHandlerResult HandleUnauthorized(std::shared_ptr<AbstractSocket> session, RequestContext& context); + static RequestHandlerResult HandlePathNotFound(std::shared_ptr<AbstractSocket> session, RequestContext& context); + +protected: + void RegisterHandler(boost::beast::http::verb method, std::string_view path, + std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> handler, + RequestHandlerFlag flags = RequestHandlerFlag::None); + +private: + using HttpMethodHandlerMap = std::map<std::string, RequestHandler, std::less<>>; + + HttpMethodHandlerMap _getHandlers; + HttpMethodHandlerMap _postHandlers; + + std::string _logger; +}; + +class TC_NETWORK_API SessionService +{ +public: + explicit SessionService(std::string_view loggerSuffix) : _logger("server.http.session.") + { + _logger.append(loggerSuffix); + } + + void Start(Asio::IoContext& ioContext); + void Stop(); + + std::shared_ptr<SessionState> FindAndRefreshSessionState(std::string_view id, boost::asio::ip::address const& address); + void MarkSessionInactive(boost::uuids::uuid const& id); + +protected: + void InitAndStoreSessionState(std::shared_ptr<SessionState> state, boost::asio::ip::address const& address); + + void KillInactiveSessions(); + +private: + std::shared_mutex _sessionsMutex; + std::map<boost::uuids::uuid, std::shared_ptr<SessionState>> _sessions; + + std::mutex _inactiveSessionsMutex; + std::set<boost::uuids::uuid> _inactiveSessions; + std::unique_ptr<Asio::DeadlineTimer> _inactiveSessionsKillTimer; + + std::string _logger; +}; + +template<typename Callable, typename SessionImpl> +concept HttpRequestHandler = invocable_r<Callable, RequestHandlerResult, std::shared_ptr<SessionImpl>, RequestContext&>; + +template<typename SessionImpl> +class HttpService : public SocketMgr<SessionImpl>, public DispatcherService, public SessionService +{ +public: + HttpService(std::string_view loggerSuffix) : DispatcherService(loggerSuffix), SessionService(loggerSuffix), _ioContext(nullptr), _logger("server.http.") + { + _logger.append(loggerSuffix); + } + + bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount = 1) override + { + if (!SocketMgr<SessionImpl>::StartNetwork(ioContext, bindIp, port, threadCount)) + return false; + + SessionService::Start(ioContext); + return true; + } + + void StopNetwork() override + { + SessionService::Stop(); + SocketMgr<SessionImpl>::StopNetwork(); + } + + // http handling + using DispatcherService::RegisterHandler; + + template<HttpRequestHandler<SessionImpl> Callable> + void RegisterHandler(boost::beast::http::verb method, std::string_view path, Callable handler, RequestHandlerFlag flags = RequestHandlerFlag::None) + { + this->DispatcherService::RegisterHandler(method, path, [handler = std::move(handler)](std::shared_ptr<AbstractSocket> session, RequestContext& context) -> RequestHandlerResult + { + return handler(std::static_pointer_cast<SessionImpl>(std::move(session)), context); + }, flags); + } + + // session tracking + virtual std::shared_ptr<SessionState> CreateNewSessionState(boost::asio::ip::address const& address) + { + std::shared_ptr<SessionState> state = std::make_shared<SessionState>(); + InitAndStoreSessionState(state, address); + return state; + } + +protected: + class Thread : public NetworkThread<SessionImpl> + { + protected: + void SocketRemoved(std::shared_ptr<SessionImpl> const& session) override + { + if (Optional<boost::uuids::uuid> id = session->GetSessionId()) + _service->MarkSessionInactive(*id); + } + + private: + friend HttpService; + + SessionService* _service; + }; + + NetworkThread<SessionImpl>* CreateThreads() const override + { + Thread* threads = new Thread[this->GetNetworkThreadCount()]; + for (int32 i = 0; i < this->GetNetworkThreadCount(); ++i) + threads[i]._service = const_cast<HttpService*>(this); + return threads; + } + + Asio::IoContext* _ioContext; + std::string _logger; +}; +} + +#endif // TRINITYCORE_HTTP_SERVICE_H diff --git a/src/common/network/Http/HttpSessionState.h b/src/common/network/Http/HttpSessionState.h new file mode 100644 index 00000000000..3012a2efc65 --- /dev/null +++ b/src/common/network/Http/HttpSessionState.h @@ -0,0 +1,35 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_HTTP_SESSION_STATE_H +#define TRINITYCORE_HTTP_SESSION_STATE_H + +#include "Duration.h" +#include <boost/asio/ip/address.hpp> +#include <boost/uuid/uuid.hpp> + +namespace Trinity::Net::Http +{ +struct SessionState +{ + boost::uuids::uuid Id = { }; + boost::asio::ip::address RemoteAddress; + TimePoint InactiveTimestamp = TimePoint::max(); +}; +} + +#endif // TRINITYCORE_HTTP_SESSION_STATE_H diff --git a/src/common/network/Http/HttpSocket.h b/src/common/network/Http/HttpSocket.h new file mode 100644 index 00000000000..2cfc3ba8ed8 --- /dev/null +++ b/src/common/network/Http/HttpSocket.h @@ -0,0 +1,53 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_HTTP_SOCKET_H +#define TRINITYCORE_HTTP_SOCKET_H + +#include "BaseHttpSocket.h" +#include <array> + +namespace Trinity::Net::Http +{ +class Socket : public BaseSocket<Impl::BoostBeastSocketWrapper> +{ + using SocketBase = BaseSocket<Impl::BoostBeastSocketWrapper>; + +public: + using SocketBase::SocketBase; + + Socket(Socket const& other) = delete; + Socket(Socket&& other) = delete; + Socket& operator=(Socket const& other) = delete; + Socket& operator=(Socket&& other) = delete; + + ~Socket() = default; + + void Start() override + { + std::array<std::shared_ptr<SocketConnectionInitializer>, 2> initializers = + { { + std::make_shared<HttpConnectionInitializer<SocketBase>>(this), + std::make_shared<ReadConnectionInitializer<SocketBase>>(this), + } }; + + SocketConnectionInitializer::SetupChain(initializers)->Start(); + } +}; +} + +#endif // TRINITYCORE_HTTP_SOCKET_H diff --git a/src/common/network/Http/HttpSslSocket.h b/src/common/network/Http/HttpSslSocket.h new file mode 100644 index 00000000000..c789cbfefaf --- /dev/null +++ b/src/common/network/Http/HttpSslSocket.h @@ -0,0 +1,58 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_HTTP_SSL_SOCKET_H +#define TRINITYCORE_HTTP_SSL_SOCKET_H + +#include "BaseHttpSocket.h" +#include "SslStream.h" + +namespace Trinity::Net::Http +{ +class SslSocket : public BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>> +{ + using SocketBase = BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>; + +public: + explicit SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) + : SocketBase(std::move(socket), sslContext) { } + + explicit SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) + : SocketBase(context, sslContext) { } + + SslSocket(SslSocket const& other) = delete; + SslSocket(SslSocket&& other) = delete; + SslSocket& operator=(SslSocket const& other) = delete; + SslSocket& operator=(SslSocket&& other) = delete; + + ~SslSocket() = default; + + void Start() override + { + std::array<std::shared_ptr<SocketConnectionInitializer>, 3> initializers = + { { + std::make_shared<SslHandshakeConnectionInitializer<SocketBase>>(this), + std::make_shared<HttpConnectionInitializer<SocketBase>>(this), + std::make_shared<ReadConnectionInitializer<SocketBase>>(this), + } }; + + SocketConnectionInitializer::SetupChain(initializers)->Start(); + } +}; +} + +#endif // TRINITYCORE_HTTP_SSL_SOCKET_H diff --git a/src/common/Asio/IpAddress.h b/src/common/network/IpAddress.h index 7d85b0028ac..b856d7f6340 100644 --- a/src/common/Asio/IpAddress.h +++ b/src/common/network/IpAddress.h @@ -15,22 +15,19 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef IpAddress_h__ -#define IpAddress_h__ +#ifndef TRINITYCORE_IP_ADDRESS_H +#define TRINITYCORE_IP_ADDRESS_H #include "Define.h" #include <boost/asio/ip/address.hpp> -namespace Trinity +namespace Trinity::Net { - namespace Net - { - using boost::asio::ip::make_address; - using boost::asio::ip::make_address_v4; - using boost::asio::ip::make_address_v6; - using boost::asio::ip::v4_mapped_t::v4_mapped; - inline uint32 address_to_uint(boost::asio::ip::address_v4 const& address) { return address.to_uint(); } - } + using boost::asio::ip::make_address; + using boost::asio::ip::make_address_v4; + using boost::asio::ip::make_address_v6; + using boost::asio::ip::v4_mapped_t::v4_mapped; + inline uint32 address_to_uint(boost::asio::ip::address_v4 const& address) { return address.to_uint(); } } -#endif // IpAddress_h__ +#endif // TRINITYCORE_IP_ADDRESS_H diff --git a/src/common/Asio/IpNetwork.cpp b/src/common/network/IpNetwork.cpp index 85f176d21e4..b6235a5a947 100644 --- a/src/common/Asio/IpNetwork.cpp +++ b/src/common/network/IpNetwork.cpp @@ -33,7 +33,7 @@ bool IsInLocalNetwork(boost::asio::ip::address const& clientAddress) { if (clientAddress.is_v4()) { - return std::any_of(LocalV4Networks.begin(), LocalV4Networks.end(), [clientAddressV4 = clientAddress.to_v4()](boost::asio::ip::network_v4 const& network) + return std::ranges::any_of(LocalV4Networks, [clientAddressV4 = clientAddress.to_v4()](boost::asio::ip::network_v4 const& network) { return IsInNetwork(network, clientAddressV4); }); @@ -41,7 +41,7 @@ bool IsInLocalNetwork(boost::asio::ip::address const& clientAddress) if (clientAddress.is_v6()) { - return std::any_of(LocalV6Networks.begin(), LocalV6Networks.end(), [clientAddressV6 = clientAddress.to_v6()](boost::asio::ip::network_v6 const& network) + return std::ranges::any_of(LocalV6Networks, [clientAddressV6 = clientAddress.to_v6()](boost::asio::ip::network_v6 const& network) { return IsInNetwork(network, clientAddressV6); }); diff --git a/src/common/Asio/IpNetwork.h b/src/common/network/IpNetwork.h index 08db7d49f84..c05c3076598 100644 --- a/src/common/Asio/IpNetwork.h +++ b/src/common/network/IpNetwork.h @@ -15,8 +15,8 @@ * with this program. If not, see <http://www.gnu.org/licenses/>. */ -#ifndef IpNetwork_h__ -#define IpNetwork_h__ +#ifndef TRINITYCORE_IP_NETWORK_H +#define TRINITYCORE_IP_NETWORK_H #include "AsioHacksFwd.h" #include "Define.h" @@ -25,15 +25,15 @@ namespace Trinity::Net { -TC_COMMON_API bool IsInLocalNetwork(boost::asio::ip::address const& clientAddress); +TC_NETWORK_API bool IsInLocalNetwork(boost::asio::ip::address const& clientAddress); -TC_COMMON_API bool IsInNetwork(boost::asio::ip::network_v4 const& network, boost::asio::ip::address_v4 const& clientAddress); +TC_NETWORK_API bool IsInNetwork(boost::asio::ip::network_v4 const& network, boost::asio::ip::address_v4 const& clientAddress); -TC_COMMON_API bool IsInNetwork(boost::asio::ip::network_v6 const& network, boost::asio::ip::address_v6 const& clientAddress); +TC_NETWORK_API bool IsInNetwork(boost::asio::ip::network_v6 const& network, boost::asio::ip::address_v6 const& clientAddress); -TC_COMMON_API Optional<std::size_t> SelectAddressForClient(boost::asio::ip::address const& clientAddress, std::span<boost::asio::ip::address const> const& addresses); +TC_NETWORK_API Optional<std::size_t> SelectAddressForClient(boost::asio::ip::address const& clientAddress, std::span<boost::asio::ip::address const> const& addresses); -TC_COMMON_API void ScanLocalNetworks(); +TC_NETWORK_API void ScanLocalNetworks(); } -#endif // IpNetwork_h__ +#endif // TRINITYCORE_IP_NETWORK_H diff --git a/src/common/network/NetworkThread.h b/src/common/network/NetworkThread.h new file mode 100644 index 00000000000..d16da442149 --- /dev/null +++ b/src/common/network/NetworkThread.h @@ -0,0 +1,179 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_NETWORK_THREAD_H +#define TRINITYCORE_NETWORK_THREAD_H + +#include "Containers.h" +#include "DeadlineTimer.h" +#include "Define.h" +#include "Errors.h" +#include "IoContext.h" +#include "Log.h" +#include "Socket.h" +#include <boost/asio/ip/tcp.hpp> +#include <atomic> +#include <memory> +#include <mutex> +#include <thread> + +namespace Trinity::Net +{ +template<class SocketType> +class NetworkThread +{ +public: + NetworkThread() : _connections(0), _stopped(false), _thread(nullptr), _ioContext(1), + _acceptSocket(_ioContext), _updateTimer(_ioContext) + { + } + + NetworkThread(NetworkThread const&) = delete; + NetworkThread(NetworkThread&&) = delete; + NetworkThread& operator=(NetworkThread const&) = delete; + NetworkThread& operator=(NetworkThread&&) = delete; + + virtual ~NetworkThread() + { + Stop(); + if (_thread) + Wait(); + } + + void Stop() + { + _stopped = true; + _ioContext.stop(); + } + + bool Start() + { + if (_thread) + return false; + + _thread = std::make_unique<std::thread>(&NetworkThread::Run, this); + return true; + } + + void Wait() + { + ASSERT(_thread); + + _thread->join(); + _thread = nullptr; + } + + int32 GetConnectionCount() const + { + return _connections; + } + + void AddSocket(std::shared_ptr<SocketType> sock) + { + std::lock_guard<std::mutex> lock(_newSocketsLock); + + ++_connections; + SocketAdded(_newSockets.emplace_back(std::move(sock))); + } + + Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; } + +protected: + virtual void SocketAdded(std::shared_ptr<SocketType> const& /*sock*/) { } + virtual void SocketRemoved(std::shared_ptr<SocketType> const& /*sock*/) { } + + void AddNewSockets() + { + std::lock_guard<std::mutex> lock(_newSocketsLock); + + if (_newSockets.empty()) + return; + + for (std::shared_ptr<SocketType>& sock : _newSockets) + { + if (!sock->IsOpen()) + { + SocketRemoved(sock); + --_connections; + } + else + _sockets.emplace_back(std::move(sock)); + } + + _newSockets.clear(); + } + + void Run() + { + TC_LOG_DEBUG("misc", "Network Thread Starting"); + + _updateTimer.expires_after(1ms); + _updateTimer.async_wait([this](boost::system::error_code const&) { Update(); }); + _ioContext.run(); + + TC_LOG_DEBUG("misc", "Network Thread exits"); + _newSockets.clear(); + _sockets.clear(); + } + + void Update() + { + if (_stopped) + return; + + _updateTimer.expires_after(1ms); + _updateTimer.async_wait([this](boost::system::error_code const&) { Update(); }); + + AddNewSockets(); + + Trinity::Containers::EraseIf(_sockets, [this](std::shared_ptr<SocketType> const& sock) + { + if (!sock->Update()) + { + if (sock->IsOpen()) + sock->CloseSocket(); + + this->SocketRemoved(sock); + + --this->_connections; + return true; + } + + return false; + }); + } + +private: + typedef std::vector<std::shared_ptr<SocketType>> SocketContainer; + + std::atomic<int32> _connections; + std::atomic<bool> _stopped; + + std::unique_ptr<std::thread> _thread; + + SocketContainer _sockets; + + std::mutex _newSocketsLock; + SocketContainer _newSockets; + + Trinity::Asio::IoContext _ioContext; + Trinity::Net::IoContextTcpSocket _acceptSocket; + Trinity::Asio::DeadlineTimer _updateTimer; +}; +} + +#endif // TRINITYCORE_NETWORK_THREAD_H diff --git a/src/common/network/Resolver.cpp b/src/common/network/Resolver.cpp new file mode 100644 index 00000000000..9bcfcf28e78 --- /dev/null +++ b/src/common/network/Resolver.cpp @@ -0,0 +1,47 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "Resolver.h" +#include <algorithm> + +Optional<boost::asio::ip::tcp::endpoint> Trinity::Net::Resolver::Resolve(boost::asio::ip::tcp const& protocol, std::string_view host, std::string_view service) +{ + boost::system::error_code ec; + boost::asio::ip::resolver_base::flags flagsResolver = boost::asio::ip::resolver_base::all_matching; + boost::asio::ip::tcp::resolver::results_type results = _impl.resolve(protocol, host, service, flagsResolver, ec); + Optional<boost::asio::ip::tcp::endpoint> result; + if (!ec) + if (auto itr = results.begin(); itr != results.end()) + result.emplace(itr->endpoint()); + + return result; +} + +std::vector<boost::asio::ip::tcp::endpoint> Trinity::Net::Resolver::ResolveAll(std::string_view host, std::string_view service) +{ + boost::system::error_code ec; + boost::asio::ip::resolver_base::flags flagsResolver = boost::asio::ip::resolver_base::all_matching; + boost::asio::ip::tcp::resolver::results_type results = _impl.resolve(host, service, flagsResolver, ec); + std::vector<boost::asio::ip::tcp::endpoint> result; + if (!ec) + { + result.resize(results.size()); + std::ranges::transform(results, result.begin(), [](boost::asio::ip::tcp::resolver::results_type::value_type const& entry) { return entry.endpoint(); }); + } + + return result; +} diff --git a/src/common/network/Resolver.h b/src/common/network/Resolver.h new file mode 100644 index 00000000000..c7d24658aa5 --- /dev/null +++ b/src/common/network/Resolver.h @@ -0,0 +1,47 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_RESOLVER_H +#define TRINITYCORE_RESOLVER_H + +#include "Define.h" +#include "IoContext.h" +#include "Optional.h" +#include <boost/asio/ip/tcp.hpp> +#include <string_view> +#include <vector> + +namespace Trinity::Net +{ +/** + Hack to make it possible to forward declare resolver (one of its template arguments is a typedef to something super long and using nested classes) +*/ +class TC_NETWORK_API Resolver +{ +public: + explicit Resolver(Asio::IoContext& ioContext) : _impl(ioContext) { } + + Optional<boost::asio::ip::tcp::endpoint> Resolve(boost::asio::ip::tcp const& protocol, std::string_view host, std::string_view service); + + std::vector<boost::asio::ip::tcp::endpoint> ResolveAll(std::string_view host, std::string_view service); + +private: + boost::asio::ip::tcp::resolver _impl; +}; +} + +#endif // TRINITYCORE_RESOLVER_H diff --git a/src/common/network/Socket.h b/src/common/network/Socket.h new file mode 100644 index 00000000000..565cc175318 --- /dev/null +++ b/src/common/network/Socket.h @@ -0,0 +1,362 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_SOCKET_H +#define TRINITYCORE_SOCKET_H + +#include "Concepts.h" +#include "Log.h" +#include "MessageBuffer.h" +#include "SocketConnectionInitializer.h" +#include <boost/asio/io_context.hpp> +#include <boost/asio/ip/tcp.hpp> +#include <atomic> +#include <memory> +#include <queue> +#include <type_traits> + +#define READ_BLOCK_SIZE 4096 +#ifdef BOOST_ASIO_HAS_IOCP +#define TC_SOCKET_USE_IOCP +#endif + +namespace Trinity::Net +{ +using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>; + +enum class SocketReadCallbackResult +{ + KeepReading, + Stop +}; + +template <typename Callable> +concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>; + +template <typename SocketType> +struct InvokeReadHandlerCallback +{ + SocketReadCallbackResult operator()() const + { + return this->Socket->ReadHandler(); + } + + SocketType* Socket; +}; + +template <typename SocketType> +struct ReadConnectionInitializer final : SocketConnectionInitializer +{ + explicit ReadConnectionInitializer(SocketType* socket) : ReadCallback({ .Socket = socket }) { } + + void Start() override + { + ReadCallback.Socket->AsyncRead(std::move(ReadCallback)); + + if (this->next) + this->next->Start(); + } + + InvokeReadHandlerCallback<SocketType> ReadCallback; +}; + +/** + @class Socket + + Base async socket implementation + + @tparam Stream stream type used for operations on socket + Stream must implement the following methods: + + void close(boost::system::error_code& error); + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError); + + template<typename MutableBufferSequence, typename ReadHandlerType> + void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler); + + template<typename ConstBufferSequence, typename WriteHandlerType> + void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler); + + template<typename ConstBufferSequence> + std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error); + + template<typename WaitHandlerType> + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler); + + template<typename SettableSocketOption> + void set_option(SettableSocketOption const& option, boost::system::error_code& error); + + tcp::socket::endpoint_type remote_endpoint() const; +*/ +template<class Stream = IoContextTcpSocket> +class Socket : public std::enable_shared_from_this<Socket<Stream>> +{ +public: + template<typename... Args> + explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...), + _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open) + { + } + + template<typename... Args> + explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward<Args>(args)...), _openState(OpenState_Closed) + { + } + + Socket(Socket const& other) = delete; + Socket(Socket&& other) = delete; + Socket& operator=(Socket const& other) = delete; + Socket& operator=(Socket&& other) = delete; + + virtual ~Socket() + { + _openState = OpenState_Closed; + boost::system::error_code error; + _socket.close(error); + } + + virtual void Start() { } + + virtual bool Update() + { + if (_openState == OpenState_Closed) + return false; + +#ifndef TC_SOCKET_USE_IOCP + if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open)) + return true; + + for (; HandleQueue();) + ; +#endif + + return true; + } + + boost::asio::ip::address const& GetRemoteIpAddress() const + { + return _remoteAddress; + } + + uint16 GetRemotePort() const + { + return _remotePort; + } + + template <SocketReadCallback Callback> + void AsyncRead(Callback&& callback) + { + if (!IsOpen()) + return; + + _readBuffer.Normalize(); + _readBuffer.EnsureFreeSpace(); + _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()), + [self = this->shared_from_this(), callback = std::forward<Callback>(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable + { + if (self->ReadHandlerInternal(error, transferredBytes)) + if (callback() == SocketReadCallbackResult::KeepReading) + self->AsyncRead(std::forward<Callback>(callback)); + }); + } + + void QueuePacket(MessageBuffer&& buffer) + { + _writeQueue.push(std::move(buffer)); + +#ifdef TC_SOCKET_USE_IOCP + AsyncProcessQueue(); +#endif + } + + bool IsOpen() const { return _openState == OpenState_Open; } + + void CloseSocket() + { + if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0) + return; + + boost::system::error_code shutdownError; + _socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError); + if (shutdownError) + TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress().to_string(), + shutdownError.value(), shutdownError.message()); + + this->OnClose(); + } + + /// Marks the socket for closing after write buffer becomes empty + void DelayedCloseSocket() + { + if (_openState.fetch_or(OpenState_Closing) != 0) + return; + + if (_writeQueue.empty()) + CloseSocket(); + } + + MessageBuffer& GetReadBuffer() { return _readBuffer; } + + Stream& underlying_stream() + { + return _socket; + } + +protected: + virtual void OnClose() { } + + virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; } + + bool AsyncProcessQueue() + { + if (_isWritingAsync) + return false; + + _isWritingAsync = true; + +#ifdef TC_SOCKET_USE_IOCP + MessageBuffer& buffer = _writeQueue.front(); + _socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()), + [self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes) + { + self->WriteHandler(error, transferedBytes); + }); +#else + _socket.async_wait(boost::asio::socket_base::wait_type::wait_write, + [self = this->shared_from_this()](boost::system::error_code const& error) + { + self->WriteHandlerWrapper(error); + }); +#endif + + return false; + } + + void SetNoDelay(bool enable) + { + boost::system::error_code err; + _socket.set_option(boost::asio::ip::tcp::no_delay(enable), err); + if (err) + TC_LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for {} - {} ({})", + GetRemoteIpAddress().to_string(), err.value(), err.message()); + } + +private: + bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes) + { + if (error) + { + CloseSocket(); + return false; + } + + _readBuffer.WriteCompleted(transferredBytes); + return IsOpen(); + } + +#ifdef TC_SOCKET_USE_IOCP + + void WriteHandler(boost::system::error_code const& error, std::size_t transferedBytes) + { + if (!error) + { + _isWritingAsync = false; + _writeQueue.front().ReadCompleted(transferedBytes); + if (!_writeQueue.front().GetActiveSize()) + _writeQueue.pop(); + + if (!_writeQueue.empty()) + AsyncProcessQueue(); + else if (_openState == OpenState_Closing) + CloseSocket(); + } + else + CloseSocket(); + } + +#else + + void WriteHandlerWrapper(boost::system::error_code const& /*error*/) + { + _isWritingAsync = false; + HandleQueue(); + } + + bool HandleQueue() + { + if (_writeQueue.empty()) + return false; + + MessageBuffer& queuedMessage = _writeQueue.front(); + + std::size_t bytesToSend = queuedMessage.GetActiveSize(); + + boost::system::error_code error; + std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error); + + if (error) + { + if (error == boost::asio::error::would_block || error == boost::asio::error::try_again) + return AsyncProcessQueue(); + + _writeQueue.pop(); + if (_openState == OpenState_Closing && _writeQueue.empty()) + CloseSocket(); + return false; + } + else if (bytesSent == 0) + { + _writeQueue.pop(); + if (_openState == OpenState_Closing && _writeQueue.empty()) + CloseSocket(); + return false; + } + else if (bytesSent < bytesToSend) // now n > 0 + { + queuedMessage.ReadCompleted(bytesSent); + return AsyncProcessQueue(); + } + + _writeQueue.pop(); + if (_openState == OpenState_Closing && _writeQueue.empty()) + CloseSocket(); + return !_writeQueue.empty(); + } + +#endif + + Stream _socket; + + boost::asio::ip::address _remoteAddress; + uint16 _remotePort = 0; + + MessageBuffer _readBuffer = MessageBuffer(READ_BLOCK_SIZE); + std::queue<MessageBuffer> _writeQueue; + + // Socket open state "enum" (not enum to enable integral std::atomic api) + static constexpr uint8 OpenState_Open = 0x0; + static constexpr uint8 OpenState_Closing = 0x1; ///< Transition to Closed state after sending all queued data + static constexpr uint8 OpenState_Closed = 0x2; + + std::atomic<uint8> _openState; + + bool _isWritingAsync = false; +}; +} + +#endif // TRINITYCORE_SOCKET_H diff --git a/src/common/network/SocketMgr.h b/src/common/network/SocketMgr.h new file mode 100644 index 00000000000..07252355308 --- /dev/null +++ b/src/common/network/SocketMgr.h @@ -0,0 +1,146 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_SOCKET_MGR_H +#define TRINITYCORE_SOCKET_MGR_H + +#include "AsyncAcceptor.h" +#include "Errors.h" +#include "NetworkThread.h" +#include "Socket.h" +#include <boost/asio/ip/tcp.hpp> +#include <memory> + +namespace Trinity::Net +{ +template<class SocketType> +class SocketMgr +{ +public: + SocketMgr(SocketMgr const&) = delete; + SocketMgr(SocketMgr&&) = delete; + SocketMgr& operator=(SocketMgr const&) = delete; + SocketMgr& operator=(SocketMgr&&) = delete; + + virtual ~SocketMgr() + { + ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction"); + } + + virtual bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount) + { + ASSERT(threadCount > 0); + + std::unique_ptr<AsyncAcceptor> acceptor = nullptr; + try + { + acceptor = std::make_unique<AsyncAcceptor>(ioContext, bindIp, port); + } + catch (boost::system::system_error const& err) + { + TC_LOG_ERROR("network", "Exception caught in SocketMgr.StartNetwork ({}:{}): {}", bindIp, port, err.what()); + return false; + } + + if (!acceptor->Bind()) + { + TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor"); + return false; + } + + _acceptor = std::move(acceptor); + _threadCount = threadCount; + _threads.reset(CreateThreads()); + + ASSERT(_threads); + + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Start(); + + _acceptor->SetSocketFactory([this]() { return GetSocketForAccept(); }); + + return true; + } + + virtual void StopNetwork() + { + _acceptor->Close(); + + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Stop(); + + Wait(); + + _acceptor = nullptr; + _threads = nullptr; + _threadCount = 0; + } + + void Wait() + { + for (int32 i = 0; i < _threadCount; ++i) + _threads[i].Wait(); + } + + virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex) + { + try + { + std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock)); + newSocket->Start(); + + _threads[threadIndex].AddSocket(newSocket); + } + catch (boost::system::system_error const& err) + { + TC_LOG_WARN("network", "Failed to retrieve client's remote address {}", err.what()); + } + } + + int32 GetNetworkThreadCount() const { return _threadCount; } + + uint32 SelectThreadWithMinConnections() const + { + uint32 min = 0; + + for (int32 i = 1; i < _threadCount; ++i) + if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount()) + min = i; + + return min; + } + + std::pair<IoContextTcpSocket*, uint32> GetSocketForAccept() + { + uint32 threadIndex = SelectThreadWithMinConnections(); + return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex); + } + +protected: + SocketMgr() : _threadCount(0) + { + } + + virtual NetworkThread<SocketType>* CreateThreads() const = 0; + + std::unique_ptr<AsyncAcceptor> _acceptor; + std::unique_ptr<NetworkThread<SocketType>[]> _threads; + int32 _threadCount; +}; +} + +#endif // TRINITYCORE_SOCKET_MGR_H diff --git a/src/common/network/SslStream.h b/src/common/network/SslStream.h new file mode 100644 index 00000000000..2cced44e5ff --- /dev/null +++ b/src/common/network/SslStream.h @@ -0,0 +1,131 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_SSL_STREAM_H +#define TRINITYCORE_SSL_STREAM_H + +#include "SocketConnectionInitializer.h" +#include <boost/asio/ip/tcp.hpp> +#include <boost/asio/ssl/stream.hpp> +#include <boost/system/error_code.hpp> + +namespace Trinity::Net +{ +template <typename SocketImpl> +struct SslHandshakeConnectionInitializer final : SocketConnectionInitializer +{ + explicit SslHandshakeConnectionInitializer(SocketImpl* socket) : _socket(socket) { } + + void Start() override + { + _socket->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server, + [socketRef = _socket->weak_from_this(), self = this->shared_from_this()](boost::system::error_code const& error) + { + std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock()); + if (!socket) + return; + + if (error) + { + TC_LOG_ERROR("session", "{} SSL Handshake failed {}", socket->GetClientInfo(), error.message()); + socket->CloseSocket(); + return; + } + + if (self->next) + self->next->Start(); + }); + } + +private: + SocketImpl* _socket; +}; + +template<class WrappedStream = IoContextTcpSocket> +class SslStream +{ +public: + explicit SslStream(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext) + { + _sslSocket.set_verify_mode(boost::asio::ssl::verify_none); + } + + explicit SslStream(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) : _sslSocket(context, sslContext) + { + _sslSocket.set_verify_mode(boost::asio::ssl::verify_none); + } + + // adapting tcp::socket api + void close(boost::system::error_code& error) + { + _sslSocket.next_layer().close(error); + } + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) + { + _sslSocket.shutdown(shutdownError); + _sslSocket.next_layer().shutdown(what, shutdownError); + } + + template<typename MutableBufferSequence, typename ReadHandlerType> + void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler) + { + _sslSocket.async_read_some(buffers, std::forward<ReadHandlerType>(handler)); + } + + template<typename ConstBufferSequence, typename WriteHandlerType> + void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler) + { + _sslSocket.async_write_some(buffers, std::forward<WriteHandlerType>(handler)); + } + + template<typename ConstBufferSequence> + std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error) + { + return _sslSocket.write_some(buffers, error); + } + + template<typename WaitHandlerType> + void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler) + { + _sslSocket.next_layer().async_wait(type, std::forward<WaitHandlerType>(handler)); + } + + template<typename SettableSocketOption> + void set_option(SettableSocketOption const& option, boost::system::error_code& error) + { + _sslSocket.next_layer().set_option(option, error); + } + + IoContextTcpSocket::endpoint_type remote_endpoint() const + { + return _sslSocket.next_layer().remote_endpoint(); + } + + // ssl api + template<typename HandshakeHandlerType> + void async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler) + { + _sslSocket.async_handshake(type, std::forward<HandshakeHandlerType>(handler)); + } + +protected: + boost::asio::ssl::stream<WrappedStream> _sslSocket; +}; +} + +#endif // TRINITYCORE_SSL_STREAM_H |