diff options
| author | Shauren <shauren.trinity@gmail.com> | 2023-12-17 23:21:10 +0100 |
|---|---|---|
| committer | funjoker <funjoker109@gmail.com> | 2023-12-19 14:15:00 +0100 |
| commit | 123f2ad97e9cc53504b9e0faa1c180f4d979126f (patch) | |
| tree | 801b4a4d6604ba70d60e4617567e5b057141acd3 /src | |
| parent | 5d6896de598283e00b70885868423eb72e9a53ca (diff) | |
Core/Bnet: Rewrite LoginRESTService using boost::beast instead of gsoap as http backend and extract generic http code to be reusable elsewhere
(cherry picked from commit acb5fbd48b5bd911dd0da6016a3d86d4c64724b6)
Diffstat (limited to 'src')
23 files changed, 1434 insertions, 427 deletions
diff --git a/src/common/Cryptography/Authentication/SRP6.cpp b/src/common/Cryptography/Authentication/SRP6.cpp index a2266f35f71..c172bbc1d3d 100644 --- a/src/common/Cryptography/Authentication/SRP6.cpp +++ b/src/common/Cryptography/Authentication/SRP6.cpp @@ -24,12 +24,7 @@ using SHA1 = Trinity::Crypto::SHA1; using SRP6 = Trinity::Crypto::SRP6; -/*static*/ std::array<uint8, 1> const SRP6::g = []() -{ - std::array<uint8, 1> g_temp; - g_temp[0] = 7; - return g_temp; -}(); +/*static*/ std::array<uint8, 1> const SRP6::g = { 7 }; /*static*/ std::array<uint8, 32> const SRP6::N = HexStrToByteArray<32>("894B645E89E1535BBDAD5B8B290650530801B18EBFBF5E8FAB3C82872A3E9BB7", true); /*static*/ BigNumber const SRP6::_g(SRP6::g); /*static*/ BigNumber const SRP6::_N(N); diff --git a/src/common/Utilities/Concepts.h b/src/common/Utilities/Concepts.h new file mode 100644 index 00000000000..4cfa52c029c --- /dev/null +++ b/src/common/Utilities/Concepts.h @@ -0,0 +1,33 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_CONCEPTS_H +#define TRINITYCORE_CONCEPTS_H + +#include <concepts> +#include <functional> // std::invoke + +namespace Trinity +{ +template <typename Callable, typename R, typename... Args> +concept invocable_r = requires(Callable && callable, Args&&... args) +{ + { std::invoke(static_cast<Callable&&>(callable), static_cast<Args&&>(args)...) } -> std::convertible_to<R>; +}; +} + +#endif // TRINITYCORE_CONCEPTS_H diff --git a/src/server/bnetserver/Main.cpp b/src/server/bnetserver/Main.cpp index e1168a1a90b..6d667cd48b1 100644 --- a/src/server/bnetserver/Main.cpp +++ b/src/server/bnetserver/Main.cpp @@ -200,21 +200,29 @@ int main(int argc, char** argv) Trinity::Net::ScanLocalNetworks(); - // Start the listening port (acceptor) for auth connections - int32 bnport = sConfigMgr->GetIntDefault("BattlenetPort", 1119); - if (bnport < 0 || bnport > 0xFFFF) + std::string httpBindIp = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0"); + int32 httpPort = sConfigMgr->GetIntDefault("LoginREST.Port", 8081); + if (httpPort <= 0 || httpPort > 0xFFFF) { - TC_LOG_ERROR("server.bnetserver", "Specified battle.net port ({}) out of allowed range (1-65535)", bnport); + TC_LOG_ERROR("server.bnetserver", "Specified login service port ({}) out of allowed range (1-65535)", httpPort); return 1; } - if (!sLoginService.Start(ioContext.get())) + if (!sLoginService.StartNetwork(*ioContext, httpBindIp, httpPort)) { TC_LOG_ERROR("server.bnetserver", "Failed to initialize login service"); return 1; } - std::shared_ptr<void> sLoginServiceHandle(nullptr, [](void*) { sLoginService.Stop(); }); + std::shared_ptr<void> sLoginServiceHandle(nullptr, [](void*) { sLoginService.StopNetwork(); }); + + // Start the listening port (acceptor) for auth connections + int32 bnport = sConfigMgr->GetIntDefault("BattlenetPort", 1119); + if (bnport <= 0 || bnport > 0xFFFF) + { + TC_LOG_ERROR("server.bnetserver", "Specified battle.net port ({}) out of allowed range (1-65535)", bnport); + return 1; + } // Get the list of realms for the server sRealmList->Initialize(*ioContext, sConfigMgr->GetIntDefault("RealmsStateUpdateDelay", 10)); diff --git a/src/server/bnetserver/REST/LoginHttpSession.cpp b/src/server/bnetserver/REST/LoginHttpSession.cpp new file mode 100644 index 00000000000..95112cb8836 --- /dev/null +++ b/src/server/bnetserver/REST/LoginHttpSession.cpp @@ -0,0 +1,118 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "LoginHttpSession.h" +#include "DatabaseEnv.h" +#include "LoginRESTService.h" +#include "SslContext.h" +#include "Util.h" + +namespace Battlenet +{ +LoginHttpSession::LoginHttpSession(boost::asio::ip::tcp::socket&& socket) + : SslSocket(std::move(socket), SslContext::instance()) +{ +} + +LoginHttpSession::~LoginHttpSession() = default; + +void LoginHttpSession::Start() +{ + std::string ip_address = GetRemoteIpAddress().to_string(); + TC_LOG_TRACE("server.http.session", "{} Accepted connection", GetClientInfo()); + + // Verify that this IP is not in the ip_banned table + LoginDatabase.Execute(LoginDatabase.GetPreparedStatement(LOGIN_DEL_EXPIRED_IP_BANS)); + + LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO); + stmt->setString(0, ip_address); + + _queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt) + .WithPreparedCallback([sess = shared_from_this()](PreparedQueryResult result) { sess->CheckIpCallback(std::move(result)); })); +} + +void LoginHttpSession::CheckIpCallback(PreparedQueryResult result) +{ + if (result) + { + bool banned = false; + do + { + Field* fields = result->Fetch(); + if (fields[0].GetUInt64() != 0) + banned = true; + + } while (result->NextRow()); + + if (banned) + { + TC_LOG_DEBUG("server.http.session", "{} tries to log in using banned IP!", GetClientInfo()); + CloseSocket(); + return; + } + } + + AsyncHandshake(); +} + +Trinity::Net::Http::RequestHandlerResult LoginHttpSession::RequestHandler(Trinity::Net::Http::RequestContext& context) +{ + return sLoginService.HandleRequest(shared_from_this(), context); +} + +std::shared_ptr<Trinity::Net::Http::SessionState> LoginHttpSession::ObtainSessionState(Trinity::Net::Http::RequestContext& context) const +{ + using namespace std::string_literals; + + std::shared_ptr<Trinity::Net::Http::SessionState> state; + + auto cookieItr = context.request.find(boost::beast::http::field::cookie); + if (cookieItr != context.request.end()) + { + std::vector<std::string_view> cookies = Trinity::Tokenize(Trinity::Net::Http::ToStdStringView(cookieItr->value()), ';', false); + std::size_t eq = 0; + auto sessionIdItr = std::find_if(cookies.begin(), cookies.end(), [&](std::string_view cookie) + { + std::string_view name = cookie; + eq = cookie.find('='); + if (eq != std::string_view::npos) + name = cookie.substr(0, eq); + + return name == SESSION_ID_COOKIE; + }); + if (sessionIdItr != cookies.end()) + { + std::string_view value = sessionIdItr->substr(eq + 1); + state = sLoginService.FindAndRefreshSessionState(value, GetRemoteIpAddress()); + } + } + + if (!state) + { + state = sLoginService.CreateNewSessionState(GetRemoteIpAddress()); + + std::string_view host = Trinity::Net::Http::ToStdStringView(context.request[boost::beast::http::field::host]); + if (std::size_t port = host.find(':'); port != std::string_view::npos) + host.remove_suffix(host.length() - port); + + context.response.insert(boost::beast::http::field::set_cookie, Trinity::StringFormat("{}={}; Path=/bnetserver; Domain={}; Secure; HttpOnly; SameSite=None", + SESSION_ID_COOKIE, boost::uuids::to_string(state->Id), host)); + } + + return state; +} +} diff --git a/src/server/bnetserver/REST/LoginHttpSession.h b/src/server/bnetserver/REST/LoginHttpSession.h new file mode 100644 index 00000000000..17c94b55bda --- /dev/null +++ b/src/server/bnetserver/REST/LoginHttpSession.h @@ -0,0 +1,43 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_LOGIN_HTTP_SESSION_H +#define TRINITYCORE_LOGIN_HTTP_SESSION_H + +#include "HttpSslSocket.h" + +namespace Battlenet +{ +class LoginHttpSession : public Trinity::Net::Http::SslSocket<LoginHttpSession> +{ +public: + static constexpr std::string_view SESSION_ID_COOKIE = "JSESSIONID"; + + explicit LoginHttpSession(boost::asio::ip::tcp::socket&& socket); + ~LoginHttpSession(); + + void Start() override; + + void CheckIpCallback(PreparedQueryResult result); + + Trinity::Net::Http::RequestHandlerResult RequestHandler(Trinity::Net::Http::RequestContext& context) override; + +protected: + std::shared_ptr<Trinity::Net::Http::SessionState> ObtainSessionState(Trinity::Net::Http::RequestContext& context) const override; +}; +} +#endif // TRINITYCORE_LOGIN_HTTP_SESSION_H diff --git a/src/server/bnetserver/REST/LoginRESTService.cpp b/src/server/bnetserver/REST/LoginRESTService.cpp index 6141dc702f9..602a38dceed 100644 --- a/src/server/bnetserver/REST/LoginRESTService.cpp +++ b/src/server/bnetserver/REST/LoginRESTService.cpp @@ -16,71 +16,63 @@ */ #include "LoginRESTService.h" +#include "Base64.h" #include "Configuration/Config.h" #include "CryptoHash.h" #include "CryptoRandom.h" #include "DatabaseEnv.h" -#include "Errors.h" #include "IpNetwork.h" +#include "IteratorPair.h" #include "ProtobufJSON.h" #include "Resolver.h" -#include "SslContext.h" #include "Util.h" -#include "httpget.h" -#include "httppost.h" -#include "soapH.h" +#include <boost/uuid/string_generator.hpp> +#include <fmt/chrono.h> -int ns1__executeCommand(soap*, char*, char**) { return SOAP_OK; } +namespace Battlenet +{ +LoginRESTService& LoginRESTService::Instance() +{ + static LoginRESTService instance; + return instance; +} -class AsyncRequest +bool LoginRESTService::StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount) { -public: - AsyncRequest(soap const& server) : _client(server), _responseStatus(0) { } + if (!HttpService::StartNetwork(ioContext, bindIp, port, threadCount)) + return false; - AsyncRequest(AsyncRequest const&) = delete; - AsyncRequest& operator=(AsyncRequest const&) = delete; - AsyncRequest(AsyncRequest&&) = default; - AsyncRequest& operator=(AsyncRequest&&) = default; + using Trinity::Net::Http::RequestHandlerFlag; - bool InvokeIfReady() + RegisterHandler(boost::beast::http::verb::get, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { - ASSERT(_callback); - return _callback->InvokeIfReady(); - } - - soap* GetClient() { return &_client; } - void SetCallback(std::unique_ptr<QueryCallback> callback) { _callback = std::move(callback); } - int32 GetResponseStatus() const { return _responseStatus; } - void SetResponseStatus(int32 responseStatus) { _responseStatus = responseStatus; } + return HandleGetForm(std::move(session), context); + }); -private: - soap _client; - std::unique_ptr<QueryCallback> _callback; - int32 _responseStatus; -}; + RegisterHandler(boost::beast::http::verb::get, "/bnetserver/gameAccounts/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) + { + return HandleGetGameAccounts(std::move(session), context); + }); -int32 handle_get_plugin(soap* soapClient) -{ - return sLoginService.HandleHttpRequest(soapClient, "GET", sLoginService._getHandlers); -} + RegisterHandler(boost::beast::http::verb::get, "/bnetserver/portal/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) + { + return HandleGetPortal(std::move(session), context); + }); -int32 handle_post_plugin(soap* soapClient) -{ - return sLoginService.HandleHttpRequest(soapClient, "POST", sLoginService._postHandlers); -} + RegisterHandler(boost::beast::http::verb::post, "/bnetserver/login/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) + { + return HandlePostLogin(std::move(session), context); + }, RequestHandlerFlag::DoNotLogRequestContent); -bool LoginRESTService::Start(Trinity::Asio::IoContext* ioContext) -{ - _ioContext = ioContext; - _bindIP = sConfigMgr->GetStringDefault("BindIP", "0.0.0.0"); - _port = sConfigMgr->GetIntDefault("LoginREST.Port", 8081); - if (_port < 0 || _port > 0xFFFF) + RegisterHandler(boost::beast::http::verb::post, "/bnetserver/refreshLoginTicket/", [this](std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { - TC_LOG_ERROR("server.rest", "Specified login service port ({}) out of allowed range (1-65535), defaulting to 8081", _port); - _port = 8081; - } + return HandlePostRefreshLoginTicket(std::move(session), context); + }); - Trinity::Asio::Resolver resolver(*ioContext); + _bindIP = bindIp; + _port = port; + + Trinity::Asio::Resolver resolver(ioContext); _hostnames[0] = sConfigMgr->GetStringDefault("LoginREST.ExternalAddress", "127.0.0.1"); Optional<boost::asio::ip::tcp::endpoint> externalAddress = resolver.Resolve(boost::asio::ip::tcp::v4(), _hostnames[0], std::to_string(_port)); @@ -103,8 +95,8 @@ bool LoginRESTService::Start(Trinity::Asio::IoContext* ioContext) _addresses[1] = localAddress->address(); // set up form inputs - Battlenet::JSON::Login::FormInput* input; - _formInputs.set_type(Battlenet::JSON::Login::LOGIN_FORM); + JSON::Login::FormInput* input; + _formInputs.set_type(JSON::Login::LOGIN_FORM); input = _formInputs.add_inputs(); input->set_input_id("account_name"); input->set_type("text"); @@ -124,16 +116,10 @@ bool LoginRESTService::Start(Trinity::Asio::IoContext* ioContext) _loginTicketDuration = sConfigMgr->GetIntDefault("LoginREST.TicketDuration", 3600); - _thread = std::thread(&LoginRESTService::Run, this); + _acceptor->AsyncAcceptWithCallback<&LoginRESTService::OnSocketAccept>(); return true; } -void LoginRESTService::Stop() -{ - _stopped = true; - _thread.join(); -} - std::string const& LoginRESTService::GetHostnameForClient(boost::asio::ip::address const& address) const { if (auto addressIndex = Trinity::Net::SelectAddressForClient(address, _addresses)) @@ -145,117 +131,53 @@ std::string const& LoginRESTService::GetHostnameForClient(boost::asio::ip::addre return _hostnames[0]; } -void LoginRESTService::Run() +std::string LoginRESTService::ExtractAuthorization(HttpRequest const& request) { - soap soapServer(SOAP_C_UTFSTRING, SOAP_C_UTFSTRING); - - // check every 3 seconds if world ended - soapServer.accept_timeout = 3; - soapServer.recv_timeout = 5; - soapServer.send_timeout = 5; - if (!soap_valid_socket(soap_bind(&soapServer, _bindIP.c_str(), _port, 100))) - { - TC_LOG_ERROR("server.rest", "Couldn't bind to {}:{}", _bindIP, _port); - return; - } - - TC_LOG_INFO("server.rest", "Login service bound to http://{}:{}", _bindIP, _port); - - http_post_handlers handlers[] = - { - { "application/json;charset=utf-8", handle_post_plugin }, - { "application/json", handle_post_plugin }, - { nullptr, nullptr } - }; + using namespace std::string_view_literals; - _getHandlers["/bnetserver/login/"] = &LoginRESTService::HandleGetForm; - _getHandlers["/bnetserver/gameAccounts/"] = &LoginRESTService::HandleGetGameAccounts; - _getHandlers["/bnetserver/portal/"] = &LoginRESTService::HandleGetPortal; + std::string ticket; + auto itr = request.find(boost::beast::http::field::authorization); + if (itr == request.end()) + return ticket; - _postHandlers["/bnetserver/login/"] = &LoginRESTService::HandlePostLogin; - _postHandlers["/bnetserver/refreshLoginTicket/"] = &LoginRESTService::HandlePostRefreshLoginTicket; + std::string_view authorization = Trinity::Net::Http::ToStdStringView(itr->value()); + constexpr std::string_view BASIC_PREFIX = "Basic "sv; - soap_register_plugin_arg(&soapServer, &http_get, (void*)&handle_get_plugin); - soap_register_plugin_arg(&soapServer, &http_post, handlers); - soap_register_plugin_arg(&soapServer, &ContentTypePlugin::Init, (void*)"application/json;charset=utf-8"); - soap_register_plugin_arg(&soapServer, &ResponseCodePlugin::Init, nullptr); + if (authorization.starts_with(BASIC_PREFIX)) + authorization.remove_prefix(BASIC_PREFIX.length()); - // Use our already ready ssl context - soapServer.ctx = Battlenet::SslContext::instance().native_handle(); - soapServer.ssl_flags = SOAP_SSL_RSA; + Optional<std::vector<uint8>> decoded = Trinity::Encoding::Base64::Decode(authorization); + if (!decoded) + return ticket; - while (!_stopped) - { - if (!soap_valid_socket(soap_accept(&soapServer))) - continue; // ran into an accept timeout - - std::shared_ptr<AsyncRequest> soapClient = std::make_shared<AsyncRequest>(soapServer); - if (soap_ssl_accept(soapClient->GetClient()) != SOAP_OK) - { - TC_LOG_DEBUG("server.rest", "Failed SSL handshake from IP={}", boost::asio::ip::address_v4(soapClient->GetClient()->ip).to_string()); - continue; - } - - TC_LOG_DEBUG("server.rest", "Accepted connection from IP={}", boost::asio::ip::address_v4(soapClient->GetClient()->ip).to_string()); - - Trinity::Asio::post(*_ioContext, [soapClient]() - { - soapClient->GetClient()->user = (void*)&soapClient; // this allows us to make a copy of pointer inside GET/POST handlers to increment reference count - soap_begin(soapClient->GetClient()); - soap_begin_recv(soapClient->GetClient()); - }); - } - - // and release the context handle here - soap does not own it so it should not free it on exit - soapServer.ctx = nullptr; - - TC_LOG_INFO("server.rest", "Login service exiting..."); -} - -int32 LoginRESTService::HandleHttpRequest(soap* soapClient, char const* method, HttpMethodHandlerMap const& handlers) -{ - TC_LOG_DEBUG("server.rest", "[{}:{}] Handling {} request path=\"{}\"", - boost::asio::ip::address_v4(soapClient->ip).to_string(), soapClient->port, method, soapClient->path); - - size_t pathLength = strlen(soapClient->path); - if (char const* queryPart = strchr(soapClient->path, '?')) - pathLength = queryPart - soapClient->path; + std::string_view decodedHeader(reinterpret_cast<char const*>(decoded->data()), decoded->size()); - auto handler = handlers.find(std::string{ soapClient->path, pathLength }); - if (handler != handlers.end()) - { - int32 status = (this->*handler->second)(*reinterpret_cast<std::shared_ptr<AsyncRequest>*>(soapClient->user)); - if (status != SOAP_OK) - { - ResponseCodePlugin::GetForClient(soapClient)->ErrorCode = status; - return SendResponse(soapClient, Battlenet::JSON::Login::ErrorResponse()); - } + if (std::size_t ticketEnd = decodedHeader.find(':'); ticketEnd != std::string_view::npos) + decodedHeader.remove_suffix(decodedHeader.length() - ticketEnd); - return SOAP_OK; - } - - ResponseCodePlugin::GetForClient(soapClient)->ErrorCode = 404; - return SendResponse(soapClient, Battlenet::JSON::Login::ErrorResponse()); + ticket = decodedHeader; + return ticket; } -int32 LoginRESTService::HandleGetForm(std::shared_ptr<AsyncRequest> request) +LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetForm(std::shared_ptr<LoginHttpSession> /*session*/, HttpRequestContext& context) { - return SendResponse(request->GetClient(), _formInputs); + context.response.set(boost::beast::http::field::content_type, "application/json;charset=utf-8"); + context.response.body() = ::JSON::Serialize(_formInputs); + return RequestHandlerResult::Handled; } -int32 LoginRESTService::HandleGetGameAccounts(std::shared_ptr<AsyncRequest> request) +LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetGameAccounts(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { - if (!request->GetClient()->userid) - return 401; - - request->SetCallback(std::make_unique<QueryCallback>(LoginDatabase.AsyncQuery([&] { - LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_BNET_GAME_ACCOUNT_LIST); - stmt->setString(0, request->GetClient()->userid); - return stmt; - }()) - .WithPreparedCallback([this, request](PreparedQueryResult result) + std::string ticket = ExtractAuthorization(context.request); + if (ticket.empty()) + return HandleUnauthorized(std::move(session), context); + + LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_BNET_GAME_ACCOUNT_LIST); + stmt->setString(0, ticket); + session->QueueQuery(LoginDatabase.AsyncQuery(stmt) + .WithPreparedCallback([session, context = std::move(context)](PreparedQueryResult result) mutable { - Battlenet::JSON::Login::GameAccountList response; + JSON::Login::GameAccountList gameAccounts; if (result) { auto formatDisplayName = [](char const* name) -> std::string @@ -270,7 +192,7 @@ int32 LoginRESTService::HandleGetGameAccounts(std::shared_ptr<AsyncRequest> requ do { Field* fields = result->Fetch(); - Battlenet::JSON::Login::GameAccountInfo* gameAccount = response.add_game_accounts(); + JSON::Login::GameAccountInfo* gameAccount = gameAccounts.add_game_accounts(); gameAccount->set_display_name(formatDisplayName(fields[0].GetCString())); gameAccount->set_expansion(fields[1].GetUInt8()); if (!fields[2].IsNull()) @@ -285,40 +207,37 @@ int32 LoginRESTService::HandleGetGameAccounts(std::shared_ptr<AsyncRequest> requ } while (result->NextRow()); } - SendResponse(request->GetClient(), response); - }))); - - Trinity::Asio::post(*_ioContext, [this, request]() { HandleAsyncRequest(request); }); + context.response.set(boost::beast::http::field::content_type, "application/json;charset=utf-8"); + context.response.body() = ::JSON::Serialize(gameAccounts); + session->SendResponse(context); + })); - return SOAP_OK; + return RequestHandlerResult::Async; } -int32 LoginRESTService::HandleGetPortal(std::shared_ptr<AsyncRequest> request) +LoginRESTService::RequestHandlerResult LoginRESTService::HandleGetPortal(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { - std::string const& hostname = GetHostnameForClient(boost::asio::ip::address_v4(request->GetClient()->ip)); - std::string response = Trinity::StringFormat("{}:{}", hostname, sConfigMgr->GetIntDefault("BattlenetPort", 1119)); - - soap_response(request->GetClient(), SOAP_FILE); - soap_send_raw(request->GetClient(), response.c_str(), response.length()); - return soap_end_send(request->GetClient()); + context.response.set(boost::beast::http::field::content_type, "text/plain"); + context.response.body() = Trinity::StringFormat("{}:{}", GetHostnameForClient(session->GetRemoteIpAddress()), sConfigMgr->GetIntDefault("BattlenetPort", 1119)); + return RequestHandlerResult::Handled; } -int32 LoginRESTService::HandlePostLogin(std::shared_ptr<AsyncRequest> request) +LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostLogin(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { - char* buf = nullptr; - size_t len = 0; - soap_http_body(request->GetClient(), &buf, &len); - - Battlenet::JSON::Login::LoginForm loginForm; - if (!buf || !JSON::Deserialize(buf, &loginForm)) + JSON::Login::LoginForm loginForm; + if (!::JSON::Deserialize(context.request.body(), &loginForm)) { - ResponseCodePlugin::GetForClient(request->GetClient())->ErrorCode = 400; - - Battlenet::JSON::Login::LoginResult loginResult; - loginResult.set_authentication_state(Battlenet::JSON::Login::LOGIN); + JSON::Login::LoginResult loginResult; + loginResult.set_authentication_state(JSON::Login::LOGIN); loginResult.set_error_code("UNABLE_TO_DECODE"); loginResult.set_error_message("There was an internal error while connecting to Battle.net. Please try again later."); - return SendResponse(request->GetClient(), loginResult); + + context.response.result(boost::beast::http::status::bad_request); + context.response.set(boost::beast::http::field::content_type, "application/json;charset=utf-8"); + context.response.body() = ::JSON::Serialize(loginResult); + session->SendResponse(context); + + return RequestHandlerResult::Handled; } std::string login; @@ -340,44 +259,32 @@ int32 LoginRESTService::HandlePostLogin(std::shared_ptr<AsyncRequest> request) std::string sentPasswordHash = CalculateShaPassHash(login, password); - request->SetCallback(std::make_unique<QueryCallback>(LoginDatabase.AsyncQuery(stmt) - .WithChainingPreparedCallback([request, login, sentPasswordHash, this](QueryCallback& callback, PreparedQueryResult result) + session->QueueQuery(LoginDatabase.AsyncQuery(stmt) + .WithChainingPreparedCallback([this, session, context = std::move(context), login = std::move(login), sentPasswordHash = std::move(sentPasswordHash)](QueryCallback& callback, PreparedQueryResult result) mutable { - if (result) + if (!result) { - Field* fields = result->Fetch(); - uint32 accountId = fields[0].GetUInt32(); - std::string pass_hash = fields[1].GetString(); - uint32 failedLogins = fields[2].GetUInt32(); - std::string loginTicket = fields[3].GetString(); - uint32 loginTicketExpiry = fields[4].GetUInt32(); - bool isBanned = fields[5].GetUInt64() != 0; - - if (sentPasswordHash == pass_hash) - { - if (loginTicket.empty() || loginTicketExpiry < time(nullptr)) - { - std::array<uint8, 20> ticket = Trinity::Crypto::GetRandomBytes<20>(); + JSON::Login::LoginResult loginResult; + loginResult.set_authentication_state(JSON::Login::DONE); + context.response.set(boost::beast::http::field::content_type, "application/json;charset=utf-8"); + context.response.body() = ::JSON::Serialize(loginResult); + session->SendResponse(context); + return; + } - loginTicket = "TC-" + ByteArrayToHexStr(ticket); - } + Field* fields = result->Fetch(); + uint32 accountId = fields[0].GetUInt32(); + std::string pass_hash = fields[1].GetString(); + uint32 failedLogins = fields[2].GetUInt32(); + std::string loginTicket = fields[3].GetString(); + uint32 loginTicketExpiry = fields[4].GetUInt32(); + bool isBanned = fields[5].GetUInt64() != 0; - LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_UPD_BNET_AUTHENTICATION); - stmt->setString(0, loginTicket); - stmt->setUInt32(1, time(nullptr) + _loginTicketDuration); - stmt->setUInt32(2, accountId); - callback.WithPreparedCallback([request, loginTicket](PreparedQueryResult) - { - Battlenet::JSON::Login::LoginResult loginResult; - loginResult.set_authentication_state(Battlenet::JSON::Login::DONE); - loginResult.set_login_ticket(loginTicket); - sLoginService.SendResponse(request->GetClient(), loginResult); - }).SetNextQuery(LoginDatabase.AsyncQuery(stmt)); - return; - } - else if (!isBanned) + if (sentPasswordHash != pass_hash) + { + if (!isBanned) { - std::string ip_address = boost::asio::ip::address_v4(request->GetClient()->ip).to_string(); + std::string ip_address = session->GetRemoteIpAddress().to_string(); uint32 maxWrongPassword = uint32(sConfigMgr->GetIntDefault("WrongPass.MaxCount", 0)); if (sConfigMgr->GetBoolDefault("WrongPass.Logging", false)) @@ -421,31 +328,50 @@ int32 LoginRESTService::HandlePostLogin(std::shared_ptr<AsyncRequest> request) LoginDatabase.CommitTransaction(trans); } } + + JSON::Login::LoginResult loginResult; + loginResult.set_authentication_state(JSON::Login::DONE); + + context.response.set(boost::beast::http::field::content_type, "application/json;charset=utf-8"); + context.response.body() = ::JSON::Serialize(loginResult); + session->SendResponse(context); + return; } - Battlenet::JSON::Login::LoginResult loginResult; - loginResult.set_authentication_state(Battlenet::JSON::Login::DONE); - sLoginService.SendResponse(request->GetClient(), loginResult); - }))); + if (loginTicket.empty() || loginTicketExpiry < time(nullptr)) + loginTicket = "TC-" + ByteArrayToHexStr(Trinity::Crypto::GetRandomBytes<20>()); - Trinity::Asio::post(*_ioContext, [this, request]() { HandleAsyncRequest(request); }); + LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_UPD_BNET_AUTHENTICATION); + stmt->setString(0, loginTicket); + stmt->setUInt32(1, time(nullptr) + _loginTicketDuration); + stmt->setUInt32(2, accountId); + callback.WithPreparedCallback([session, context = std::move(context), loginTicket = std::move(loginTicket)](PreparedQueryResult) mutable + { + JSON::Login::LoginResult loginResult; + loginResult.set_authentication_state(JSON::Login::DONE); + loginResult.set_login_ticket(loginTicket); + + context.response.set(boost::beast::http::field::content_type, "application/json;charset=utf-8"); + context.response.body() = ::JSON::Serialize(loginResult); + session->SendResponse(context); + }).SetNextQuery(LoginDatabase.AsyncQuery(stmt)); + })); - return SOAP_OK; + return RequestHandlerResult::Async; } -int32 LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<AsyncRequest> request) +LoginRESTService::RequestHandlerResult LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context) { - if (!request->GetClient()->userid) - return 401; - - request->SetCallback(std::make_unique<QueryCallback>(LoginDatabase.AsyncQuery([&] { - LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_BNET_EXISTING_AUTHENTICATION); - stmt->setString(0, request->GetClient()->userid); - return stmt; - }()) - .WithPreparedCallback([this, request](PreparedQueryResult result) + std::string ticket = ExtractAuthorization(context.request); + if (ticket.empty()) + return HandleUnauthorized(std::move(session), context); + + LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_BNET_EXISTING_AUTHENTICATION); + stmt->setString(0, ticket); + session->QueueQuery(LoginDatabase.AsyncQuery(stmt) + .WithPreparedCallback([this, session, context = std::move(context), ticket = std::move(ticket)](PreparedQueryResult result) mutable { - Battlenet::JSON::Login::LoginRefreshResult loginRefreshResult; + JSON::Login::LoginRefreshResult loginRefreshResult; if (result) { uint32 loginTicketExpiry = (*result)[0].GetUInt32(); @@ -456,7 +382,7 @@ int32 LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<AsyncReques LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_UPD_BNET_EXISTING_AUTHENTICATION); stmt->setUInt32(0, uint32(now + _loginTicketDuration)); - stmt->setString(1, request->GetClient()->userid); + stmt->setString(1, ticket); LoginDatabase.Execute(stmt); } else @@ -465,34 +391,12 @@ int32 LoginRESTService::HandlePostRefreshLoginTicket(std::shared_ptr<AsyncReques else loginRefreshResult.set_is_expired(true); - SendResponse(request->GetClient(), loginRefreshResult); - }))); - - Trinity::Asio::post(*_ioContext, [this, request]() { HandleAsyncRequest(request); }); - - return SOAP_OK; -} - -int32 LoginRESTService::SendResponse(soap* soapClient, google::protobuf::Message const& response) -{ - std::string jsonResponse = JSON::Serialize(response); - - soap_response(soapClient, SOAP_FILE); - soap_send_raw(soapClient, jsonResponse.c_str(), jsonResponse.length()); - return soap_end_send(soapClient); -} + context.response.set(boost::beast::http::field::content_type, "application/json;charset=utf-8"); + context.response.body() = ::JSON::Serialize(loginRefreshResult); + session->SendResponse(context); + })); -void LoginRESTService::HandleAsyncRequest(std::shared_ptr<AsyncRequest> request) -{ - if (!request->InvokeIfReady()) - { - Trinity::Asio::post(*_ioContext, [this, request]() { HandleAsyncRequest(request); }); - } - else if (request->GetResponseStatus()) - { - ResponseCodePlugin::GetForClient(request->GetClient())->ErrorCode = request->GetResponseStatus(); - SendResponse(request->GetClient(), Battlenet::JSON::Login::ErrorResponse()); - } + return RequestHandlerResult::Async; } std::string LoginRESTService::CalculateShaPassHash(std::string const& name, std::string const& password) @@ -510,85 +414,8 @@ std::string LoginRESTService::CalculateShaPassHash(std::string const& name, std: return ByteArrayToHexStr(sha.GetDigest(), true); } -Namespace namespaces[] = -{ - { nullptr, nullptr, nullptr, nullptr } -}; - -LoginRESTService& LoginRESTService::Instance() -{ - static LoginRESTService instance; - return instance; -} - -char const* const LoginRESTService::ResponseCodePlugin::PluginId = "bnet-error-code"; - -int32 LoginRESTService::ResponseCodePlugin::Init(soap* s, soap_plugin* p, void* /*arg*/) -{ - ResponseCodePlugin* data = new ResponseCodePlugin(); - data->fresponse = s->fresponse; - - p->id = PluginId; - p->fcopy = &Copy; - p->fdelete = &Destroy; - p->data = data; - - s->fresponse = &ChangeResponse; - return SOAP_OK; -} - -int32 LoginRESTService::ResponseCodePlugin::Copy(soap* /*s*/, soap_plugin* dst, soap_plugin* src) -{ - dst->data = new ResponseCodePlugin(*reinterpret_cast<ResponseCodePlugin*>(src->data)); - return SOAP_OK; -} - -void LoginRESTService::ResponseCodePlugin::Destroy(soap* s, soap_plugin* p) -{ - ResponseCodePlugin* data = reinterpret_cast<ResponseCodePlugin*>(p->data); - s->fresponse = data->fresponse; - delete data; -} - -int32 LoginRESTService::ResponseCodePlugin::ChangeResponse(soap* s, int32 originalResponse, uint64 contentLength) -{ - ResponseCodePlugin* self = reinterpret_cast<ResponseCodePlugin*>(soap_lookup_plugin(s, PluginId)); - return self->fresponse(s, self->ErrorCode && originalResponse == SOAP_FILE ? self->ErrorCode : originalResponse, contentLength); -} - -LoginRESTService::ResponseCodePlugin* LoginRESTService::ResponseCodePlugin::GetForClient(soap* s) -{ - return ASSERT_NOTNULL(reinterpret_cast<ResponseCodePlugin*>(soap_lookup_plugin(s, PluginId))); -} - -char const* const LoginRESTService::ContentTypePlugin::PluginId = "bnet-content-type"; - -int32 LoginRESTService::ContentTypePlugin::Init(soap* s, soap_plugin* p, void* arg) -{ - ContentTypePlugin* data = new ContentTypePlugin(); - data->fposthdr = s->fposthdr; - data->ContentType = reinterpret_cast<char const*>(arg); - - p->id = PluginId; - p->fdelete = &Destroy; - p->data = data; - - s->fposthdr = &OnSetHeader; - return SOAP_OK; -} - -void LoginRESTService::ContentTypePlugin::Destroy(soap* s, soap_plugin* p) +void LoginRESTService::OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex) { - ContentTypePlugin* data = reinterpret_cast<ContentTypePlugin*>(p->data); - s->fposthdr = data->fposthdr; - delete data; + sLoginService.OnSocketOpen(std::move(sock), threadIndex); } - -int32 LoginRESTService::ContentTypePlugin::OnSetHeader(soap* s, char const* key, char const* value) -{ - ContentTypePlugin* self = reinterpret_cast<ContentTypePlugin*>(soap_lookup_plugin(s, PluginId)); - if (key && !strcmp("Content-Type", key)) - value = self->ContentType; - - return self->fposthdr(s, key, value); } diff --git a/src/server/bnetserver/REST/LoginRESTService.h b/src/server/bnetserver/REST/LoginRESTService.h index f783fea243a..1313493e023 100644 --- a/src/server/bnetserver/REST/LoginRESTService.h +++ b/src/server/bnetserver/REST/LoginRESTService.h @@ -18,98 +18,59 @@ #ifndef LoginRESTService_h__ #define LoginRESTService_h__ -#include "Define.h" -#include "IoContext.h" +#include "HttpService.h" #include "Login.pb.h" -#include "Session.h" -#include <boost/asio/ip/tcp.hpp> -#include <atomic> -#include <thread> - -class AsyncRequest; -struct soap; -struct soap_plugin; +#include "LoginHttpSession.h" +namespace Battlenet +{ enum class BanMode { BAN_IP = 0, BAN_ACCOUNT = 1 }; -class LoginRESTService +class LoginRESTService : public Trinity::Net::Http::HttpService<LoginHttpSession> { public: - LoginRESTService() : _ioContext(nullptr), _stopped(false), _port(0), _loginTicketDuration(0) { } + using RequestHandlerResult = Trinity::Net::Http::RequestHandlerResult; + using HttpRequest = Trinity::Net::Http::Request; + using HttpResponse = Trinity::Net::Http::Response; + using HttpRequestContext = Trinity::Net::Http::RequestContext; + using HttpSessionState = Trinity::Net::Http::SessionState; + + LoginRESTService() : HttpService("login"), _port(0), _loginTicketDuration(0) { } static LoginRESTService& Instance(); - bool Start(Trinity::Asio::IoContext* ioContext); - void Stop(); + bool StartNetwork(Trinity::Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount = 1) override; std::string const& GetHostnameForClient(boost::asio::ip::address const& address) const; - int32 GetPort() const { return _port; } + uint16 GetPort() const { return _port; } private: - void Run(); - - friend int32 handle_get_plugin(soap* soapClient); - friend int32 handle_post_plugin(soap* soapClient); - - using HttpMethodHandlerMap = std::unordered_map<std::string, int32(LoginRESTService::*)(std::shared_ptr<AsyncRequest>)>; - int32 HandleHttpRequest(soap* soapClient, char const* method, HttpMethodHandlerMap const& handlers); - - int32 HandleGetForm(std::shared_ptr<AsyncRequest> request); - int32 HandleGetGameAccounts(std::shared_ptr<AsyncRequest> request); - int32 HandleGetPortal(std::shared_ptr<AsyncRequest> request); + static void OnSocketAccept(boost::asio::ip::tcp::socket&& sock, uint32 threadIndex); - int32 HandlePostLogin(std::shared_ptr<AsyncRequest> request); - int32 HandlePostRefreshLoginTicket(std::shared_ptr<AsyncRequest> request); + static std::string ExtractAuthorization(HttpRequest const& request); - int32 SendResponse(soap* soapClient, google::protobuf::Message const& response); + RequestHandlerResult HandleGetForm(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context); + RequestHandlerResult HandleGetGameAccounts(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context); + RequestHandlerResult HandleGetPortal(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context); - void HandleAsyncRequest(std::shared_ptr<AsyncRequest> request); + RequestHandlerResult HandlePostLogin(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context); + RequestHandlerResult HandlePostRefreshLoginTicket(std::shared_ptr<LoginHttpSession> session, HttpRequestContext& context); - std::string CalculateShaPassHash(std::string const& name, std::string const& password); + static std::string CalculateShaPassHash(std::string const& name, std::string const& password); - struct ResponseCodePlugin - { - static char const* const PluginId; - static int32 Init(soap* s, soap_plugin*, void*); - static int32 Copy(soap* s, soap_plugin* dst, soap_plugin* src); - static void Destroy(soap* s, soap_plugin* p); - static int32 ChangeResponse(soap* s, int32 originalResponse, uint64 contentLength); - - static ResponseCodePlugin* GetForClient(soap* s); - - int32(*fresponse)(soap* s, int32 status, uint64 length); - int32 ErrorCode; - }; - - struct ContentTypePlugin - { - static char const* const PluginId; - static int32 Init(soap* s, soap_plugin* p, void*); - static void Destroy(soap* s, soap_plugin* p); - static int32 OnSetHeader(soap* s, char const* key, char const* value); - - int32(*fposthdr)(soap* s, char const* key, char const* value); - char const* ContentType; - }; - - Trinity::Asio::IoContext* _ioContext; - std::thread _thread; - std::atomic<bool> _stopped; - Battlenet::JSON::Login::FormInputs _formInputs; + JSON::Login::FormInputs _formInputs; std::string _bindIP; - int32 _port; + uint16 _port; std::array<std::string, 2> _hostnames; std::array<boost::asio::ip::address, 2> _addresses; uint32 _loginTicketDuration; - - HttpMethodHandlerMap _getHandlers; - HttpMethodHandlerMap _postHandlers; }; +} -#define sLoginService LoginRESTService::Instance() +#define sLoginService Battlenet::LoginRESTService::Instance() #endif // LoginRESTService_h__ diff --git a/src/server/bnetserver/Server/Session.cpp b/src/server/bnetserver/Server/Session.cpp index ea5256ed635..b450df14e1d 100644 --- a/src/server/bnetserver/Server/Session.cpp +++ b/src/server/bnetserver/Server/Session.cpp @@ -30,6 +30,7 @@ #include "RealmList.h" #include "RealmList.pb.h" #include "ServiceDispatcher.h" +#include "SslContext.h" #include "Timezone.h" #include <rapidjson/document.h> #include <zlib.h> @@ -73,7 +74,8 @@ void Battlenet::Session::GameAccountInfo::LoadResult(Field const* fields) DisplayName = Name; } -Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSocket(std::move(socket)), _accountInfo(new AccountInfo()), _gameAccountInfo(nullptr), _locale(), +Battlenet::Session::Session(boost::asio::ip::tcp::socket&& socket) : BattlenetSocket(std::move(socket), SslContext::instance()), + _accountInfo(new AccountInfo()), _gameAccountInfo(nullptr), _locale(), _os(), _build(0), _timezoneOffset(0min), _ipCountry(), _clientSecret(), _authed(false), _requestToken(0) { _headerLengthBuffer.Resize(2); diff --git a/src/server/bnetserver/Server/Session.h b/src/server/bnetserver/Server/Session.h index 6d22c202615..33a058f940c 100644 --- a/src/server/bnetserver/Server/Session.h +++ b/src/server/bnetserver/Server/Session.h @@ -20,11 +20,10 @@ #include "AsyncCallbackProcessor.h" #include "Duration.h" +#include "QueryResult.h" #include "Realm.h" -#include "SslContext.h" -#include "SslSocket.h" #include "Socket.h" -#include "QueryResult.h" +#include "SslSocket.h" #include <boost/asio/ip/tcp.hpp> #include <google/protobuf/message.h> #include <memory> @@ -65,9 +64,9 @@ using namespace bgs::protocol; namespace Battlenet { - class Session : public Socket<Session, SslSocket<SslContext>> + class Session : public Socket<Session, SslSocket<>> { - typedef Socket<Session, SslSocket<SslContext>> BattlenetSocket; + typedef Socket<Session, SslSocket<>> BattlenetSocket; public: struct LastPlayedCharacterInfo diff --git a/src/server/database/Database/DatabaseWorkerPool.cpp b/src/server/database/Database/DatabaseWorkerPool.cpp index 6e9f8a8df1a..9cc4f755d9d 100644 --- a/src/server/database/Database/DatabaseWorkerPool.cpp +++ b/src/server/database/Database/DatabaseWorkerPool.cpp @@ -35,6 +35,7 @@ #include "MySQLWorkaround.h" #include <boost/asio/use_future.hpp> #include <mysqld_error.h> +#include <utility> #ifdef TRINITY_DEBUG #include <sstream> #include <boost/stacktrace.hpp> diff --git a/src/server/shared/CMakeLists.txt b/src/server/shared/CMakeLists.txt index 66f4ff56fd2..040381b946f 100644 --- a/src/server/shared/CMakeLists.txt +++ b/src/server/shared/CMakeLists.txt @@ -45,8 +45,7 @@ target_link_libraries(shared database rapidjson proto - zlib - gsoap) + zlib) set_target_properties(shared PROPERTIES diff --git a/src/server/shared/Networking/Http/BaseHttpSocket.cpp b/src/server/shared/Networking/Http/BaseHttpSocket.cpp new file mode 100644 index 00000000000..ca92c442aa2 --- /dev/null +++ b/src/server/shared/Networking/Http/BaseHttpSocket.cpp @@ -0,0 +1,115 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "BaseHttpSocket.h" +#include <boost/beast/http/serializer.hpp> + +namespace Trinity::Net::Http +{ +using RequestSerializer = boost::beast::http::request_serializer<ResponseBody>; +using ResponseSerializer = boost::beast::http::response_serializer<ResponseBody>; + +bool AbstractSocket::ParseRequest(MessageBuffer& packet, RequestParser& parser) +{ + if (!parser.is_done()) + { + // need more data in the payload + boost::system::error_code ec = {}; + std::size_t readDataSize = parser.put(boost::asio::const_buffer(packet.GetReadPointer(), packet.GetActiveSize()), ec); + packet.ReadCompleted(readDataSize); + } + + return parser.is_done(); +} + +std::string AbstractSocket::SerializeRequest(Request const& request) +{ + RequestSerializer serializer(request); + + std::string buffer; + while (!serializer.is_done()) + { + size_t totalBytes = 0; + boost::system::error_code ec = {}; + serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code const&, ConstBufferSequence const& buffers) + { + size_t totalBytesInBuffers = boost::asio::buffer_size(buffers); + + buffer.reserve(buffer.size() + totalBytes); + + auto begin = boost::asio::buffers_begin(buffers); + auto end = boost::asio::buffers_end(buffers); + + std::copy(begin, end, std::back_inserter(buffer)); + totalBytes += totalBytesInBuffers; + }); + + serializer.consume(totalBytes); + } + + return buffer; +} + +MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response& response) +{ + response.prepare_payload(); + + ResponseSerializer serializer(response); + bool (*serializerIsDone)(ResponseSerializer&); + if (request.method() != boost::beast::http::verb::head) + { + serializerIsDone = [](ResponseSerializer& s) { return s.is_done(); }; + } + else + { + serializerIsDone = [](ResponseSerializer& s) { return s.is_header_done(); }; + serializer.split(true); + } + + MessageBuffer buffer; + while (!serializerIsDone(serializer)) + { + serializer.limit(buffer.GetRemainingSpace()); + + size_t totalBytes = 0; + boost::system::error_code ec = {}; + serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code& currentError, ConstBufferSequence const& buffers) + { + size_t totalBytesInBuffers = boost::asio::buffer_size(buffers); + if (totalBytesInBuffers > buffer.GetRemainingSpace()) + { + currentError = boost::beast::http::error::need_more; + return; + } + + auto begin = boost::asio::buffers_begin(buffers); + auto end = boost::asio::buffers_end(buffers); + + std::copy(begin, end, buffer.GetWritePointer()); + buffer.WriteCompleted(totalBytesInBuffers); + totalBytes += totalBytesInBuffers; + }); + + serializer.consume(totalBytes); + + if (ec == boost::beast::http::error::need_more) + buffer.Resize(buffer.GetBufferSize() + 4096); + } + + return buffer; +} +} diff --git a/src/server/shared/Networking/Http/BaseHttpSocket.h b/src/server/shared/Networking/Http/BaseHttpSocket.h new file mode 100644 index 00000000000..330287252d2 --- /dev/null +++ b/src/server/shared/Networking/Http/BaseHttpSocket.h @@ -0,0 +1,191 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_BASE_HTTP_SOCKET_H +#define TRINITYCORE_BASE_HTTP_SOCKET_H + +#include "AsyncCallbackProcessor.h" +#include "DatabaseEnvFwd.h" +#include "HttpCommon.h" +#include "HttpSessionState.h" +#include "Optional.h" +#include "QueryCallback.h" +#include "Socket.h" +#include <boost/asio/buffers_iterator.hpp> +#include <boost/beast/http/parser.hpp> +#include <boost/beast/http/string_body.hpp> +#include <boost/uuid/uuid_io.hpp> + +namespace Trinity::Net::Http +{ +using RequestParser = boost::beast::http::request_parser<RequestBody>; + +class TC_SHARED_API AbstractSocket +{ +public: + AbstractSocket() = default; + AbstractSocket(AbstractSocket const& other) = default; + AbstractSocket(AbstractSocket&& other) = default; + AbstractSocket& operator=(AbstractSocket const& other) = default; + AbstractSocket& operator=(AbstractSocket&& other) = default; + virtual ~AbstractSocket() = default; + + static bool ParseRequest(MessageBuffer& packet, RequestParser& parser); + + static std::string SerializeRequest(Request const& request); + static MessageBuffer SerializeResponse(Request const& request, Response& response); + + virtual void SendResponse(RequestContext& context) = 0; + + virtual void QueueQuery(QueryCallback&& queryCallback) = 0; + + virtual std::string GetClientInfo() const = 0; + + virtual Optional<boost::uuids::uuid> GetSessionId() const = 0; +}; + +template<typename Derived, typename Stream> +class BaseSocket : public ::Socket<Derived, Stream>, public AbstractSocket +{ + using Base = ::Socket<Derived, Stream>; + +public: + template<typename... Args> + explicit BaseSocket(boost::asio::ip::tcp::socket&& socket, Args&&... args) + : Base(std::move(socket), std::forward<Args>(args)...) { } + + BaseSocket(BaseSocket const& other) = delete; + BaseSocket(BaseSocket&& other) = delete; + BaseSocket& operator=(BaseSocket const& other) = delete; + BaseSocket& operator=(BaseSocket&& other) = delete; + + ~BaseSocket() = default; + + void ReadHandler() override + { + if (!this->IsOpen()) + return; + + MessageBuffer& packet = this->GetReadBuffer(); + while (packet.GetActiveSize() > 0) + { + if (!ParseRequest(packet, *_httpParser)) + { + // Couldn't receive the whole data this time. + break; + } + + if (!HandleMessage(_httpParser->get())) + { + this->CloseSocket(); + break; + } + + this->ResetHttpParser(); + } + + this->AsyncRead(); + } + + bool HandleMessage(Request& request) + { + RequestContext context { .request = std::move(request) }; + + if (!_state) + _state = this->ObtainSessionState(context); + + RequestHandlerResult status = this->RequestHandler(context); + + if (status != RequestHandlerResult::Async) + this->SendResponse(context); + + return status != RequestHandlerResult::Error; + } + + virtual RequestHandlerResult RequestHandler(RequestContext& context) = 0; + + void SendResponse(RequestContext& context) override + { + MessageBuffer buffer = SerializeResponse(context.request, context.response); + + TC_LOG_DEBUG("server.http", "{} Request {} {} done, status {}", this->GetClientInfo(), ToStdStringView(context.request.method_string()), + ToStdStringView(context.request.target()), context.response.result_int()); + if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE)) + { + sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Request: ", this->GetClientInfo(), + CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>"); + sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Response: ", this->GetClientInfo(), + CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>"); + } + + this->QueuePacket(std::move(buffer)); + + if (!context.response.keep_alive()) + this->DelayedCloseSocket(); + } + + void QueueQuery(QueryCallback&& queryCallback) override + { + this->_queryProcessor.AddCallback(std::move(queryCallback)); + } + + bool Update() override + { + if (!this->Base::Update()) + return false; + + this->_queryProcessor.ProcessReadyCallbacks(); + return true; + } + + std::string GetClientInfo() const override + { + std::string info; + info.reserve(500); + auto itr = StringFormatTo(std::back_inserter(info), "[{}:{}", this->GetRemoteIpAddress().to_string(), this->GetRemotePort()); + if (_state) + itr = StringFormatTo(itr, ", Session Id: {}", boost::uuids::to_string(_state->Id)); + + StringFormatTo(itr, "]"); + return info; + } + + Optional<boost::uuids::uuid> GetSessionId() const final + { + if (this->_state) + return this->_state->Id; + + return {}; + } + +protected: + void ResetHttpParser() + { + this->_httpParser.reset(); + this->_httpParser.emplace(); + this->_httpParser->eager(true); + } + + virtual std::shared_ptr<SessionState> ObtainSessionState(RequestContext& context) const = 0; + + QueryCallbackProcessor _queryProcessor; + Optional<RequestParser> _httpParser; + std::shared_ptr<SessionState> _state; +}; +} + +#endif // TRINITYCORE_BASE_HTTP_SOCKET_H diff --git a/src/server/shared/Networking/Http/HttpCommon.h b/src/server/shared/Networking/Http/HttpCommon.h new file mode 100644 index 00000000000..5f6ecb6c147 --- /dev/null +++ b/src/server/shared/Networking/Http/HttpCommon.h @@ -0,0 +1,55 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_HTTP_COMMON_H +#define TRINITYCORE_HTTP_COMMON_H + +#include "Define.h" +#include <boost/beast/http/message.hpp> +#include <boost/beast/http/string_body.hpp> + +namespace Trinity::Net::Http +{ +using RequestBody = boost::beast::http::string_body; +using ResponseBody = boost::beast::http::string_body; + +using Request = boost::beast::http::request<RequestBody>; +using Response = boost::beast::http::response<ResponseBody>; + +struct RequestContext +{ + Request request; + Response response; + struct RequestHandler const* handler = nullptr; +}; + +TC_SHARED_API bool CanLogRequestContent(RequestContext const& context); +TC_SHARED_API bool CanLogResponseContent(RequestContext const& context); + +inline std::string_view ToStdStringView(boost::beast::string_view bsw) +{ + return { bsw.data(), bsw.size() }; +} + +enum class RequestHandlerResult +{ + Handled, + Error, + Async, +}; +} +#endif // TRINITYCORE_HTTP_COMMON_H diff --git a/src/server/shared/Networking/Http/HttpService.cpp b/src/server/shared/Networking/Http/HttpService.cpp new file mode 100644 index 00000000000..8995b65612a --- /dev/null +++ b/src/server/shared/Networking/Http/HttpService.cpp @@ -0,0 +1,258 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "HttpService.h" +#include "BaseHttpSocket.h" +#include "CryptoRandom.h" +#include "Timezone.h" +#include <boost/beast/version.hpp> +#include <boost/uuid/string_generator.hpp> +#include <fmt/chrono.h> + +namespace Trinity::Net::Http +{ +bool CanLogRequestContent(RequestContext const& context) +{ + return !context.handler || !context.handler->Flags.HasFlag(RequestHandlerFlag::DoNotLogRequestContent); +} + +bool CanLogResponseContent(RequestContext const& context) +{ + return !context.handler || !context.handler->Flags.HasFlag(RequestHandlerFlag::DoNotLogResponseContent); +} + +RequestHandlerResult DispatcherService::HandleRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context) +{ + TC_LOG_DEBUG(_logger, "{} Starting request {} {}", session->GetClientInfo(), + ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target())); + + std::string_view path = [&] + { + std::string_view path = ToStdStringView(context.request.target()); + size_t queryIndex = path.find('?'); + if (queryIndex != std::string_view::npos) + path = path.substr(0, queryIndex); + return path; + }(); + + context.handler = [&]() -> HttpMethodHandlerMap::mapped_type const* + { + switch (context.request.method()) + { + case boost::beast::http::verb::get: + case boost::beast::http::verb::head: + { + auto itr = _getHandlers.find(path); + return itr != _getHandlers.end() ? &itr->second : nullptr; + } + case boost::beast::http::verb::post: + { + auto itr = _postHandlers.find(path); + return itr != _postHandlers.end() ? &itr->second : nullptr; + } + default: + break; + } + return nullptr; + }(); + + SystemTimePoint responseDate = SystemTimePoint::clock::now(); + context.response.set(boost::beast::http::field::date, StringFormat("{:%a, %d %b %Y %T GMT}", responseDate - Timezone::GetSystemZoneOffsetAt(responseDate))); + context.response.set(boost::beast::http::field::server, BOOST_BEAST_VERSION_STRING); + context.response.keep_alive(context.response.keep_alive()); + + if (!context.handler) + return HandlePathNotFound(std::move(session), context); + + return context.handler->Func(std::move(session), context); +} + +RequestHandlerResult DispatcherService::HandlePathNotFound(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context) +{ + context.response.result(boost::beast::http::status::not_found); + return RequestHandlerResult::Handled; +} + +RequestHandlerResult DispatcherService::HandleUnauthorized(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context) +{ + context.response.result(boost::beast::http::status::unauthorized); + return RequestHandlerResult::Handled; +} + +void DispatcherService::RegisterHandler(boost::beast::http::verb method, std::string_view path, + std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> handler, + RequestHandlerFlag flags) +{ + HttpMethodHandlerMap& handlerMap = [&]() -> HttpMethodHandlerMap& + { + switch (method) + { + case boost::beast::http::verb::get: + return _getHandlers; + case boost::beast::http::verb::post: + return _postHandlers; + default: + { + std::string_view methodString = ToStdStringView(boost::beast::http::to_string(method)); + ABORT_MSG("Tried to register a handler for unsupported HTTP method " STRING_VIEW_FMT, STRING_VIEW_FMT_ARG(methodString)); + } + } + }(); + + handlerMap[std::string(path)] = { .Func = std::move(handler), .Flags = flags }; + TC_LOG_INFO(_logger, "Registered new handler for {} {}", ToStdStringView(boost::beast::http::to_string(method)), path); +} + +void SessionService::InitAndStoreSessionState(std::shared_ptr<SessionState> state, boost::asio::ip::address const& address) +{ + state->RemoteAddress = address; + + // Generate session id + { + std::unique_lock lock{ _sessionsMutex }; + + while (state->Id.is_nil() || _sessions.contains(state->Id)) + std::copy_n(Trinity::Crypto::GetRandomBytes<16>().begin(), 16, state->Id.begin()); + + TC_LOG_DEBUG(_logger, "Client at {} created new session {}", address.to_string(), boost::uuids::to_string(state->Id)); + _sessions[state->Id] = std::move(state); + } +} + +void SessionService::Start(Asio::IoContext& ioContext) +{ + _inactiveSessionsKillTimer = std::make_unique<Asio::DeadlineTimer>(ioContext); + _inactiveSessionsKillTimer->expires_from_now(boost::posix_time::minutes(1)); + _inactiveSessionsKillTimer->async_wait([this](boost::system::error_code const& err) + { + if (err) + return; + + KillInactiveSessions(); + }); +} + +void SessionService::Stop() +{ + _inactiveSessionsKillTimer = nullptr; + { + std::unique_lock lock{ _sessionsMutex }; + _sessions.clear(); + } + { + std::unique_lock lock{ _inactiveSessionsMutex }; + _inactiveSessions.clear(); + } +} + +std::shared_ptr<SessionState> SessionService::FindAndRefreshSessionState(std::string_view id, boost::asio::ip::address const& address) +{ + std::shared_ptr<SessionState> state; + + { + std::shared_lock lock{ _sessionsMutex }; + auto itr = _sessions.find(boost::uuids::string_generator()(id.begin(), id.end())); + if (itr == _sessions.end()) + { + TC_LOG_DEBUG(_logger, "Client at {} attempted to use a session {} that was expired", address.to_string(), id); + return nullptr; // no session + } + + state = itr->second; + } + + if (state->RemoteAddress != address) + { + TC_LOG_ERROR(_logger, "Client at {} attempted to use a session {} that was last accessed from {}, denied access", + address.to_string(), id, state->RemoteAddress.to_string()); + return nullptr; + } + + { + std::unique_lock inactiveSessionsLock{ _inactiveSessionsMutex }; + _inactiveSessions.erase(state->Id); + } + + return state; +} + +void SessionService::MarkSessionInactive(boost::uuids::uuid const& id) +{ + { + std::unique_lock inactiveSessionsLock{ _inactiveSessionsMutex }; + _inactiveSessions.insert(id); + } + + { + auto itr = _sessions.find(id); + if (itr != _sessions.end()) + { + itr->second->InactiveTimestamp = TimePoint::clock::now() + Minutes(5); + TC_LOG_TRACE(_logger, "Session {} marked as inactive", boost::uuids::to_string(id)); + } + } +} + +void SessionService::KillInactiveSessions() +{ + std::set<boost::uuids::uuid> inactiveSessions; + + { + std::unique_lock lock{ _inactiveSessionsMutex }; + std::swap(_inactiveSessions, inactiveSessions); + } + + { + TimePoint now = TimePoint::clock::now(); + std::size_t inactiveSessionsCount = inactiveSessions.size(); + + std::unique_lock lock{ _sessionsMutex }; + for (auto itr = inactiveSessions.begin(); itr != inactiveSessions.end(); ) + { + auto sessionItr = _sessions.find(*itr); + if (sessionItr == _sessions.end() || sessionItr->second->InactiveTimestamp < now) + { + _sessions.erase(sessionItr); + itr = inactiveSessions.erase(itr); + } + else + ++itr; + } + + TC_LOG_DEBUG(_logger, "Killed {} inactive sessions", inactiveSessionsCount - inactiveSessions.size()); + } + + { + // restore sessions not killed to inactive queue + std::unique_lock lock{ _inactiveSessionsMutex }; + for (auto itr = inactiveSessions.begin(); itr != inactiveSessions.end(); ) + { + auto node = inactiveSessions.extract(itr++); + _inactiveSessions.insert(std::move(node)); + } + } + + _inactiveSessionsKillTimer->expires_from_now(boost::posix_time::minutes(1)); + _inactiveSessionsKillTimer->async_wait([this](boost::system::error_code const& err) + { + if (err) + return; + + KillInactiveSessions(); + }); +} +} diff --git a/src/server/shared/Networking/Http/HttpService.h b/src/server/shared/Networking/Http/HttpService.h new file mode 100644 index 00000000000..01c66146ae3 --- /dev/null +++ b/src/server/shared/Networking/Http/HttpService.h @@ -0,0 +1,188 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_HTTP_SERVICE_H +#define TRINITYCORE_HTTP_SERVICE_H + +#include "AsioHacksFwd.h" +#include "Concepts.h" +#include "Define.h" +#include "EnumFlag.h" +#include "HttpCommon.h" +#include "HttpSessionState.h" +#include "Optional.h" +#include "SocketMgr.h" +#include <boost/uuid/uuid.hpp> +#include <functional> +#include <map> +#include <set> +#include <shared_mutex> + +namespace Trinity::Net::Http +{ +class AbstractSocket; + +enum class RequestHandlerFlag +{ + None = 0x0, + DoNotLogRequestContent = 0x1, + DoNotLogResponseContent = 0x2, +}; + +DEFINE_ENUM_FLAG(RequestHandlerFlag); + +struct RequestHandler +{ + std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> Func; + EnumFlag<RequestHandlerFlag> Flags = RequestHandlerFlag::None; +}; + +class TC_SHARED_API DispatcherService +{ +public: + explicit DispatcherService(std::string_view loggerSuffix) : _logger("server.http.dispatcher.") + { + _logger.append(loggerSuffix); + } + + RequestHandlerResult HandleRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context); + + RequestHandlerResult HandlePathNotFound(std::shared_ptr<AbstractSocket> session, RequestContext& context); + RequestHandlerResult HandleUnauthorized(std::shared_ptr<AbstractSocket> session, RequestContext& context); + +protected: + void RegisterHandler(boost::beast::http::verb method, std::string_view path, + std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> handler, + RequestHandlerFlag flags = RequestHandlerFlag::None); + +private: + using HttpMethodHandlerMap = std::map<std::string, RequestHandler, std::less<>>; + + HttpMethodHandlerMap _getHandlers; + HttpMethodHandlerMap _postHandlers; + + std::string _logger; +}; + +class TC_SHARED_API SessionService +{ +public: + explicit SessionService(std::string_view loggerSuffix) : _logger("server.http.session.") + { + _logger.append(loggerSuffix); + } + + void Start(Asio::IoContext& ioContext); + void Stop(); + + std::shared_ptr<SessionState> FindAndRefreshSessionState(std::string_view id, boost::asio::ip::address const& address); + void MarkSessionInactive(boost::uuids::uuid const& id); + +protected: + void InitAndStoreSessionState(std::shared_ptr<SessionState> state, boost::asio::ip::address const& address); + + void KillInactiveSessions(); + +private: + std::shared_mutex _sessionsMutex; + std::map<boost::uuids::uuid, std::shared_ptr<SessionState>> _sessions; + + std::mutex _inactiveSessionsMutex; + std::set<boost::uuids::uuid> _inactiveSessions; + std::unique_ptr<Asio::DeadlineTimer> _inactiveSessionsKillTimer; + + std::string _logger; +}; + +template<typename Callable, typename SessionImpl> +concept HttpRequestHandler = invocable_r<Callable, RequestHandlerResult, std::shared_ptr<SessionImpl>, RequestContext&>; + +template<typename SessionImpl> +class HttpService : public SocketMgr<SessionImpl>, public DispatcherService, public SessionService +{ +public: + HttpService(std::string_view loggerSuffix) : DispatcherService(loggerSuffix), SessionService(loggerSuffix), _ioContext(nullptr), _logger("server.http.") + { + _logger.append(loggerSuffix); + } + + bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount = 1) override + { + if (!SocketMgr<SessionImpl>::StartNetwork(ioContext, bindIp, port, threadCount)) + return false; + + SessionService::Start(ioContext); + return true; + } + + void StopNetwork() override + { + SessionService::Stop(); + SocketMgr<SessionImpl>::StopNetwork(); + } + + // http handling + using DispatcherService::RegisterHandler; + + template<HttpRequestHandler<SessionImpl> Callable> + void RegisterHandler(boost::beast::http::verb method, std::string_view path, Callable handler, RequestHandlerFlag flags = RequestHandlerFlag::None) + { + this->DispatcherService::RegisterHandler(method, path, [handler = std::move(handler)](std::shared_ptr<AbstractSocket> session, RequestContext& context) -> RequestHandlerResult + { + return handler(std::static_pointer_cast<SessionImpl>(std::move(session)), context); + }, flags); + } + + // session tracking + virtual std::shared_ptr<SessionState> CreateNewSessionState(boost::asio::ip::address const& address) + { + std::shared_ptr<SessionState> state = std::make_shared<SessionState>(); + InitAndStoreSessionState(state, address); + return state; + } + +protected: + class Thread : public NetworkThread<SessionImpl> + { + protected: + void SocketRemoved(std::shared_ptr<SessionImpl> session) override + { + if (Optional<boost::uuids::uuid> id = session->GetSessionId()) + _service->MarkSessionInactive(*id); + } + + private: + friend HttpService; + + SessionService* _service; + }; + + NetworkThread<SessionImpl>* CreateThreads() const override + { + Thread* threads = new Thread[this->GetNetworkThreadCount()]; + for (int32 i = 0; i < this->GetNetworkThreadCount(); ++i) + threads[i]._service = const_cast<HttpService*>(this); + return threads; + } + +private: + Asio::IoContext* _ioContext; + std::string _logger; +}; +} + +#endif // TRINITYCORE_HTTP_SERVICE_H diff --git a/src/server/shared/Networking/Http/HttpSessionState.h b/src/server/shared/Networking/Http/HttpSessionState.h new file mode 100644 index 00000000000..3012a2efc65 --- /dev/null +++ b/src/server/shared/Networking/Http/HttpSessionState.h @@ -0,0 +1,35 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_HTTP_SESSION_STATE_H +#define TRINITYCORE_HTTP_SESSION_STATE_H + +#include "Duration.h" +#include <boost/asio/ip/address.hpp> +#include <boost/uuid/uuid.hpp> + +namespace Trinity::Net::Http +{ +struct SessionState +{ + boost::uuids::uuid Id = { }; + boost::asio::ip::address RemoteAddress; + TimePoint InactiveTimestamp = TimePoint::max(); +}; +} + +#endif // TRINITYCORE_HTTP_SESSION_STATE_H diff --git a/src/server/shared/Networking/Http/HttpSocket.h b/src/server/shared/Networking/Http/HttpSocket.h new file mode 100644 index 00000000000..2bd18efd565 --- /dev/null +++ b/src/server/shared/Networking/Http/HttpSocket.h @@ -0,0 +1,75 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_HTTP_SOCKET_H +#define TRINITYCORE_HTTP_SOCKET_H + +#include "BaseHttpSocket.h" +#include <boost/beast/core/tcp_stream.hpp> + +namespace Trinity::Net::Http +{ +namespace Impl +{ +class BoostBeastSocketWrapper : public boost::beast::tcp_stream +{ +public: + using boost::beast::tcp_stream::tcp_stream; + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) + { + socket().shutdown(what, shutdownError); + } + + void close(boost::system::error_code& /*error*/) + { + boost::beast::tcp_stream::close(); + } + + boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const + { + return socket().remote_endpoint(); + } +}; +} + +template <typename Derived> +class Socket : public BaseSocket<Derived, Impl::BoostBeastSocketWrapper> +{ + using SocketBase = BaseSocket<Derived, Impl::BoostBeastSocketWrapper>; + +public: + explicit Socket(boost::asio::ip::tcp::socket&& socket) + : SocketBase(std::move(socket)) { } + + Socket(Socket const& other) = delete; + Socket(Socket&& other) = delete; + Socket& operator=(Socket const& other) = delete; + Socket& operator=(Socket&& other) = delete; + + ~Socket() = default; + + void Start() override + { + this->ResetHttpParser(); + + this->AsyncRead(); + } +}; +} + +#endif // TRINITYCORE_HTTP_SOCKET_H diff --git a/src/server/shared/Networking/Http/HttpSslSocket.h b/src/server/shared/Networking/Http/HttpSslSocket.h new file mode 100644 index 00000000000..cdb70645e05 --- /dev/null +++ b/src/server/shared/Networking/Http/HttpSslSocket.h @@ -0,0 +1,97 @@ +/* + * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License along + * with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#ifndef TRINITYCORE_HTTP_SSL_SOCKET_H +#define TRINITYCORE_HTTP_SSL_SOCKET_H + +#include "BaseHttpSocket.h" +#include "SslSocket.h" +#include <boost/beast/core/stream_traits.hpp> +#include <boost/beast/core/tcp_stream.hpp> +#include <boost/beast/ssl/ssl_stream.hpp> + +namespace Trinity::Net::Http +{ +namespace Impl +{ +class BoostBeastSslSocketWrapper : public ::SslSocket<boost::beast::ssl_stream<boost::beast::tcp_stream>> +{ +public: + using SslSocket::SslSocket; + + void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError) + { + _sslSocket.shutdown(shutdownError); + boost::beast::get_lowest_layer(_sslSocket).socket().shutdown(what, shutdownError); + } + + void close(boost::system::error_code& /*error*/) + { + boost::beast::get_lowest_layer(_sslSocket).close(); + } + + boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const + { + return boost::beast::get_lowest_layer(_sslSocket).socket().remote_endpoint(); + } +}; +} + +template <typename Derived> +class SslSocket : public BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper> +{ + using SocketBase = BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper>; + +public: + explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext) + : SocketBase(std::move(socket), sslContext) { } + + SslSocket(SslSocket const& other) = delete; + SslSocket(SslSocket&& other) = delete; + SslSocket& operator=(SslSocket const& other) = delete; + SslSocket& operator=(SslSocket&& other) = delete; + + ~SslSocket() = default; + + void Start() override + { + this->AsyncHandshake(); + } + + void AsyncHandshake() + { + this->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server, + [self = this->shared_from_this()](boost::system::error_code const& error) { self->HandshakeHandler(error); }); + } + + void HandshakeHandler(boost::system::error_code const& error) + { + if (error) + { + TC_LOG_ERROR("server.http.session.ssl", "{} SSL Handshake failed {}", this->GetClientInfo(), error.message()); + this->CloseSocket(); + return; + } + + this->ResetHttpParser(); + + this->AsyncRead(); + } +}; +} + +#endif // TRINITYCORE_HTTP_SSL_SOCKET_H diff --git a/src/server/shared/Networking/NetworkThread.h b/src/server/shared/Networking/NetworkThread.h index 69d62403249..0195c48b9fc 100644 --- a/src/server/shared/Networking/NetworkThread.h +++ b/src/server/shared/Networking/NetworkThread.h @@ -77,7 +77,7 @@ public: return _connections; } - virtual void AddSocket(std::shared_ptr<SocketType> sock) + void AddSocket(std::shared_ptr<SocketType> sock) { std::lock_guard<std::mutex> lock(_newSocketsLock); diff --git a/src/server/shared/Networking/Socket.h b/src/server/shared/Networking/Socket.h index a996ecb2cbe..511f94ed366 100644 --- a/src/server/shared/Networking/Socket.h +++ b/src/server/shared/Networking/Socket.h @@ -18,14 +18,13 @@ #ifndef __SOCKET_H__ #define __SOCKET_H__ -#include "MessageBuffer.h" #include "Log.h" +#include "MessageBuffer.h" +#include <boost/asio/ip/tcp.hpp> #include <atomic> -#include <queue> #include <memory> -#include <functional> +#include <queue> #include <type_traits> -#include <boost/asio/ip/tcp.hpp> #define READ_BLOCK_SIZE 4096 #ifdef BOOST_ASIO_HAS_IOCP @@ -63,12 +62,19 @@ template<class T, class Stream = boost::asio::ip::tcp::socket> class Socket : public std::enable_shared_from_this<T> { public: - explicit Socket(boost::asio::ip::tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()), - _remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false) + template<typename... Args> + explicit Socket(boost::asio::ip::tcp::socket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...), + _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), + _closed(false), _closing(false), _isWritingAsync(false) { _readBuffer.Resize(READ_BLOCK_SIZE); } + Socket(Socket const& other) = delete; + Socket(Socket&& other) = delete; + Socket& operator=(Socket const& other) = delete; + Socket& operator=(Socket&& other) = delete; + virtual ~Socket() { _closed = true; diff --git a/src/server/shared/Networking/SslSocket.h b/src/server/shared/Networking/SslSocket.h index e00b1b6b65e..c19c8612edf 100644 --- a/src/server/shared/Networking/SslSocket.h +++ b/src/server/shared/Networking/SslSocket.h @@ -24,11 +24,11 @@ namespace boostssl = boost::asio::ssl; -template<class SslContext, class Stream = boostssl::stream<boost::asio::ip::tcp::socket>> +template<class Stream = boostssl::stream<boost::asio::ip::tcp::socket>> class SslSocket { public: - explicit SslSocket(boost::asio::ip::tcp::socket&& socket) : _sslSocket(std::move(socket), SslContext::instance()) + explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext) { _sslSocket.set_verify_mode(boostssl::verify_none); } diff --git a/src/server/worldserver/CMakeLists.txt b/src/server/worldserver/CMakeLists.txt index da3bd43c2d5..f6df3e524de 100644 --- a/src/server/worldserver/CMakeLists.txt +++ b/src/server/worldserver/CMakeLists.txt @@ -51,7 +51,8 @@ target_link_libraries(worldserver PUBLIC scripts game - readline) + readline + gsoap) CollectIncludeDirectories( ${CMAKE_CURRENT_SOURCE_DIR} |
