aboutsummaryrefslogtreecommitdiff
path: root/src/common
diff options
context:
space:
mode:
authorShauren <shauren.trinity@gmail.com>2025-04-09 21:02:31 +0200
committerOvahlord <dreadkiller@gmx.de>2025-04-09 21:09:48 +0200
commit48c8c93ec4791002e473e4ea7af2bea9d693be0a (patch)
tree55c8c295698f9d2542ede02d0c237e582908749d /src/common
parent00482e96553ad578dc32591c1b207f769f1d4eb9 (diff)
Core/Network: Move to separate project
(cherry picked from commit 71b681bbf0f5189cd87a6cea66ef51667223f54a)
Diffstat (limited to 'src/common')
-rw-r--r--src/common/Asio/AsioHacksFwd.h6
-rw-r--r--src/common/Asio/Resolver.h69
-rw-r--r--src/common/CMakeLists.txt14
-rw-r--r--src/common/Define.h6
-rw-r--r--src/common/IPLocation/IPLocation.cpp150
-rw-r--r--src/common/IPLocation/IPLocation.h58
-rw-r--r--src/common/Utilities/Util.cpp16
-rw-r--r--src/common/Utilities/Util.h2
-rw-r--r--src/common/network/AsyncAcceptor.h137
-rw-r--r--src/common/network/CMakeLists.txt53
-rw-r--r--src/common/network/ConnectionInitializers/SocketConnectionInitializer.h51
-rw-r--r--src/common/network/Http/BaseHttpSocket.cpp146
-rw-r--r--src/common/network/Http/BaseHttpSocket.h228
-rw-r--r--src/common/network/Http/HttpCommon.h55
-rw-r--r--src/common/network/Http/HttpService.cpp267
-rw-r--r--src/common/network/Http/HttpService.h188
-rw-r--r--src/common/network/Http/HttpSessionState.h35
-rw-r--r--src/common/network/Http/HttpSocket.h53
-rw-r--r--src/common/network/Http/HttpSslSocket.h58
-rw-r--r--src/common/network/IpAddress.h (renamed from src/common/Asio/IpAddress.h)21
-rw-r--r--src/common/network/IpNetwork.cpp (renamed from src/common/Asio/IpNetwork.cpp)4
-rw-r--r--src/common/network/IpNetwork.h (renamed from src/common/Asio/IpNetwork.h)16
-rw-r--r--src/common/network/NetworkThread.h179
-rw-r--r--src/common/network/Resolver.cpp47
-rw-r--r--src/common/network/Resolver.h47
-rw-r--r--src/common/network/Socket.h362
-rw-r--r--src/common/network/SocketMgr.h146
-rw-r--r--src/common/network/SslStream.h131
28 files changed, 2229 insertions, 316 deletions
diff --git a/src/common/Asio/AsioHacksFwd.h b/src/common/Asio/AsioHacksFwd.h
index a8f04b16d81..06f5e531c20 100644
--- a/src/common/Asio/AsioHacksFwd.h
+++ b/src/common/Asio/AsioHacksFwd.h
@@ -60,9 +60,13 @@ namespace Trinity
{
class DeadlineTimer;
class IoContext;
- class Resolver;
class Strand;
}
+
+ namespace Net
+ {
+ class Resolver;
+ }
}
#endif // AsioHacksFwd_h__
diff --git a/src/common/Asio/Resolver.h b/src/common/Asio/Resolver.h
deleted file mode 100644
index 84dedd21bfa..00000000000
--- a/src/common/Asio/Resolver.h
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
- *
- * This program is free software; you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation; either version 2 of the License, or (at your
- * option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but WITHOUT
- * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
- * more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef Resolver_h__
-#define Resolver_h__
-
-#include "IoContext.h"
-#include "Optional.h"
-#include <boost/asio/ip/tcp.hpp>
-#include <algorithm>
-#include <string_view>
-#include <vector>
-
-namespace Trinity
-{
- namespace Asio
- {
- /**
- Hack to make it possible to forward declare resolver (one of its template arguments is a typedef to something super long and using nested classes)
- */
- class Resolver
- {
- public:
- explicit Resolver(IoContext& ioContext) : _impl(ioContext) { }
-
- Optional<boost::asio::ip::tcp::endpoint> Resolve(boost::asio::ip::tcp const& protocol, std::string_view host, std::string_view service)
- {
- boost::system::error_code ec;
- boost::asio::ip::resolver_base::flags flagsResolver = boost::asio::ip::resolver_base::all_matching;
- boost::asio::ip::tcp::resolver::results_type results = _impl.resolve(protocol, host, service, flagsResolver, ec);
- if (results.begin() == results.end() || ec)
- return {};
-
- return results.begin()->endpoint();
- }
-
- std::vector<boost::asio::ip::tcp::endpoint> ResolveAll(std::string_view host, std::string_view service)
- {
- boost::system::error_code ec;
- boost::asio::ip::resolver_base::flags flagsResolver = boost::asio::ip::resolver_base::all_matching;
- boost::asio::ip::tcp::resolver::results_type results = _impl.resolve(host, service, flagsResolver, ec);
- std::vector<boost::asio::ip::tcp::endpoint> result;
- if (!ec)
- std::ranges::transform(results, std::back_inserter(result), [](boost::asio::ip::tcp::resolver::results_type::value_type const& entry) { return entry.endpoint(); });
-
- return result;
- }
-
- private:
- boost::asio::ip::tcp::resolver _impl;
- };
- }
-}
-
-#endif // Resolver_h__
diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt
index b202dc5f9a4..8b8ecc0f471 100644
--- a/src/common/CMakeLists.txt
+++ b/src/common/CMakeLists.txt
@@ -14,7 +14,8 @@ CollectSourceFiles(
# Exclude
${CMAKE_CURRENT_SOURCE_DIR}/Debugging/Windows
${CMAKE_CURRENT_SOURCE_DIR}/Platform
- ${CMAKE_CURRENT_SOURCE_DIR}/PrecompiledHeaders)
+ ${CMAKE_CURRENT_SOURCE_DIR}/PrecompiledHeaders
+ ${CMAKE_CURRENT_SOURCE_DIR}/network)
if(WIN32)
CollectSourceFiles(
@@ -27,6 +28,8 @@ if(WIN32)
WINDOWS_PLATFORM_SOURCES)
list(APPEND PRIVATE_SOURCES
${WINDOWS_PLATFORM_SOURCES})
+ unset(WINDOWS_DEBUGGING_SOURCES)
+ unset(WINDOWS_PLATFORM_SOURCES)
endif()
if(USE_COREPCH)
@@ -43,7 +46,8 @@ CollectIncludeDirectories(
${CMAKE_CURRENT_SOURCE_DIR}
PUBLIC_INCLUDES
# Exclude
- ${CMAKE_CURRENT_SOURCE_DIR}/PrecompiledHeaders)
+ ${CMAKE_CURRENT_SOURCE_DIR}/PrecompiledHeaders
+ ${CMAKE_CURRENT_SOURCE_DIR}/network)
target_include_directories(common
PUBLIC
@@ -95,3 +99,9 @@ endif()
if(USE_COREPCH)
add_cxx_pch(common ${PRIVATE_PCH_HEADER})
endif()
+
+unset(PRIVATE_SOURCES)
+unset(PRIVATE_PCH_HEADER)
+unset(PUBLIC_INCLUDES)
+
+add_subdirectory(network)
diff --git a/src/common/Define.h b/src/common/Define.h
index e4a2333c66d..f918db84314 100644
--- a/src/common/Define.h
+++ b/src/common/Define.h
@@ -111,6 +111,12 @@
# define TC_DATABASE_API TC_API_IMPORT
#endif
+#ifdef TRINITY_API_EXPORT_NETWORK
+# define TC_NETWORK_API TC_API_EXPORT
+#else
+# define TC_NETWORK_API TC_API_IMPORT
+#endif
+
#ifdef TRINITY_API_EXPORT_SHARED
# define TC_SHARED_API TC_API_EXPORT
#else
diff --git a/src/common/IPLocation/IPLocation.cpp b/src/common/IPLocation/IPLocation.cpp
deleted file mode 100644
index 72383bd3554..00000000000
--- a/src/common/IPLocation/IPLocation.cpp
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
- *
- * This program is free software; you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation; either version 2 of the License, or (at your
- * option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but WITHOUT
- * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
- * more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#include "IPLocation.h"
-#include "BigNumber.h"
-#include "Config.h"
-#include "Errors.h"
-#include "IpAddress.h"
-#include "Log.h"
-#include "Util.h"
-#include <algorithm>
-#include <fstream>
-
-IpLocationStore::IpLocationStore() = default;
-IpLocationStore::~IpLocationStore() = default;
-
-void IpLocationStore::Load()
-{
- _ipLocationStore.clear();
- TC_LOG_INFO("server.loading", "Loading IP Location Database...");
-
- std::string databaseFilePath = sConfigMgr->GetStringDefault("IPLocationFile", "");
- if (databaseFilePath.empty())
- return;
-
- // Check if file exists
- std::ifstream databaseFile(databaseFilePath);
- if (!databaseFile)
- {
- TC_LOG_ERROR("server.loading", "IPLocation: No ip database file exists ({}).", databaseFilePath);
- return;
- }
-
- if (!databaseFile.is_open())
- {
- TC_LOG_ERROR("server.loading", "IPLocation: Ip database file ({}) can not be opened.", databaseFilePath);
- return;
- }
-
- std::string ipFrom;
- std::string ipTo;
- std::string countryCode;
- std::string countryName;
- BigNumber bnParser;
- BigNumber ipv4Max(0xFFFFFFFF);
- BigNumber ipv6MappedMask(0xFFFF);
- ipv6MappedMask <<= 32;
-
- auto parseStringToIPv6 = [&](std::string const& str) -> Optional<std::array<uint8, 16>>
- {
- bnParser.SetDecStr(str);
- if (!bnParser.SetDecStr(str))
- return {};
- // convert ipv4 to ipv6 v4 mapped value
- if (bnParser <= ipv4Max)
- bnParser += ipv6MappedMask;
- return bnParser.ToByteArray<16>(false);
- };
-
- while (databaseFile.good())
- {
- // Read lines
- if (!std::getline(databaseFile, ipFrom, ','))
- break;
- if (!std::getline(databaseFile, ipTo, ','))
- break;
- if (!std::getline(databaseFile, countryCode, ','))
- break;
- if (!std::getline(databaseFile, countryName, '\n'))
- break;
-
- // Remove new lines and return
- std::erase_if(countryName, [](char c) { return c == '\r' || c == '\n'; });
-
- // Remove quotation marks
- std::erase(ipFrom, '"');
- std::erase(ipTo, '"');
- std::erase(countryCode, '"');
- std::erase(countryName, '"');
-
- if (countryCode == "-")
- continue;
-
- // Convert country code to lowercase
- strToLower(countryCode);
-
- Optional<std::array<uint8, 16>> from = parseStringToIPv6(ipFrom);
- if (!from)
- continue;
-
- Optional<std::array<uint8, 16>> to = parseStringToIPv6(ipTo);
- if (!to)
- continue;
-
- _ipLocationStore.emplace_back(*from, *to, std::move(countryCode), std::move(countryName));
- }
-
- std::ranges::sort(_ipLocationStore, {}, &IpLocationRecord::IpFrom);
- ASSERT(std::ranges::is_sorted(_ipLocationStore, [](IpLocationRecord const& a, IpLocationRecord const& b) { return a.IpFrom < b.IpTo; }),
- "Overlapping IP ranges detected in database file");
-
- databaseFile.close();
-
- TC_LOG_INFO("server.loading", ">> Loaded {} ip location entries.", _ipLocationStore.size());
-}
-
-IpLocationRecord const* IpLocationStore::GetLocationRecord(std::string const& ipAddress) const
-{
- boost::system::error_code error;
- boost::asio::ip::address address = Trinity::Net::make_address(ipAddress, error);
- if (error)
- return nullptr;
-
- std::array<uint8, 16> bytes = [&]() -> std::array<uint8, 16>
- {
- if (address.is_v6())
- return address.to_v6().to_bytes();
- if (address.is_v4())
- return Trinity::Net::make_address_v6(Trinity::Net::v4_mapped, address.to_v4()).to_bytes();
- return {};
- }();
- auto itr = std::ranges::upper_bound(_ipLocationStore, bytes, {}, &IpLocationRecord::IpTo);
- if (itr == _ipLocationStore.end())
- return nullptr;
-
- if (bytes < itr->IpFrom)
- return nullptr;
-
- return &(*itr);
-}
-
-IpLocationStore* IpLocationStore::Instance()
-{
- static IpLocationStore instance;
- return &instance;
-}
diff --git a/src/common/IPLocation/IPLocation.h b/src/common/IPLocation/IPLocation.h
deleted file mode 100644
index 5471948586b..00000000000
--- a/src/common/IPLocation/IPLocation.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
- *
- * This program is free software; you can redistribute it and/or modify it
- * under the terms of the GNU General Public License as published by the
- * Free Software Foundation; either version 2 of the License, or (at your
- * option) any later version.
- *
- * This program is distributed in the hope that it will be useful, but WITHOUT
- * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
- * more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program. If not, see <http://www.gnu.org/licenses/>.
- */
-
-#ifndef IPLOCATION_H
-#define IPLOCATION_H
-
-#include "Define.h"
-#include <array>
-#include <string>
-#include <vector>
-
-struct IpLocationRecord
-{
- IpLocationRecord() : IpFrom(), IpTo() { }
- IpLocationRecord(std::array<uint8, 16> ipFrom, std::array<uint8, 16> ipTo, std::string&& countryCode, std::string&& countryName)
- : IpFrom(ipFrom), IpTo(ipTo), CountryCode(std::move(countryCode)), CountryName(std::move(countryName)) { }
-
- std::array<uint8, 16> IpFrom;
- std::array<uint8, 16> IpTo;
- std::string CountryCode;
- std::string CountryName;
-};
-
-class TC_COMMON_API IpLocationStore
-{
- public:
- IpLocationStore();
- IpLocationStore(IpLocationStore const&) = delete;
- IpLocationStore(IpLocationStore&&) = delete;
- IpLocationStore& operator=(IpLocationStore const&) = delete;
- IpLocationStore& operator=(IpLocationStore&&) = delete;
- ~IpLocationStore();
- static IpLocationStore* Instance();
-
- void Load();
- IpLocationRecord const* GetLocationRecord(std::string const& ipAddress) const;
-
- private:
- std::vector<IpLocationRecord> _ipLocationStore;
-};
-
-#define sIPLocation IpLocationStore::Instance()
-
-#endif
diff --git a/src/common/Utilities/Util.cpp b/src/common/Utilities/Util.cpp
index 93bc3b853a5..c374bff9cf1 100644
--- a/src/common/Utilities/Util.cpp
+++ b/src/common/Utilities/Util.cpp
@@ -18,7 +18,6 @@
#include "Util.h"
#include "Common.h"
#include "Containers.h"
-#include "IpAddress.h"
#include "StringConvert.h"
#include "StringFormat.h"
#include <boost/core/demangle.hpp>
@@ -29,6 +28,10 @@
#include <cstdarg>
#include <ctime>
+#if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS
+#include <Windows.h>
+#endif
+
void Trinity::VerifyOsVersion()
{
#if TRINITY_PLATFORM == TRINITY_PLATFORM_WINDOWS
@@ -270,17 +273,6 @@ std::string TimeToHumanReadable(time_t t)
return std::string(buf);
}
-/// Check if the string is a valid ip address representation
-bool IsIPAddress(char const* ipaddress)
-{
- if (!ipaddress)
- return false;
-
- boost::system::error_code error;
- Trinity::Net::make_address(ipaddress, error);
- return !error;
-}
-
/// create PID file
uint32 CreatePIDFile(std::string const& filename)
{
diff --git a/src/common/Utilities/Util.h b/src/common/Utilities/Util.h
index c85fd858806..4907b4edbae 100644
--- a/src/common/Utilities/Util.h
+++ b/src/common/Utilities/Util.h
@@ -390,8 +390,6 @@ TC_COMMON_API bool WriteWinConsole(std::string_view str, bool error = false);
TC_COMMON_API Optional<std::size_t> RemoveCRLF(std::string& str);
-TC_COMMON_API bool IsIPAddress(char const* ipaddress);
-
TC_COMMON_API uint32 CreatePIDFile(std::string const& filename);
TC_COMMON_API uint32 GetPID();
diff --git a/src/common/network/AsyncAcceptor.h b/src/common/network/AsyncAcceptor.h
new file mode 100644
index 00000000000..dd0857c2b38
--- /dev/null
+++ b/src/common/network/AsyncAcceptor.h
@@ -0,0 +1,137 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_ASYNC_ACCEPTOR_H
+#define TRINITYCORE_ASYNC_ACCEPTOR_H
+
+#include "IoContext.h"
+#include "IpAddress.h"
+#include "Log.h"
+#include "Socket.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <boost/asio/ip/v6_only.hpp>
+#include <atomic>
+#include <functional>
+
+#define TRINITY_MAX_LISTEN_CONNECTIONS boost::asio::socket_base::max_listen_connections
+
+namespace Trinity::Net
+{
+template <typename Callable>
+concept AcceptCallback = std::invocable<Callable, IoContextTcpSocket&&, uint32>;
+
+class AsyncAcceptor
+{
+public:
+ AsyncAcceptor(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port) :
+ _acceptor(ioContext), _endpoint(make_address(bindIp), port),
+ _socket(ioContext), _closed(false), _socketFactory([this] { return DefeaultSocketFactory(); })
+ {
+ }
+
+ template <AcceptCallback Callback>
+ void AsyncAccept(Callback&& acceptCallback)
+ {
+ auto [tmpSocket, tmpThreadIndex] = _socketFactory();
+ // TODO: get rid of temporary variables (clang 15 cannot handle variables from structured bindings as lambda captures)
+ IoContextTcpSocket* socket = tmpSocket;
+ uint32 threadIndex = tmpThreadIndex;
+ _acceptor.async_accept(*socket, [this, socket, threadIndex, acceptCallback = std::forward<Callback>(acceptCallback)](boost::system::error_code const& error) mutable
+ {
+ if (!error)
+ {
+ try
+ {
+ socket->non_blocking(true);
+
+ acceptCallback(std::move(*socket), threadIndex);
+ }
+ catch (boost::system::system_error const& err)
+ {
+ TC_LOG_INFO("network", "Failed to initialize client's socket {}", err.what());
+ }
+ }
+
+ if (!_closed)
+ this->AsyncAccept(std::move(acceptCallback));
+ });
+ }
+
+ bool Bind()
+ {
+ boost::system::error_code errorCode;
+ _acceptor.open(_endpoint.protocol(), errorCode);
+ if (errorCode)
+ {
+ TC_LOG_INFO("network", "Failed to open acceptor {}", errorCode.message());
+ return false;
+ }
+
+#if TRINITY_PLATFORM != TRINITY_PLATFORM_WINDOWS
+ _acceptor.set_option(boost::asio::ip::tcp::acceptor::reuse_address(true), errorCode);
+ if (errorCode)
+ {
+ TC_LOG_INFO("network", "Failed to set reuse_address option on acceptor {}", errorCode.message());
+ return false;
+ }
+#endif
+
+ // v6_only is enabled on some *BSD distributions by default
+ // we want to allow both v4 and v6 connections to the same listener
+ if (_endpoint.protocol() == boost::asio::ip::tcp::v6())
+ _acceptor.set_option(boost::asio::ip::v6_only(false));
+
+ _acceptor.bind(_endpoint, errorCode);
+ if (errorCode)
+ {
+ TC_LOG_INFO("network", "Could not bind to {}:{} {}", _endpoint.address().to_string(), _endpoint.port(), errorCode.message());
+ return false;
+ }
+
+ _acceptor.listen(TRINITY_MAX_LISTEN_CONNECTIONS, errorCode);
+ if (errorCode)
+ {
+ TC_LOG_INFO("network", "Failed to start listening on {}:{} {}", _endpoint.address().to_string(), _endpoint.port(), errorCode.message());
+ return false;
+ }
+
+ return true;
+ }
+
+ void Close()
+ {
+ if (_closed.exchange(true))
+ return;
+
+ boost::system::error_code err;
+ _acceptor.close(err);
+ }
+
+ void SetSocketFactory(std::function<std::pair<IoContextTcpSocket*, uint32>()> func) { _socketFactory = std::move(func); }
+
+private:
+ std::pair<IoContextTcpSocket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
+
+ boost::asio::basic_socket_acceptor<boost::asio::ip::tcp, IoContextTcpSocket::executor_type> _acceptor;
+ boost::asio::ip::tcp::endpoint _endpoint;
+ IoContextTcpSocket _socket;
+ std::atomic<bool> _closed;
+ std::function<std::pair<IoContextTcpSocket*, uint32>()> _socketFactory;
+};
+}
+
+#endif // TRINITYCORE_ASYNC_ACCEPTOR_H
diff --git a/src/common/network/CMakeLists.txt b/src/common/network/CMakeLists.txt
new file mode 100644
index 00000000000..70408faed51
--- /dev/null
+++ b/src/common/network/CMakeLists.txt
@@ -0,0 +1,53 @@
+# This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+#
+# This file is free software; as a special exception the author gives
+# unlimited permission to copy and/or distribute it, with or without
+# modifications, as long as this notice is preserved.
+#
+# This program is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY, to the extent permitted by law; without even the
+# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+
+CollectSourceFiles(
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ PRIVATE_SOURCES)
+
+GroupSources(${CMAKE_CURRENT_SOURCE_DIR})
+
+add_library(network
+ ${PRIVATE_SOURCES})
+
+CollectIncludeDirectories(
+ ${CMAKE_CURRENT_SOURCE_DIR}
+ PUBLIC_INCLUDES)
+
+target_include_directories(network
+ PUBLIC
+ ${PUBLIC_INCLUDES}
+ PRIVATE
+ ${CMAKE_CURRENT_BINARY_DIR})
+
+target_link_libraries(network
+ PRIVATE
+ trinity-core-interface
+ PUBLIC
+ common)
+
+set_target_properties(network
+ PROPERTIES
+ COMPILE_WARNING_AS_ERROR ${WITH_WARNINGS_AS_ERRORS}
+ DEFINE_SYMBOL TRINITY_API_EXPORT_NETWORK
+ FOLDER "server"
+ OUTPUT_NAME trinity_network)
+
+if(BUILD_SHARED_LIBS)
+ if(UNIX)
+ install(TARGETS network
+ LIBRARY
+ DESTINATION lib)
+ elseif(WIN32)
+ install(TARGETS network
+ RUNTIME
+ DESTINATION "${CMAKE_INSTALL_PREFIX}")
+ endif()
+endif()
diff --git a/src/common/network/ConnectionInitializers/SocketConnectionInitializer.h b/src/common/network/ConnectionInitializers/SocketConnectionInitializer.h
new file mode 100644
index 00000000000..d3f0bb16dbf
--- /dev/null
+++ b/src/common/network/ConnectionInitializers/SocketConnectionInitializer.h
@@ -0,0 +1,51 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
+#define TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
+
+#include <memory>
+#include <span>
+
+namespace Trinity::Net
+{
+struct SocketConnectionInitializer : public std::enable_shared_from_this<SocketConnectionInitializer>
+{
+ SocketConnectionInitializer() = default;
+
+ SocketConnectionInitializer(SocketConnectionInitializer const&) = delete;
+ SocketConnectionInitializer(SocketConnectionInitializer&&) noexcept = default;
+ SocketConnectionInitializer& operator=(SocketConnectionInitializer const&) = delete;
+ SocketConnectionInitializer& operator=(SocketConnectionInitializer&&) noexcept = default;
+
+ virtual ~SocketConnectionInitializer() = default;
+
+ virtual void Start() = 0;
+
+ std::shared_ptr<SocketConnectionInitializer> next;
+
+ static std::shared_ptr<SocketConnectionInitializer>& SetupChain(std::span<std::shared_ptr<SocketConnectionInitializer>> initializers)
+ {
+ for (std::size_t i = initializers.size(); i > 1; --i)
+ initializers[i - 2]->next.swap(initializers[i - 1]);
+
+ return initializers[0];
+ }
+};
+}
+
+#endif // TRINITYCORE_SOCKET_CONNECTION_INITIALIZER_H
diff --git a/src/common/network/Http/BaseHttpSocket.cpp b/src/common/network/Http/BaseHttpSocket.cpp
new file mode 100644
index 00000000000..33053399c14
--- /dev/null
+++ b/src/common/network/Http/BaseHttpSocket.cpp
@@ -0,0 +1,146 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "BaseHttpSocket.h"
+#include <boost/asio/buffers_iterator.hpp>
+#include <boost/beast/http/serializer.hpp>
+
+namespace Trinity::Net::Http
+{
+using RequestSerializer = boost::beast::http::request_serializer<ResponseBody>;
+using ResponseSerializer = boost::beast::http::response_serializer<ResponseBody>;
+
+bool AbstractSocket::ParseRequest(MessageBuffer& packet, RequestParser& parser)
+{
+ if (!parser.is_done())
+ {
+ // need more data in the payload
+ boost::system::error_code ec = {};
+ std::size_t readDataSize = parser.put(boost::asio::const_buffer(packet.GetReadPointer(), packet.GetActiveSize()), ec);
+ packet.ReadCompleted(readDataSize);
+ }
+
+ return parser.is_done();
+}
+
+std::string AbstractSocket::SerializeRequest(Request const& request)
+{
+ RequestSerializer serializer(request);
+
+ std::string buffer;
+ while (!serializer.is_done())
+ {
+ size_t totalBytes = 0;
+ boost::system::error_code ec = {};
+ serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code const&, ConstBufferSequence const& buffers)
+ {
+ size_t totalBytesInBuffers = boost::asio::buffer_size(buffers);
+
+ buffer.reserve(buffer.size() + totalBytes);
+
+ auto begin = boost::asio::buffers_begin(buffers);
+ auto end = boost::asio::buffers_end(buffers);
+
+ std::copy(begin, end, std::back_inserter(buffer));
+ totalBytes += totalBytesInBuffers;
+ });
+
+ serializer.consume(totalBytes);
+ }
+
+ return buffer;
+}
+
+MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response& response)
+{
+ response.prepare_payload();
+
+ ResponseSerializer serializer(response);
+ bool (*serializerIsDone)(ResponseSerializer&);
+ if (request.method() != boost::beast::http::verb::head)
+ {
+ serializerIsDone = [](ResponseSerializer& s) { return s.is_done(); };
+ }
+ else
+ {
+ serializerIsDone = [](ResponseSerializer& s) { return s.is_header_done(); };
+ serializer.split(true);
+ }
+
+ MessageBuffer buffer;
+ while (!serializerIsDone(serializer))
+ {
+ serializer.limit(buffer.GetRemainingSpace());
+
+ size_t totalBytes = 0;
+ boost::system::error_code ec = {};
+ serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code& currentError, ConstBufferSequence const& buffers)
+ {
+ size_t totalBytesInBuffers = boost::asio::buffer_size(buffers);
+ if (totalBytesInBuffers > buffer.GetRemainingSpace())
+ {
+ currentError = boost::beast::http::error::need_more;
+ return;
+ }
+
+ auto begin = boost::asio::buffers_begin(buffers);
+ auto end = boost::asio::buffers_end(buffers);
+
+ std::copy(begin, end, buffer.GetWritePointer());
+ buffer.WriteCompleted(totalBytesInBuffers);
+ totalBytes += totalBytesInBuffers;
+ });
+
+ serializer.consume(totalBytes);
+
+ if (ec == boost::beast::http::error::need_more)
+ buffer.Resize(buffer.GetBufferSize() + 4096);
+ }
+
+ return buffer;
+}
+
+void AbstractSocket::LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const
+{
+ if (Logger const* logger = sLog->GetEnabledLogger("server.http", LOG_LEVEL_DEBUG))
+ {
+ std::string clientInfo = GetClientInfo();
+ sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_DEBUG, "{} Request {} {} done, status {}", clientInfo,
+ ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()), context.response.result_int());
+ if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE))
+ {
+ sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Request: {}", clientInfo,
+ CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>");
+ sLog->OutMessageTo(logger, "server.http", LOG_LEVEL_TRACE, "{} Response: {}", clientInfo,
+ CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>");
+ }
+ }
+}
+
+std::string AbstractSocket::GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state)
+{
+ std::string info = StringFormat("[{}:{}", address.to_string(), port);
+ if (state)
+ {
+ info.append(", Session Id: ");
+ info.append(boost::uuids::to_string(state->Id));
+ }
+
+ info += ']';
+ return info;
+}
+}
diff --git a/src/common/network/Http/BaseHttpSocket.h b/src/common/network/Http/BaseHttpSocket.h
new file mode 100644
index 00000000000..2d58b6d5aa1
--- /dev/null
+++ b/src/common/network/Http/BaseHttpSocket.h
@@ -0,0 +1,228 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_BASE_HTTP_SOCKET_H
+#define TRINITYCORE_BASE_HTTP_SOCKET_H
+
+#include "AsyncCallbackProcessor.h"
+#include "HttpCommon.h"
+#include "HttpSessionState.h"
+#include "Optional.h"
+#include "Socket.h"
+#include "SocketConnectionInitializer.h"
+#include <boost/beast/core/basic_stream.hpp>
+#include <boost/beast/http/parser.hpp>
+#include <boost/beast/http/string_body.hpp>
+#include <boost/uuid/uuid_io.hpp>
+
+namespace Trinity::Net::Http
+{
+using IoContextHttpSocket = boost::beast::basic_stream<boost::asio::ip::tcp, boost::asio::io_context::executor_type, boost::beast::unlimited_rate_policy>;
+
+namespace Impl
+{
+class BoostBeastSocketWrapper : public IoContextHttpSocket
+{
+public:
+ using IoContextHttpSocket::basic_stream;
+
+ void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
+ {
+ socket().shutdown(what, shutdownError);
+ }
+
+ void close(boost::system::error_code& /*error*/)
+ {
+ IoContextHttpSocket::close();
+ }
+
+ template<typename WaitHandlerType>
+ void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
+ {
+ socket().async_wait(type, std::forward<WaitHandlerType>(handler));
+ }
+
+ IoContextTcpSocket::endpoint_type remote_endpoint() const
+ {
+ return socket().remote_endpoint();
+ }
+};
+}
+
+using RequestParser = boost::beast::http::request_parser<RequestBody>;
+
+class TC_NETWORK_API AbstractSocket
+{
+public:
+ AbstractSocket() = default;
+ AbstractSocket(AbstractSocket const& other) = default;
+ AbstractSocket(AbstractSocket&& other) = default;
+ AbstractSocket& operator=(AbstractSocket const& other) = default;
+ AbstractSocket& operator=(AbstractSocket&& other) = default;
+ virtual ~AbstractSocket() = default;
+
+ static bool ParseRequest(MessageBuffer& packet, RequestParser& parser);
+
+ static std::string SerializeRequest(Request const& request);
+ static MessageBuffer SerializeResponse(Request const& request, Response& response);
+
+ virtual void SendResponse(RequestContext& context) = 0;
+
+ void LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const;
+
+ virtual std::string GetClientInfo() const = 0;
+
+ static std::string GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state);
+
+ virtual SessionState* GetSessionState() const = 0;
+
+ Optional<boost::uuids::uuid> GetSessionId() const
+ {
+ if (SessionState* state = this->GetSessionState())
+ return state->Id;
+
+ return {};
+ }
+
+ virtual void Start() = 0;
+
+ virtual bool Update() = 0;
+
+ virtual boost::asio::ip::address const& GetRemoteIpAddress() const = 0;
+
+ virtual bool IsOpen() const = 0;
+
+ virtual void CloseSocket() = 0;
+};
+
+template <typename SocketImpl>
+struct HttpConnectionInitializer final : SocketConnectionInitializer
+{
+ explicit HttpConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
+
+ void Start() override
+ {
+ _socket->ResetHttpParser();
+
+ if (this->next)
+ this->next->Start();
+ }
+
+private:
+ SocketImpl* _socket;
+};
+
+template<typename Stream>
+class BaseSocket : public Trinity::Net::Socket<Stream>, public AbstractSocket
+{
+ using Base = Trinity::Net::Socket<Stream>;
+
+public:
+ using Base::Base;
+
+ BaseSocket(BaseSocket const& other) = delete;
+ BaseSocket(BaseSocket&& other) = delete;
+ BaseSocket& operator=(BaseSocket const& other) = delete;
+ BaseSocket& operator=(BaseSocket&& other) = delete;
+
+ ~BaseSocket() = default;
+
+ SocketReadCallbackResult ReadHandler() final
+ {
+ MessageBuffer& packet = this->GetReadBuffer();
+ while (packet.GetActiveSize() > 0)
+ {
+ if (!ParseRequest(packet, *_httpParser))
+ {
+ // Couldn't receive the whole data this time.
+ break;
+ }
+
+ if (!HandleMessage(_httpParser->get()))
+ {
+ this->CloseSocket();
+ return SocketReadCallbackResult::Stop;
+ }
+
+ this->ResetHttpParser();
+ }
+
+ return SocketReadCallbackResult::KeepReading;
+ }
+
+ bool HandleMessage(Request& request)
+ {
+ RequestContext context { .request = std::move(request) };
+
+ if (!_state)
+ _state = this->ObtainSessionState(context);
+
+ RequestHandlerResult status = this->RequestHandler(context);
+
+ if (status != RequestHandlerResult::Async)
+ this->SendResponse(context);
+
+ return status != RequestHandlerResult::Error;
+ }
+
+ virtual RequestHandlerResult RequestHandler(RequestContext& context) = 0;
+
+ void SendResponse(RequestContext& context) final
+ {
+ MessageBuffer buffer = SerializeResponse(context.request, context.response);
+
+ this->LogRequestAndResponse(context, buffer);
+
+ this->QueuePacket(std::move(buffer));
+
+ if (!context.response.keep_alive())
+ this->DelayedCloseSocket();
+ }
+
+ void Start() override { return this->Base::Start(); }
+
+ bool Update() override { return this->Base::Update(); }
+
+ boost::asio::ip::address const& GetRemoteIpAddress() const final { return this->Base::GetRemoteIpAddress(); }
+
+ bool IsOpen() const final { return this->Base::IsOpen(); }
+
+ void CloseSocket() final { return this->Base::CloseSocket(); }
+
+ std::string GetClientInfo() const override
+ {
+ return AbstractSocket::GetClientInfo(this->GetRemoteIpAddress(), this->GetRemotePort(), this->_state.get());
+ }
+
+ SessionState* GetSessionState() const override { return _state.get(); }
+
+ void ResetHttpParser()
+ {
+ this->_httpParser.reset();
+ this->_httpParser.emplace();
+ this->_httpParser->eager(true);
+ }
+
+protected:
+ virtual std::shared_ptr<SessionState> ObtainSessionState(RequestContext& context) const = 0;
+
+ Optional<RequestParser> _httpParser;
+ std::shared_ptr<SessionState> _state;
+};
+}
+
+#endif // TRINITYCORE_BASE_HTTP_SOCKET_H
diff --git a/src/common/network/Http/HttpCommon.h b/src/common/network/Http/HttpCommon.h
new file mode 100644
index 00000000000..274b59d7536
--- /dev/null
+++ b/src/common/network/Http/HttpCommon.h
@@ -0,0 +1,55 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_HTTP_COMMON_H
+#define TRINITYCORE_HTTP_COMMON_H
+
+#include "Define.h"
+#include <boost/beast/http/message.hpp>
+#include <boost/beast/http/string_body.hpp>
+
+namespace Trinity::Net::Http
+{
+using RequestBody = boost::beast::http::string_body;
+using ResponseBody = boost::beast::http::string_body;
+
+using Request = boost::beast::http::request<RequestBody>;
+using Response = boost::beast::http::response<ResponseBody>;
+
+struct RequestContext
+{
+ Request request;
+ Response response;
+ struct RequestHandler const* handler = nullptr;
+};
+
+TC_NETWORK_API bool CanLogRequestContent(RequestContext const& context);
+TC_NETWORK_API bool CanLogResponseContent(RequestContext const& context);
+
+inline std::string_view ToStdStringView(boost::beast::string_view bsw)
+{
+ return { bsw.data(), bsw.size() };
+}
+
+enum class RequestHandlerResult
+{
+ Handled,
+ Error,
+ Async,
+};
+}
+#endif // TRINITYCORE_HTTP_COMMON_H
diff --git a/src/common/network/Http/HttpService.cpp b/src/common/network/Http/HttpService.cpp
new file mode 100644
index 00000000000..b01e27e296a
--- /dev/null
+++ b/src/common/network/Http/HttpService.cpp
@@ -0,0 +1,267 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "HttpService.h"
+#include "BaseHttpSocket.h"
+#include "CryptoRandom.h"
+#include "Timezone.h"
+#include <boost/beast/version.hpp>
+#include <boost/uuid/string_generator.hpp>
+#include <fmt/chrono.h>
+
+namespace Trinity::Net::Http
+{
+bool CanLogRequestContent(RequestContext const& context)
+{
+ return !context.handler || !context.handler->Flags.HasFlag(RequestHandlerFlag::DoNotLogRequestContent);
+}
+
+bool CanLogResponseContent(RequestContext const& context)
+{
+ return !context.handler || !context.handler->Flags.HasFlag(RequestHandlerFlag::DoNotLogResponseContent);
+}
+
+RequestHandlerResult DispatcherService::HandleRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context)
+{
+ TC_LOG_DEBUG(_logger, "{} Starting request {} {}", session->GetClientInfo(),
+ ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()));
+
+ std::string_view path = [&]
+ {
+ std::string_view path = ToStdStringView(context.request.target());
+ size_t queryIndex = path.find('?');
+ if (queryIndex != std::string_view::npos)
+ path = path.substr(0, queryIndex);
+ return path;
+ }();
+
+ context.handler = [&]() -> HttpMethodHandlerMap::mapped_type const*
+ {
+ switch (context.request.method())
+ {
+ case boost::beast::http::verb::get:
+ case boost::beast::http::verb::head:
+ {
+ auto itr = _getHandlers.find(path);
+ return itr != _getHandlers.end() ? &itr->second : nullptr;
+ }
+ case boost::beast::http::verb::post:
+ {
+ auto itr = _postHandlers.find(path);
+ return itr != _postHandlers.end() ? &itr->second : nullptr;
+ }
+ default:
+ break;
+ }
+ return nullptr;
+ }();
+
+ SystemTimePoint responseDate = SystemTimePoint::clock::now();
+ context.response.set(boost::beast::http::field::date, StringFormat("{:%a, %d %b %Y %T GMT}", responseDate - Timezone::GetSystemZoneOffsetAt(responseDate)));
+ context.response.set(boost::beast::http::field::server, BOOST_BEAST_VERSION_STRING);
+ context.response.keep_alive(context.request.keep_alive());
+
+ if (!context.handler)
+ return HandlePathNotFound(std::move(session), context);
+
+ return context.handler->Func(std::move(session), context);
+}
+
+RequestHandlerResult DispatcherService::HandleBadRequest(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context)
+{
+ context.response.result(boost::beast::http::status::bad_request);
+ return RequestHandlerResult::Handled;
+}
+
+RequestHandlerResult DispatcherService::HandleUnauthorized(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context)
+{
+ context.response.result(boost::beast::http::status::unauthorized);
+ return RequestHandlerResult::Handled;
+}
+
+RequestHandlerResult DispatcherService::HandlePathNotFound(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context)
+{
+ context.response.result(boost::beast::http::status::not_found);
+ return RequestHandlerResult::Handled;
+}
+
+void DispatcherService::RegisterHandler(boost::beast::http::verb method, std::string_view path,
+ std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> handler,
+ RequestHandlerFlag flags)
+{
+ HttpMethodHandlerMap& handlerMap = [&]() -> HttpMethodHandlerMap&
+ {
+ switch (method)
+ {
+ case boost::beast::http::verb::get:
+ return _getHandlers;
+ case boost::beast::http::verb::post:
+ return _postHandlers;
+ default:
+ {
+ std::string_view methodString = ToStdStringView(boost::beast::http::to_string(method));
+ ABORT_MSG("Tried to register a handler for unsupported HTTP method " STRING_VIEW_FMT, STRING_VIEW_FMT_ARG(methodString));
+ }
+ }
+ }();
+
+ handlerMap[std::string(path)] = { .Func = std::move(handler), .Flags = flags };
+ TC_LOG_INFO(_logger, "Registered new handler for {} {}", ToStdStringView(boost::beast::http::to_string(method)), path);
+}
+
+void SessionService::InitAndStoreSessionState(std::shared_ptr<SessionState> state, boost::asio::ip::address const& address)
+{
+ state->RemoteAddress = address;
+
+ // Generate session id
+ {
+ std::unique_lock lock{ _sessionsMutex };
+
+ while (state->Id.is_nil() || _sessions.contains(state->Id))
+ std::copy_n(Trinity::Crypto::GetRandomBytes<16>().begin(), 16, state->Id.begin());
+
+ TC_LOG_DEBUG(_logger, "Client at {} created new session {}", address.to_string(), boost::uuids::to_string(state->Id));
+ _sessions[state->Id] = std::move(state);
+ }
+}
+
+void SessionService::Start(Asio::IoContext& ioContext)
+{
+ _inactiveSessionsKillTimer = std::make_unique<Asio::DeadlineTimer>(ioContext);
+ _inactiveSessionsKillTimer->expires_after(1min);
+ _inactiveSessionsKillTimer->async_wait([this](boost::system::error_code const& err)
+ {
+ if (err)
+ return;
+
+ KillInactiveSessions();
+ });
+}
+
+void SessionService::Stop()
+{
+ _inactiveSessionsKillTimer = nullptr;
+ {
+ std::unique_lock lock{ _sessionsMutex };
+ _sessions.clear();
+ }
+ {
+ std::unique_lock lock{ _inactiveSessionsMutex };
+ _inactiveSessions.clear();
+ }
+}
+
+std::shared_ptr<SessionState> SessionService::FindAndRefreshSessionState(std::string_view id, boost::asio::ip::address const& address)
+{
+ std::shared_ptr<SessionState> state;
+
+ {
+ std::shared_lock lock{ _sessionsMutex };
+ auto itr = _sessions.find(boost::uuids::string_generator()(id.begin(), id.end()));
+ if (itr == _sessions.end())
+ {
+ TC_LOG_DEBUG(_logger, "Client at {} attempted to use a session {} that was expired", address.to_string(), id);
+ return nullptr; // no session
+ }
+
+ state = itr->second;
+ }
+
+ if (state->RemoteAddress != address)
+ {
+ TC_LOG_ERROR(_logger, "Client at {} attempted to use a session {} that was last accessed from {}, denied access",
+ address.to_string(), id, state->RemoteAddress.to_string());
+ return nullptr;
+ }
+
+ {
+ std::unique_lock inactiveSessionsLock{ _inactiveSessionsMutex };
+ _inactiveSessions.erase(state->Id);
+ }
+
+ return state;
+}
+
+void SessionService::MarkSessionInactive(boost::uuids::uuid const& id)
+{
+ bool wasActive = true;
+ {
+ std::unique_lock inactiveSessionsLock{ _inactiveSessionsMutex };
+ wasActive = _inactiveSessions.insert(id).second;
+ }
+
+ if (wasActive)
+ {
+ std::shared_lock lock{ _sessionsMutex };
+ auto itr = _sessions.find(id);
+ if (itr != _sessions.end())
+ {
+ itr->second->InactiveTimestamp = TimePoint::clock::now() + Minutes(5);
+ TC_LOG_TRACE(_logger, "Session {} marked as inactive", boost::uuids::to_string(id));
+ }
+ }
+}
+
+void SessionService::KillInactiveSessions()
+{
+ std::set<boost::uuids::uuid> inactiveSessions;
+
+ {
+ std::unique_lock lock{ _inactiveSessionsMutex };
+ std::swap(_inactiveSessions, inactiveSessions);
+ }
+
+ {
+ TimePoint now = TimePoint::clock::now();
+ std::size_t inactiveSessionsCount = inactiveSessions.size();
+
+ std::unique_lock lock{ _sessionsMutex };
+ for (auto itr = inactiveSessions.begin(); itr != inactiveSessions.end(); )
+ {
+ auto sessionItr = _sessions.find(*itr);
+ if (sessionItr == _sessions.end() || sessionItr->second->InactiveTimestamp < now)
+ {
+ _sessions.erase(sessionItr);
+ itr = inactiveSessions.erase(itr);
+ }
+ else
+ ++itr;
+ }
+
+ TC_LOG_DEBUG(_logger, "Killed {} inactive sessions", inactiveSessionsCount - inactiveSessions.size());
+ }
+
+ {
+ // restore sessions not killed to inactive queue
+ std::unique_lock lock{ _inactiveSessionsMutex };
+ for (auto itr = inactiveSessions.begin(); itr != inactiveSessions.end(); )
+ {
+ auto node = inactiveSessions.extract(itr++);
+ _inactiveSessions.insert(std::move(node));
+ }
+ }
+
+ _inactiveSessionsKillTimer->expires_after(1min);
+ _inactiveSessionsKillTimer->async_wait([this](boost::system::error_code const& err)
+ {
+ if (err)
+ return;
+
+ KillInactiveSessions();
+ });
+}
+}
diff --git a/src/common/network/Http/HttpService.h b/src/common/network/Http/HttpService.h
new file mode 100644
index 00000000000..1549893576f
--- /dev/null
+++ b/src/common/network/Http/HttpService.h
@@ -0,0 +1,188 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_HTTP_SERVICE_H
+#define TRINITYCORE_HTTP_SERVICE_H
+
+#include "AsioHacksFwd.h"
+#include "Concepts.h"
+#include "Define.h"
+#include "EnumFlag.h"
+#include "HttpCommon.h"
+#include "HttpSessionState.h"
+#include "Optional.h"
+#include "SocketMgr.h"
+#include <boost/uuid/uuid.hpp>
+#include <functional>
+#include <map>
+#include <set>
+#include <shared_mutex>
+
+namespace Trinity::Net::Http
+{
+class AbstractSocket;
+
+enum class RequestHandlerFlag
+{
+ None = 0x0,
+ DoNotLogRequestContent = 0x1,
+ DoNotLogResponseContent = 0x2,
+};
+
+DEFINE_ENUM_FLAG(RequestHandlerFlag);
+
+struct RequestHandler
+{
+ std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> Func;
+ EnumFlag<RequestHandlerFlag> Flags = RequestHandlerFlag::None;
+};
+
+class TC_NETWORK_API DispatcherService
+{
+public:
+ explicit DispatcherService(std::string_view loggerSuffix) : _logger("server.http.dispatcher.")
+ {
+ _logger.append(loggerSuffix);
+ }
+
+ RequestHandlerResult HandleRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context);
+
+ static RequestHandlerResult HandleBadRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context);
+ static RequestHandlerResult HandleUnauthorized(std::shared_ptr<AbstractSocket> session, RequestContext& context);
+ static RequestHandlerResult HandlePathNotFound(std::shared_ptr<AbstractSocket> session, RequestContext& context);
+
+protected:
+ void RegisterHandler(boost::beast::http::verb method, std::string_view path,
+ std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> handler,
+ RequestHandlerFlag flags = RequestHandlerFlag::None);
+
+private:
+ using HttpMethodHandlerMap = std::map<std::string, RequestHandler, std::less<>>;
+
+ HttpMethodHandlerMap _getHandlers;
+ HttpMethodHandlerMap _postHandlers;
+
+ std::string _logger;
+};
+
+class TC_NETWORK_API SessionService
+{
+public:
+ explicit SessionService(std::string_view loggerSuffix) : _logger("server.http.session.")
+ {
+ _logger.append(loggerSuffix);
+ }
+
+ void Start(Asio::IoContext& ioContext);
+ void Stop();
+
+ std::shared_ptr<SessionState> FindAndRefreshSessionState(std::string_view id, boost::asio::ip::address const& address);
+ void MarkSessionInactive(boost::uuids::uuid const& id);
+
+protected:
+ void InitAndStoreSessionState(std::shared_ptr<SessionState> state, boost::asio::ip::address const& address);
+
+ void KillInactiveSessions();
+
+private:
+ std::shared_mutex _sessionsMutex;
+ std::map<boost::uuids::uuid, std::shared_ptr<SessionState>> _sessions;
+
+ std::mutex _inactiveSessionsMutex;
+ std::set<boost::uuids::uuid> _inactiveSessions;
+ std::unique_ptr<Asio::DeadlineTimer> _inactiveSessionsKillTimer;
+
+ std::string _logger;
+};
+
+template<typename Callable, typename SessionImpl>
+concept HttpRequestHandler = invocable_r<Callable, RequestHandlerResult, std::shared_ptr<SessionImpl>, RequestContext&>;
+
+template<typename SessionImpl>
+class HttpService : public SocketMgr<SessionImpl>, public DispatcherService, public SessionService
+{
+public:
+ HttpService(std::string_view loggerSuffix) : DispatcherService(loggerSuffix), SessionService(loggerSuffix), _ioContext(nullptr), _logger("server.http.")
+ {
+ _logger.append(loggerSuffix);
+ }
+
+ bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount = 1) override
+ {
+ if (!SocketMgr<SessionImpl>::StartNetwork(ioContext, bindIp, port, threadCount))
+ return false;
+
+ SessionService::Start(ioContext);
+ return true;
+ }
+
+ void StopNetwork() override
+ {
+ SessionService::Stop();
+ SocketMgr<SessionImpl>::StopNetwork();
+ }
+
+ // http handling
+ using DispatcherService::RegisterHandler;
+
+ template<HttpRequestHandler<SessionImpl> Callable>
+ void RegisterHandler(boost::beast::http::verb method, std::string_view path, Callable handler, RequestHandlerFlag flags = RequestHandlerFlag::None)
+ {
+ this->DispatcherService::RegisterHandler(method, path, [handler = std::move(handler)](std::shared_ptr<AbstractSocket> session, RequestContext& context) -> RequestHandlerResult
+ {
+ return handler(std::static_pointer_cast<SessionImpl>(std::move(session)), context);
+ }, flags);
+ }
+
+ // session tracking
+ virtual std::shared_ptr<SessionState> CreateNewSessionState(boost::asio::ip::address const& address)
+ {
+ std::shared_ptr<SessionState> state = std::make_shared<SessionState>();
+ InitAndStoreSessionState(state, address);
+ return state;
+ }
+
+protected:
+ class Thread : public NetworkThread<SessionImpl>
+ {
+ protected:
+ void SocketRemoved(std::shared_ptr<SessionImpl> const& session) override
+ {
+ if (Optional<boost::uuids::uuid> id = session->GetSessionId())
+ _service->MarkSessionInactive(*id);
+ }
+
+ private:
+ friend HttpService;
+
+ SessionService* _service;
+ };
+
+ NetworkThread<SessionImpl>* CreateThreads() const override
+ {
+ Thread* threads = new Thread[this->GetNetworkThreadCount()];
+ for (int32 i = 0; i < this->GetNetworkThreadCount(); ++i)
+ threads[i]._service = const_cast<HttpService*>(this);
+ return threads;
+ }
+
+ Asio::IoContext* _ioContext;
+ std::string _logger;
+};
+}
+
+#endif // TRINITYCORE_HTTP_SERVICE_H
diff --git a/src/common/network/Http/HttpSessionState.h b/src/common/network/Http/HttpSessionState.h
new file mode 100644
index 00000000000..3012a2efc65
--- /dev/null
+++ b/src/common/network/Http/HttpSessionState.h
@@ -0,0 +1,35 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_HTTP_SESSION_STATE_H
+#define TRINITYCORE_HTTP_SESSION_STATE_H
+
+#include "Duration.h"
+#include <boost/asio/ip/address.hpp>
+#include <boost/uuid/uuid.hpp>
+
+namespace Trinity::Net::Http
+{
+struct SessionState
+{
+ boost::uuids::uuid Id = { };
+ boost::asio::ip::address RemoteAddress;
+ TimePoint InactiveTimestamp = TimePoint::max();
+};
+}
+
+#endif // TRINITYCORE_HTTP_SESSION_STATE_H
diff --git a/src/common/network/Http/HttpSocket.h b/src/common/network/Http/HttpSocket.h
new file mode 100644
index 00000000000..2cfc3ba8ed8
--- /dev/null
+++ b/src/common/network/Http/HttpSocket.h
@@ -0,0 +1,53 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_HTTP_SOCKET_H
+#define TRINITYCORE_HTTP_SOCKET_H
+
+#include "BaseHttpSocket.h"
+#include <array>
+
+namespace Trinity::Net::Http
+{
+class Socket : public BaseSocket<Impl::BoostBeastSocketWrapper>
+{
+ using SocketBase = BaseSocket<Impl::BoostBeastSocketWrapper>;
+
+public:
+ using SocketBase::SocketBase;
+
+ Socket(Socket const& other) = delete;
+ Socket(Socket&& other) = delete;
+ Socket& operator=(Socket const& other) = delete;
+ Socket& operator=(Socket&& other) = delete;
+
+ ~Socket() = default;
+
+ void Start() override
+ {
+ std::array<std::shared_ptr<SocketConnectionInitializer>, 2> initializers =
+ { {
+ std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
+ std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
+ } };
+
+ SocketConnectionInitializer::SetupChain(initializers)->Start();
+ }
+};
+}
+
+#endif // TRINITYCORE_HTTP_SOCKET_H
diff --git a/src/common/network/Http/HttpSslSocket.h b/src/common/network/Http/HttpSslSocket.h
new file mode 100644
index 00000000000..c789cbfefaf
--- /dev/null
+++ b/src/common/network/Http/HttpSslSocket.h
@@ -0,0 +1,58 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_HTTP_SSL_SOCKET_H
+#define TRINITYCORE_HTTP_SSL_SOCKET_H
+
+#include "BaseHttpSocket.h"
+#include "SslStream.h"
+
+namespace Trinity::Net::Http
+{
+class SslSocket : public BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>
+{
+ using SocketBase = BaseSocket<SslStream<Impl::BoostBeastSocketWrapper>>;
+
+public:
+ explicit SslSocket(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext)
+ : SocketBase(std::move(socket), sslContext) { }
+
+ explicit SslSocket(boost::asio::io_context& context, boost::asio::ssl::context& sslContext)
+ : SocketBase(context, sslContext) { }
+
+ SslSocket(SslSocket const& other) = delete;
+ SslSocket(SslSocket&& other) = delete;
+ SslSocket& operator=(SslSocket const& other) = delete;
+ SslSocket& operator=(SslSocket&& other) = delete;
+
+ ~SslSocket() = default;
+
+ void Start() override
+ {
+ std::array<std::shared_ptr<SocketConnectionInitializer>, 3> initializers =
+ { {
+ std::make_shared<SslHandshakeConnectionInitializer<SocketBase>>(this),
+ std::make_shared<HttpConnectionInitializer<SocketBase>>(this),
+ std::make_shared<ReadConnectionInitializer<SocketBase>>(this),
+ } };
+
+ SocketConnectionInitializer::SetupChain(initializers)->Start();
+ }
+};
+}
+
+#endif // TRINITYCORE_HTTP_SSL_SOCKET_H
diff --git a/src/common/Asio/IpAddress.h b/src/common/network/IpAddress.h
index 7d85b0028ac..b856d7f6340 100644
--- a/src/common/Asio/IpAddress.h
+++ b/src/common/network/IpAddress.h
@@ -15,22 +15,19 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef IpAddress_h__
-#define IpAddress_h__
+#ifndef TRINITYCORE_IP_ADDRESS_H
+#define TRINITYCORE_IP_ADDRESS_H
#include "Define.h"
#include <boost/asio/ip/address.hpp>
-namespace Trinity
+namespace Trinity::Net
{
- namespace Net
- {
- using boost::asio::ip::make_address;
- using boost::asio::ip::make_address_v4;
- using boost::asio::ip::make_address_v6;
- using boost::asio::ip::v4_mapped_t::v4_mapped;
- inline uint32 address_to_uint(boost::asio::ip::address_v4 const& address) { return address.to_uint(); }
- }
+ using boost::asio::ip::make_address;
+ using boost::asio::ip::make_address_v4;
+ using boost::asio::ip::make_address_v6;
+ using boost::asio::ip::v4_mapped_t::v4_mapped;
+ inline uint32 address_to_uint(boost::asio::ip::address_v4 const& address) { return address.to_uint(); }
}
-#endif // IpAddress_h__
+#endif // TRINITYCORE_IP_ADDRESS_H
diff --git a/src/common/Asio/IpNetwork.cpp b/src/common/network/IpNetwork.cpp
index 85f176d21e4..b6235a5a947 100644
--- a/src/common/Asio/IpNetwork.cpp
+++ b/src/common/network/IpNetwork.cpp
@@ -33,7 +33,7 @@ bool IsInLocalNetwork(boost::asio::ip::address const& clientAddress)
{
if (clientAddress.is_v4())
{
- return std::any_of(LocalV4Networks.begin(), LocalV4Networks.end(), [clientAddressV4 = clientAddress.to_v4()](boost::asio::ip::network_v4 const& network)
+ return std::ranges::any_of(LocalV4Networks, [clientAddressV4 = clientAddress.to_v4()](boost::asio::ip::network_v4 const& network)
{
return IsInNetwork(network, clientAddressV4);
});
@@ -41,7 +41,7 @@ bool IsInLocalNetwork(boost::asio::ip::address const& clientAddress)
if (clientAddress.is_v6())
{
- return std::any_of(LocalV6Networks.begin(), LocalV6Networks.end(), [clientAddressV6 = clientAddress.to_v6()](boost::asio::ip::network_v6 const& network)
+ return std::ranges::any_of(LocalV6Networks, [clientAddressV6 = clientAddress.to_v6()](boost::asio::ip::network_v6 const& network)
{
return IsInNetwork(network, clientAddressV6);
});
diff --git a/src/common/Asio/IpNetwork.h b/src/common/network/IpNetwork.h
index 08db7d49f84..c05c3076598 100644
--- a/src/common/Asio/IpNetwork.h
+++ b/src/common/network/IpNetwork.h
@@ -15,8 +15,8 @@
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
-#ifndef IpNetwork_h__
-#define IpNetwork_h__
+#ifndef TRINITYCORE_IP_NETWORK_H
+#define TRINITYCORE_IP_NETWORK_H
#include "AsioHacksFwd.h"
#include "Define.h"
@@ -25,15 +25,15 @@
namespace Trinity::Net
{
-TC_COMMON_API bool IsInLocalNetwork(boost::asio::ip::address const& clientAddress);
+TC_NETWORK_API bool IsInLocalNetwork(boost::asio::ip::address const& clientAddress);
-TC_COMMON_API bool IsInNetwork(boost::asio::ip::network_v4 const& network, boost::asio::ip::address_v4 const& clientAddress);
+TC_NETWORK_API bool IsInNetwork(boost::asio::ip::network_v4 const& network, boost::asio::ip::address_v4 const& clientAddress);
-TC_COMMON_API bool IsInNetwork(boost::asio::ip::network_v6 const& network, boost::asio::ip::address_v6 const& clientAddress);
+TC_NETWORK_API bool IsInNetwork(boost::asio::ip::network_v6 const& network, boost::asio::ip::address_v6 const& clientAddress);
-TC_COMMON_API Optional<std::size_t> SelectAddressForClient(boost::asio::ip::address const& clientAddress, std::span<boost::asio::ip::address const> const& addresses);
+TC_NETWORK_API Optional<std::size_t> SelectAddressForClient(boost::asio::ip::address const& clientAddress, std::span<boost::asio::ip::address const> const& addresses);
-TC_COMMON_API void ScanLocalNetworks();
+TC_NETWORK_API void ScanLocalNetworks();
}
-#endif // IpNetwork_h__
+#endif // TRINITYCORE_IP_NETWORK_H
diff --git a/src/common/network/NetworkThread.h b/src/common/network/NetworkThread.h
new file mode 100644
index 00000000000..d16da442149
--- /dev/null
+++ b/src/common/network/NetworkThread.h
@@ -0,0 +1,179 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_NETWORK_THREAD_H
+#define TRINITYCORE_NETWORK_THREAD_H
+
+#include "Containers.h"
+#include "DeadlineTimer.h"
+#include "Define.h"
+#include "Errors.h"
+#include "IoContext.h"
+#include "Log.h"
+#include "Socket.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <atomic>
+#include <memory>
+#include <mutex>
+#include <thread>
+
+namespace Trinity::Net
+{
+template<class SocketType>
+class NetworkThread
+{
+public:
+ NetworkThread() : _connections(0), _stopped(false), _thread(nullptr), _ioContext(1),
+ _acceptSocket(_ioContext), _updateTimer(_ioContext)
+ {
+ }
+
+ NetworkThread(NetworkThread const&) = delete;
+ NetworkThread(NetworkThread&&) = delete;
+ NetworkThread& operator=(NetworkThread const&) = delete;
+ NetworkThread& operator=(NetworkThread&&) = delete;
+
+ virtual ~NetworkThread()
+ {
+ Stop();
+ if (_thread)
+ Wait();
+ }
+
+ void Stop()
+ {
+ _stopped = true;
+ _ioContext.stop();
+ }
+
+ bool Start()
+ {
+ if (_thread)
+ return false;
+
+ _thread = std::make_unique<std::thread>(&NetworkThread::Run, this);
+ return true;
+ }
+
+ void Wait()
+ {
+ ASSERT(_thread);
+
+ _thread->join();
+ _thread = nullptr;
+ }
+
+ int32 GetConnectionCount() const
+ {
+ return _connections;
+ }
+
+ void AddSocket(std::shared_ptr<SocketType> sock)
+ {
+ std::lock_guard<std::mutex> lock(_newSocketsLock);
+
+ ++_connections;
+ SocketAdded(_newSockets.emplace_back(std::move(sock)));
+ }
+
+ Trinity::Net::IoContextTcpSocket* GetSocketForAccept() { return &_acceptSocket; }
+
+protected:
+ virtual void SocketAdded(std::shared_ptr<SocketType> const& /*sock*/) { }
+ virtual void SocketRemoved(std::shared_ptr<SocketType> const& /*sock*/) { }
+
+ void AddNewSockets()
+ {
+ std::lock_guard<std::mutex> lock(_newSocketsLock);
+
+ if (_newSockets.empty())
+ return;
+
+ for (std::shared_ptr<SocketType>& sock : _newSockets)
+ {
+ if (!sock->IsOpen())
+ {
+ SocketRemoved(sock);
+ --_connections;
+ }
+ else
+ _sockets.emplace_back(std::move(sock));
+ }
+
+ _newSockets.clear();
+ }
+
+ void Run()
+ {
+ TC_LOG_DEBUG("misc", "Network Thread Starting");
+
+ _updateTimer.expires_after(1ms);
+ _updateTimer.async_wait([this](boost::system::error_code const&) { Update(); });
+ _ioContext.run();
+
+ TC_LOG_DEBUG("misc", "Network Thread exits");
+ _newSockets.clear();
+ _sockets.clear();
+ }
+
+ void Update()
+ {
+ if (_stopped)
+ return;
+
+ _updateTimer.expires_after(1ms);
+ _updateTimer.async_wait([this](boost::system::error_code const&) { Update(); });
+
+ AddNewSockets();
+
+ Trinity::Containers::EraseIf(_sockets, [this](std::shared_ptr<SocketType> const& sock)
+ {
+ if (!sock->Update())
+ {
+ if (sock->IsOpen())
+ sock->CloseSocket();
+
+ this->SocketRemoved(sock);
+
+ --this->_connections;
+ return true;
+ }
+
+ return false;
+ });
+ }
+
+private:
+ typedef std::vector<std::shared_ptr<SocketType>> SocketContainer;
+
+ std::atomic<int32> _connections;
+ std::atomic<bool> _stopped;
+
+ std::unique_ptr<std::thread> _thread;
+
+ SocketContainer _sockets;
+
+ std::mutex _newSocketsLock;
+ SocketContainer _newSockets;
+
+ Trinity::Asio::IoContext _ioContext;
+ Trinity::Net::IoContextTcpSocket _acceptSocket;
+ Trinity::Asio::DeadlineTimer _updateTimer;
+};
+}
+
+#endif // TRINITYCORE_NETWORK_THREAD_H
diff --git a/src/common/network/Resolver.cpp b/src/common/network/Resolver.cpp
new file mode 100644
index 00000000000..9bcfcf28e78
--- /dev/null
+++ b/src/common/network/Resolver.cpp
@@ -0,0 +1,47 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "Resolver.h"
+#include <algorithm>
+
+Optional<boost::asio::ip::tcp::endpoint> Trinity::Net::Resolver::Resolve(boost::asio::ip::tcp const& protocol, std::string_view host, std::string_view service)
+{
+ boost::system::error_code ec;
+ boost::asio::ip::resolver_base::flags flagsResolver = boost::asio::ip::resolver_base::all_matching;
+ boost::asio::ip::tcp::resolver::results_type results = _impl.resolve(protocol, host, service, flagsResolver, ec);
+ Optional<boost::asio::ip::tcp::endpoint> result;
+ if (!ec)
+ if (auto itr = results.begin(); itr != results.end())
+ result.emplace(itr->endpoint());
+
+ return result;
+}
+
+std::vector<boost::asio::ip::tcp::endpoint> Trinity::Net::Resolver::ResolveAll(std::string_view host, std::string_view service)
+{
+ boost::system::error_code ec;
+ boost::asio::ip::resolver_base::flags flagsResolver = boost::asio::ip::resolver_base::all_matching;
+ boost::asio::ip::tcp::resolver::results_type results = _impl.resolve(host, service, flagsResolver, ec);
+ std::vector<boost::asio::ip::tcp::endpoint> result;
+ if (!ec)
+ {
+ result.resize(results.size());
+ std::ranges::transform(results, result.begin(), [](boost::asio::ip::tcp::resolver::results_type::value_type const& entry) { return entry.endpoint(); });
+ }
+
+ return result;
+}
diff --git a/src/common/network/Resolver.h b/src/common/network/Resolver.h
new file mode 100644
index 00000000000..c7d24658aa5
--- /dev/null
+++ b/src/common/network/Resolver.h
@@ -0,0 +1,47 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_RESOLVER_H
+#define TRINITYCORE_RESOLVER_H
+
+#include "Define.h"
+#include "IoContext.h"
+#include "Optional.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <string_view>
+#include <vector>
+
+namespace Trinity::Net
+{
+/**
+ Hack to make it possible to forward declare resolver (one of its template arguments is a typedef to something super long and using nested classes)
+*/
+class TC_NETWORK_API Resolver
+{
+public:
+ explicit Resolver(Asio::IoContext& ioContext) : _impl(ioContext) { }
+
+ Optional<boost::asio::ip::tcp::endpoint> Resolve(boost::asio::ip::tcp const& protocol, std::string_view host, std::string_view service);
+
+ std::vector<boost::asio::ip::tcp::endpoint> ResolveAll(std::string_view host, std::string_view service);
+
+private:
+ boost::asio::ip::tcp::resolver _impl;
+};
+}
+
+#endif // TRINITYCORE_RESOLVER_H
diff --git a/src/common/network/Socket.h b/src/common/network/Socket.h
new file mode 100644
index 00000000000..565cc175318
--- /dev/null
+++ b/src/common/network/Socket.h
@@ -0,0 +1,362 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_SOCKET_H
+#define TRINITYCORE_SOCKET_H
+
+#include "Concepts.h"
+#include "Log.h"
+#include "MessageBuffer.h"
+#include "SocketConnectionInitializer.h"
+#include <boost/asio/io_context.hpp>
+#include <boost/asio/ip/tcp.hpp>
+#include <atomic>
+#include <memory>
+#include <queue>
+#include <type_traits>
+
+#define READ_BLOCK_SIZE 4096
+#ifdef BOOST_ASIO_HAS_IOCP
+#define TC_SOCKET_USE_IOCP
+#endif
+
+namespace Trinity::Net
+{
+using IoContextTcpSocket = boost::asio::basic_stream_socket<boost::asio::ip::tcp, boost::asio::io_context::executor_type>;
+
+enum class SocketReadCallbackResult
+{
+ KeepReading,
+ Stop
+};
+
+template <typename Callable>
+concept SocketReadCallback = Trinity::invocable_r<Callable, SocketReadCallbackResult>;
+
+template <typename SocketType>
+struct InvokeReadHandlerCallback
+{
+ SocketReadCallbackResult operator()() const
+ {
+ return this->Socket->ReadHandler();
+ }
+
+ SocketType* Socket;
+};
+
+template <typename SocketType>
+struct ReadConnectionInitializer final : SocketConnectionInitializer
+{
+ explicit ReadConnectionInitializer(SocketType* socket) : ReadCallback({ .Socket = socket }) { }
+
+ void Start() override
+ {
+ ReadCallback.Socket->AsyncRead(std::move(ReadCallback));
+
+ if (this->next)
+ this->next->Start();
+ }
+
+ InvokeReadHandlerCallback<SocketType> ReadCallback;
+};
+
+/**
+ @class Socket
+
+ Base async socket implementation
+
+ @tparam Stream stream type used for operations on socket
+ Stream must implement the following methods:
+
+ void close(boost::system::error_code& error);
+
+ void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError);
+
+ template<typename MutableBufferSequence, typename ReadHandlerType>
+ void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler);
+
+ template<typename ConstBufferSequence, typename WriteHandlerType>
+ void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler);
+
+ template<typename ConstBufferSequence>
+ std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error);
+
+ template<typename WaitHandlerType>
+ void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler);
+
+ template<typename SettableSocketOption>
+ void set_option(SettableSocketOption const& option, boost::system::error_code& error);
+
+ tcp::socket::endpoint_type remote_endpoint() const;
+*/
+template<class Stream = IoContextTcpSocket>
+class Socket : public std::enable_shared_from_this<Socket<Stream>>
+{
+public:
+ template<typename... Args>
+ explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...),
+ _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open)
+ {
+ }
+
+ template<typename... Args>
+ explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward<Args>(args)...), _openState(OpenState_Closed)
+ {
+ }
+
+ Socket(Socket const& other) = delete;
+ Socket(Socket&& other) = delete;
+ Socket& operator=(Socket const& other) = delete;
+ Socket& operator=(Socket&& other) = delete;
+
+ virtual ~Socket()
+ {
+ _openState = OpenState_Closed;
+ boost::system::error_code error;
+ _socket.close(error);
+ }
+
+ virtual void Start() { }
+
+ virtual bool Update()
+ {
+ if (_openState == OpenState_Closed)
+ return false;
+
+#ifndef TC_SOCKET_USE_IOCP
+ if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open))
+ return true;
+
+ for (; HandleQueue();)
+ ;
+#endif
+
+ return true;
+ }
+
+ boost::asio::ip::address const& GetRemoteIpAddress() const
+ {
+ return _remoteAddress;
+ }
+
+ uint16 GetRemotePort() const
+ {
+ return _remotePort;
+ }
+
+ template <SocketReadCallback Callback>
+ void AsyncRead(Callback&& callback)
+ {
+ if (!IsOpen())
+ return;
+
+ _readBuffer.Normalize();
+ _readBuffer.EnsureFreeSpace();
+ _socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
+ [self = this->shared_from_this(), callback = std::forward<Callback>(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable
+ {
+ if (self->ReadHandlerInternal(error, transferredBytes))
+ if (callback() == SocketReadCallbackResult::KeepReading)
+ self->AsyncRead(std::forward<Callback>(callback));
+ });
+ }
+
+ void QueuePacket(MessageBuffer&& buffer)
+ {
+ _writeQueue.push(std::move(buffer));
+
+#ifdef TC_SOCKET_USE_IOCP
+ AsyncProcessQueue();
+#endif
+ }
+
+ bool IsOpen() const { return _openState == OpenState_Open; }
+
+ void CloseSocket()
+ {
+ if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0)
+ return;
+
+ boost::system::error_code shutdownError;
+ _socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError);
+ if (shutdownError)
+ TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress().to_string(),
+ shutdownError.value(), shutdownError.message());
+
+ this->OnClose();
+ }
+
+ /// Marks the socket for closing after write buffer becomes empty
+ void DelayedCloseSocket()
+ {
+ if (_openState.fetch_or(OpenState_Closing) != 0)
+ return;
+
+ if (_writeQueue.empty())
+ CloseSocket();
+ }
+
+ MessageBuffer& GetReadBuffer() { return _readBuffer; }
+
+ Stream& underlying_stream()
+ {
+ return _socket;
+ }
+
+protected:
+ virtual void OnClose() { }
+
+ virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; }
+
+ bool AsyncProcessQueue()
+ {
+ if (_isWritingAsync)
+ return false;
+
+ _isWritingAsync = true;
+
+#ifdef TC_SOCKET_USE_IOCP
+ MessageBuffer& buffer = _writeQueue.front();
+ _socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()),
+ [self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes)
+ {
+ self->WriteHandler(error, transferedBytes);
+ });
+#else
+ _socket.async_wait(boost::asio::socket_base::wait_type::wait_write,
+ [self = this->shared_from_this()](boost::system::error_code const& error)
+ {
+ self->WriteHandlerWrapper(error);
+ });
+#endif
+
+ return false;
+ }
+
+ void SetNoDelay(bool enable)
+ {
+ boost::system::error_code err;
+ _socket.set_option(boost::asio::ip::tcp::no_delay(enable), err);
+ if (err)
+ TC_LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for {} - {} ({})",
+ GetRemoteIpAddress().to_string(), err.value(), err.message());
+ }
+
+private:
+ bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes)
+ {
+ if (error)
+ {
+ CloseSocket();
+ return false;
+ }
+
+ _readBuffer.WriteCompleted(transferredBytes);
+ return IsOpen();
+ }
+
+#ifdef TC_SOCKET_USE_IOCP
+
+ void WriteHandler(boost::system::error_code const& error, std::size_t transferedBytes)
+ {
+ if (!error)
+ {
+ _isWritingAsync = false;
+ _writeQueue.front().ReadCompleted(transferedBytes);
+ if (!_writeQueue.front().GetActiveSize())
+ _writeQueue.pop();
+
+ if (!_writeQueue.empty())
+ AsyncProcessQueue();
+ else if (_openState == OpenState_Closing)
+ CloseSocket();
+ }
+ else
+ CloseSocket();
+ }
+
+#else
+
+ void WriteHandlerWrapper(boost::system::error_code const& /*error*/)
+ {
+ _isWritingAsync = false;
+ HandleQueue();
+ }
+
+ bool HandleQueue()
+ {
+ if (_writeQueue.empty())
+ return false;
+
+ MessageBuffer& queuedMessage = _writeQueue.front();
+
+ std::size_t bytesToSend = queuedMessage.GetActiveSize();
+
+ boost::system::error_code error;
+ std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error);
+
+ if (error)
+ {
+ if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
+ return AsyncProcessQueue();
+
+ _writeQueue.pop();
+ if (_openState == OpenState_Closing && _writeQueue.empty())
+ CloseSocket();
+ return false;
+ }
+ else if (bytesSent == 0)
+ {
+ _writeQueue.pop();
+ if (_openState == OpenState_Closing && _writeQueue.empty())
+ CloseSocket();
+ return false;
+ }
+ else if (bytesSent < bytesToSend) // now n > 0
+ {
+ queuedMessage.ReadCompleted(bytesSent);
+ return AsyncProcessQueue();
+ }
+
+ _writeQueue.pop();
+ if (_openState == OpenState_Closing && _writeQueue.empty())
+ CloseSocket();
+ return !_writeQueue.empty();
+ }
+
+#endif
+
+ Stream _socket;
+
+ boost::asio::ip::address _remoteAddress;
+ uint16 _remotePort = 0;
+
+ MessageBuffer _readBuffer = MessageBuffer(READ_BLOCK_SIZE);
+ std::queue<MessageBuffer> _writeQueue;
+
+ // Socket open state "enum" (not enum to enable integral std::atomic api)
+ static constexpr uint8 OpenState_Open = 0x0;
+ static constexpr uint8 OpenState_Closing = 0x1; ///< Transition to Closed state after sending all queued data
+ static constexpr uint8 OpenState_Closed = 0x2;
+
+ std::atomic<uint8> _openState;
+
+ bool _isWritingAsync = false;
+};
+}
+
+#endif // TRINITYCORE_SOCKET_H
diff --git a/src/common/network/SocketMgr.h b/src/common/network/SocketMgr.h
new file mode 100644
index 00000000000..07252355308
--- /dev/null
+++ b/src/common/network/SocketMgr.h
@@ -0,0 +1,146 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_SOCKET_MGR_H
+#define TRINITYCORE_SOCKET_MGR_H
+
+#include "AsyncAcceptor.h"
+#include "Errors.h"
+#include "NetworkThread.h"
+#include "Socket.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <memory>
+
+namespace Trinity::Net
+{
+template<class SocketType>
+class SocketMgr
+{
+public:
+ SocketMgr(SocketMgr const&) = delete;
+ SocketMgr(SocketMgr&&) = delete;
+ SocketMgr& operator=(SocketMgr const&) = delete;
+ SocketMgr& operator=(SocketMgr&&) = delete;
+
+ virtual ~SocketMgr()
+ {
+ ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction");
+ }
+
+ virtual bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
+ {
+ ASSERT(threadCount > 0);
+
+ std::unique_ptr<AsyncAcceptor> acceptor = nullptr;
+ try
+ {
+ acceptor = std::make_unique<AsyncAcceptor>(ioContext, bindIp, port);
+ }
+ catch (boost::system::system_error const& err)
+ {
+ TC_LOG_ERROR("network", "Exception caught in SocketMgr.StartNetwork ({}:{}): {}", bindIp, port, err.what());
+ return false;
+ }
+
+ if (!acceptor->Bind())
+ {
+ TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor");
+ return false;
+ }
+
+ _acceptor = std::move(acceptor);
+ _threadCount = threadCount;
+ _threads.reset(CreateThreads());
+
+ ASSERT(_threads);
+
+ for (int32 i = 0; i < _threadCount; ++i)
+ _threads[i].Start();
+
+ _acceptor->SetSocketFactory([this]() { return GetSocketForAccept(); });
+
+ return true;
+ }
+
+ virtual void StopNetwork()
+ {
+ _acceptor->Close();
+
+ for (int32 i = 0; i < _threadCount; ++i)
+ _threads[i].Stop();
+
+ Wait();
+
+ _acceptor = nullptr;
+ _threads = nullptr;
+ _threadCount = 0;
+ }
+
+ void Wait()
+ {
+ for (int32 i = 0; i < _threadCount; ++i)
+ _threads[i].Wait();
+ }
+
+ virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex)
+ {
+ try
+ {
+ std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
+ newSocket->Start();
+
+ _threads[threadIndex].AddSocket(newSocket);
+ }
+ catch (boost::system::system_error const& err)
+ {
+ TC_LOG_WARN("network", "Failed to retrieve client's remote address {}", err.what());
+ }
+ }
+
+ int32 GetNetworkThreadCount() const { return _threadCount; }
+
+ uint32 SelectThreadWithMinConnections() const
+ {
+ uint32 min = 0;
+
+ for (int32 i = 1; i < _threadCount; ++i)
+ if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
+ min = i;
+
+ return min;
+ }
+
+ std::pair<IoContextTcpSocket*, uint32> GetSocketForAccept()
+ {
+ uint32 threadIndex = SelectThreadWithMinConnections();
+ return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
+ }
+
+protected:
+ SocketMgr() : _threadCount(0)
+ {
+ }
+
+ virtual NetworkThread<SocketType>* CreateThreads() const = 0;
+
+ std::unique_ptr<AsyncAcceptor> _acceptor;
+ std::unique_ptr<NetworkThread<SocketType>[]> _threads;
+ int32 _threadCount;
+};
+}
+
+#endif // TRINITYCORE_SOCKET_MGR_H
diff --git a/src/common/network/SslStream.h b/src/common/network/SslStream.h
new file mode 100644
index 00000000000..2cced44e5ff
--- /dev/null
+++ b/src/common/network/SslStream.h
@@ -0,0 +1,131 @@
+/*
+ * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License as published by the
+ * Free Software Foundation; either version 2 of the License, or (at your
+ * option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef TRINITYCORE_SSL_STREAM_H
+#define TRINITYCORE_SSL_STREAM_H
+
+#include "SocketConnectionInitializer.h"
+#include <boost/asio/ip/tcp.hpp>
+#include <boost/asio/ssl/stream.hpp>
+#include <boost/system/error_code.hpp>
+
+namespace Trinity::Net
+{
+template <typename SocketImpl>
+struct SslHandshakeConnectionInitializer final : SocketConnectionInitializer
+{
+ explicit SslHandshakeConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
+
+ void Start() override
+ {
+ _socket->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
+ [socketRef = _socket->weak_from_this(), self = this->shared_from_this()](boost::system::error_code const& error)
+ {
+ std::shared_ptr<SocketImpl> socket = static_pointer_cast<SocketImpl>(socketRef.lock());
+ if (!socket)
+ return;
+
+ if (error)
+ {
+ TC_LOG_ERROR("session", "{} SSL Handshake failed {}", socket->GetClientInfo(), error.message());
+ socket->CloseSocket();
+ return;
+ }
+
+ if (self->next)
+ self->next->Start();
+ });
+ }
+
+private:
+ SocketImpl* _socket;
+};
+
+template<class WrappedStream = IoContextTcpSocket>
+class SslStream
+{
+public:
+ explicit SslStream(IoContextTcpSocket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext)
+ {
+ _sslSocket.set_verify_mode(boost::asio::ssl::verify_none);
+ }
+
+ explicit SslStream(boost::asio::io_context& context, boost::asio::ssl::context& sslContext) : _sslSocket(context, sslContext)
+ {
+ _sslSocket.set_verify_mode(boost::asio::ssl::verify_none);
+ }
+
+ // adapting tcp::socket api
+ void close(boost::system::error_code& error)
+ {
+ _sslSocket.next_layer().close(error);
+ }
+
+ void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
+ {
+ _sslSocket.shutdown(shutdownError);
+ _sslSocket.next_layer().shutdown(what, shutdownError);
+ }
+
+ template<typename MutableBufferSequence, typename ReadHandlerType>
+ void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler)
+ {
+ _sslSocket.async_read_some(buffers, std::forward<ReadHandlerType>(handler));
+ }
+
+ template<typename ConstBufferSequence, typename WriteHandlerType>
+ void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler)
+ {
+ _sslSocket.async_write_some(buffers, std::forward<WriteHandlerType>(handler));
+ }
+
+ template<typename ConstBufferSequence>
+ std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error)
+ {
+ return _sslSocket.write_some(buffers, error);
+ }
+
+ template<typename WaitHandlerType>
+ void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
+ {
+ _sslSocket.next_layer().async_wait(type, std::forward<WaitHandlerType>(handler));
+ }
+
+ template<typename SettableSocketOption>
+ void set_option(SettableSocketOption const& option, boost::system::error_code& error)
+ {
+ _sslSocket.next_layer().set_option(option, error);
+ }
+
+ IoContextTcpSocket::endpoint_type remote_endpoint() const
+ {
+ return _sslSocket.next_layer().remote_endpoint();
+ }
+
+ // ssl api
+ template<typename HandshakeHandlerType>
+ void async_handshake(boost::asio::ssl::stream_base::handshake_type type, HandshakeHandlerType&& handler)
+ {
+ _sslSocket.async_handshake(type, std::forward<HandshakeHandlerType>(handler));
+ }
+
+protected:
+ boost::asio::ssl::stream<WrappedStream> _sslSocket;
+};
+}
+
+#endif // TRINITYCORE_SSL_STREAM_H