Core/Bnet: Rewrite LoginRESTService using boost::beast instead of gsoap as http backend and extract generic http code to be reusable elsewhere

This commit is contained in:
Shauren
2023-12-17 23:21:10 +01:00
parent 5f00ac4b2b
commit acb5fbd48b
27 changed files with 1439 additions and 1262 deletions

View File

@@ -0,0 +1,115 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "BaseHttpSocket.h"
#include <boost/beast/http/serializer.hpp>
namespace Trinity::Net::Http
{
using RequestSerializer = boost::beast::http::request_serializer<ResponseBody>;
using ResponseSerializer = boost::beast::http::response_serializer<ResponseBody>;
bool AbstractSocket::ParseRequest(MessageBuffer& packet, RequestParser& parser)
{
if (!parser.is_done())
{
// need more data in the payload
boost::system::error_code ec = {};
std::size_t readDataSize = parser.put(boost::asio::const_buffer(packet.GetReadPointer(), packet.GetActiveSize()), ec);
packet.ReadCompleted(readDataSize);
}
return parser.is_done();
}
std::string AbstractSocket::SerializeRequest(Request const& request)
{
RequestSerializer serializer(request);
std::string buffer;
while (!serializer.is_done())
{
size_t totalBytes = 0;
boost::system::error_code ec = {};
serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code const&, ConstBufferSequence const& buffers)
{
size_t totalBytesInBuffers = boost::asio::buffer_size(buffers);
buffer.reserve(buffer.size() + totalBytes);
auto begin = boost::asio::buffers_begin(buffers);
auto end = boost::asio::buffers_end(buffers);
std::copy(begin, end, std::back_inserter(buffer));
totalBytes += totalBytesInBuffers;
});
serializer.consume(totalBytes);
}
return buffer;
}
MessageBuffer AbstractSocket::SerializeResponse(Request const& request, Response& response)
{
response.prepare_payload();
ResponseSerializer serializer(response);
bool (*serializerIsDone)(ResponseSerializer&);
if (request.method() != boost::beast::http::verb::head)
{
serializerIsDone = [](ResponseSerializer& s) { return s.is_done(); };
}
else
{
serializerIsDone = [](ResponseSerializer& s) { return s.is_header_done(); };
serializer.split(true);
}
MessageBuffer buffer;
while (!serializerIsDone(serializer))
{
serializer.limit(buffer.GetRemainingSpace());
size_t totalBytes = 0;
boost::system::error_code ec = {};
serializer.next(ec, [&]<typename ConstBufferSequence>(boost::system::error_code& currentError, ConstBufferSequence const& buffers)
{
size_t totalBytesInBuffers = boost::asio::buffer_size(buffers);
if (totalBytesInBuffers > buffer.GetRemainingSpace())
{
currentError = boost::beast::http::error::need_more;
return;
}
auto begin = boost::asio::buffers_begin(buffers);
auto end = boost::asio::buffers_end(buffers);
std::copy(begin, end, buffer.GetWritePointer());
buffer.WriteCompleted(totalBytesInBuffers);
totalBytes += totalBytesInBuffers;
});
serializer.consume(totalBytes);
if (ec == boost::beast::http::error::need_more)
buffer.Resize(buffer.GetBufferSize() + 4096);
}
return buffer;
}
}

View File

@@ -0,0 +1,191 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_BASE_HTTP_SOCKET_H
#define TRINITYCORE_BASE_HTTP_SOCKET_H
#include "AsyncCallbackProcessor.h"
#include "DatabaseEnvFwd.h"
#include "HttpCommon.h"
#include "HttpSessionState.h"
#include "Optional.h"
#include "QueryCallback.h"
#include "Socket.h"
#include <boost/asio/buffers_iterator.hpp>
#include <boost/beast/http/parser.hpp>
#include <boost/beast/http/string_body.hpp>
#include <boost/uuid/uuid_io.hpp>
namespace Trinity::Net::Http
{
using RequestParser = boost::beast::http::request_parser<RequestBody>;
class TC_SHARED_API AbstractSocket
{
public:
AbstractSocket() = default;
AbstractSocket(AbstractSocket const& other) = default;
AbstractSocket(AbstractSocket&& other) = default;
AbstractSocket& operator=(AbstractSocket const& other) = default;
AbstractSocket& operator=(AbstractSocket&& other) = default;
virtual ~AbstractSocket() = default;
static bool ParseRequest(MessageBuffer& packet, RequestParser& parser);
static std::string SerializeRequest(Request const& request);
static MessageBuffer SerializeResponse(Request const& request, Response& response);
virtual void SendResponse(RequestContext& context) = 0;
virtual void QueueQuery(QueryCallback&& queryCallback) = 0;
virtual std::string GetClientInfo() const = 0;
virtual Optional<boost::uuids::uuid> GetSessionId() const = 0;
};
template<typename Derived, typename Stream>
class BaseSocket : public ::Socket<Derived, Stream>, public AbstractSocket
{
using Base = ::Socket<Derived, Stream>;
public:
template<typename... Args>
explicit BaseSocket(boost::asio::ip::tcp::socket&& socket, Args&&... args)
: Base(std::move(socket), std::forward<Args>(args)...) { }
BaseSocket(BaseSocket const& other) = delete;
BaseSocket(BaseSocket&& other) = delete;
BaseSocket& operator=(BaseSocket const& other) = delete;
BaseSocket& operator=(BaseSocket&& other) = delete;
~BaseSocket() = default;
void ReadHandler() override
{
if (!this->IsOpen())
return;
MessageBuffer& packet = this->GetReadBuffer();
while (packet.GetActiveSize() > 0)
{
if (!ParseRequest(packet, *_httpParser))
{
// Couldn't receive the whole data this time.
break;
}
if (!HandleMessage(_httpParser->get()))
{
this->CloseSocket();
break;
}
this->ResetHttpParser();
}
this->AsyncRead();
}
bool HandleMessage(Request& request)
{
RequestContext context { .request = std::move(request) };
if (!_state)
_state = this->ObtainSessionState(context);
RequestHandlerResult status = this->RequestHandler(context);
if (status != RequestHandlerResult::Async)
this->SendResponse(context);
return status != RequestHandlerResult::Error;
}
virtual RequestHandlerResult RequestHandler(RequestContext& context) = 0;
void SendResponse(RequestContext& context) override
{
MessageBuffer buffer = SerializeResponse(context.request, context.response);
TC_LOG_DEBUG("server.http", "{} Request {} {} done, status {}", this->GetClientInfo(), ToStdStringView(context.request.method_string()),
ToStdStringView(context.request.target()), context.response.result_int());
if (sLog->ShouldLog("server.http", LOG_LEVEL_TRACE))
{
sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Request: ", this->GetClientInfo(),
CanLogRequestContent(context) ? SerializeRequest(context.request) : "<REDACTED>");
sLog->OutMessage("server.http", LOG_LEVEL_TRACE, "{} Response: ", this->GetClientInfo(),
CanLogResponseContent(context) ? std::string_view(reinterpret_cast<char const*>(buffer.GetBasePointer()), buffer.GetActiveSize()) : "<REDACTED>");
}
this->QueuePacket(std::move(buffer));
if (!context.response.keep_alive())
this->DelayedCloseSocket();
}
void QueueQuery(QueryCallback&& queryCallback) override
{
this->_queryProcessor.AddCallback(std::move(queryCallback));
}
bool Update() override
{
if (!this->Base::Update())
return false;
this->_queryProcessor.ProcessReadyCallbacks();
return true;
}
std::string GetClientInfo() const override
{
std::string info;
info.reserve(500);
auto itr = StringFormatTo(std::back_inserter(info), "[{}:{}", this->GetRemoteIpAddress().to_string(), this->GetRemotePort());
if (_state)
itr = StringFormatTo(itr, ", Session Id: {}", boost::uuids::to_string(_state->Id));
StringFormatTo(itr, "]");
return info;
}
Optional<boost::uuids::uuid> GetSessionId() const final
{
if (this->_state)
return this->_state->Id;
return {};
}
protected:
void ResetHttpParser()
{
this->_httpParser.reset();
this->_httpParser.emplace();
this->_httpParser->eager(true);
}
virtual std::shared_ptr<SessionState> ObtainSessionState(RequestContext& context) const = 0;
QueryCallbackProcessor _queryProcessor;
Optional<RequestParser> _httpParser;
std::shared_ptr<SessionState> _state;
};
}
#endif // TRINITYCORE_BASE_HTTP_SOCKET_H

View File

@@ -0,0 +1,55 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_HTTP_COMMON_H
#define TRINITYCORE_HTTP_COMMON_H
#include "Define.h"
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/string_body.hpp>
namespace Trinity::Net::Http
{
using RequestBody = boost::beast::http::string_body;
using ResponseBody = boost::beast::http::string_body;
using Request = boost::beast::http::request<RequestBody>;
using Response = boost::beast::http::response<ResponseBody>;
struct RequestContext
{
Request request;
Response response;
struct RequestHandler const* handler = nullptr;
};
TC_SHARED_API bool CanLogRequestContent(RequestContext const& context);
TC_SHARED_API bool CanLogResponseContent(RequestContext const& context);
inline std::string_view ToStdStringView(boost::beast::string_view bsw)
{
return { bsw.data(), bsw.size() };
}
enum class RequestHandlerResult
{
Handled,
Error,
Async,
};
}
#endif // TRINITYCORE_HTTP_COMMON_H

View File

@@ -0,0 +1,258 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "HttpService.h"
#include "BaseHttpSocket.h"
#include "CryptoRandom.h"
#include "Timezone.h"
#include <boost/beast/version.hpp>
#include <boost/uuid/string_generator.hpp>
#include <fmt/chrono.h>
namespace Trinity::Net::Http
{
bool CanLogRequestContent(RequestContext const& context)
{
return !context.handler || !context.handler->Flags.HasFlag(RequestHandlerFlag::DoNotLogRequestContent);
}
bool CanLogResponseContent(RequestContext const& context)
{
return !context.handler || !context.handler->Flags.HasFlag(RequestHandlerFlag::DoNotLogResponseContent);
}
RequestHandlerResult DispatcherService::HandleRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context)
{
TC_LOG_DEBUG(_logger, "{} Starting request {} {}", session->GetClientInfo(),
ToStdStringView(context.request.method_string()), ToStdStringView(context.request.target()));
std::string_view path = [&]
{
std::string_view path = ToStdStringView(context.request.target());
size_t queryIndex = path.find('?');
if (queryIndex != std::string_view::npos)
path = path.substr(0, queryIndex);
return path;
}();
context.handler = [&]() -> HttpMethodHandlerMap::mapped_type const*
{
switch (context.request.method())
{
case boost::beast::http::verb::get:
case boost::beast::http::verb::head:
{
auto itr = _getHandlers.find(path);
return itr != _getHandlers.end() ? &itr->second : nullptr;
}
case boost::beast::http::verb::post:
{
auto itr = _postHandlers.find(path);
return itr != _postHandlers.end() ? &itr->second : nullptr;
}
default:
break;
}
return nullptr;
}();
SystemTimePoint responseDate = SystemTimePoint::clock::now();
context.response.set(boost::beast::http::field::date, StringFormat("{:%a, %d %b %Y %T GMT}", responseDate - Timezone::GetSystemZoneOffsetAt(responseDate)));
context.response.set(boost::beast::http::field::server, BOOST_BEAST_VERSION_STRING);
context.response.keep_alive(context.response.keep_alive());
if (!context.handler)
return HandlePathNotFound(std::move(session), context);
return context.handler->Func(std::move(session), context);
}
RequestHandlerResult DispatcherService::HandlePathNotFound(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context)
{
context.response.result(boost::beast::http::status::not_found);
return RequestHandlerResult::Handled;
}
RequestHandlerResult DispatcherService::HandleUnauthorized(std::shared_ptr<AbstractSocket> /*session*/, RequestContext& context)
{
context.response.result(boost::beast::http::status::unauthorized);
return RequestHandlerResult::Handled;
}
void DispatcherService::RegisterHandler(boost::beast::http::verb method, std::string_view path,
std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> handler,
RequestHandlerFlag flags)
{
HttpMethodHandlerMap& handlerMap = [&]() -> HttpMethodHandlerMap&
{
switch (method)
{
case boost::beast::http::verb::get:
return _getHandlers;
case boost::beast::http::verb::post:
return _postHandlers;
default:
{
std::string_view methodString = ToStdStringView(boost::beast::http::to_string(method));
ABORT_MSG("Tried to register a handler for unsupported HTTP method " STRING_VIEW_FMT, STRING_VIEW_FMT_ARG(methodString));
}
}
}();
handlerMap[std::string(path)] = { .Func = std::move(handler), .Flags = flags };
TC_LOG_INFO(_logger, "Registered new handler for {} {}", ToStdStringView(boost::beast::http::to_string(method)), path);
}
void SessionService::InitAndStoreSessionState(std::shared_ptr<SessionState> state, boost::asio::ip::address const& address)
{
state->RemoteAddress = address;
// Generate session id
{
std::unique_lock lock{ _sessionsMutex };
while (state->Id.is_nil() || _sessions.contains(state->Id))
std::copy_n(Trinity::Crypto::GetRandomBytes<16>().begin(), 16, state->Id.begin());
TC_LOG_DEBUG(_logger, "Client at {} created new session {}", address.to_string(), boost::uuids::to_string(state->Id));
_sessions[state->Id] = std::move(state);
}
}
void SessionService::Start(Asio::IoContext& ioContext)
{
_inactiveSessionsKillTimer = std::make_unique<Asio::DeadlineTimer>(ioContext);
_inactiveSessionsKillTimer->expires_from_now(boost::posix_time::minutes(1));
_inactiveSessionsKillTimer->async_wait([this](boost::system::error_code const& err)
{
if (err)
return;
KillInactiveSessions();
});
}
void SessionService::Stop()
{
_inactiveSessionsKillTimer = nullptr;
{
std::unique_lock lock{ _sessionsMutex };
_sessions.clear();
}
{
std::unique_lock lock{ _inactiveSessionsMutex };
_inactiveSessions.clear();
}
}
std::shared_ptr<SessionState> SessionService::FindAndRefreshSessionState(std::string_view id, boost::asio::ip::address const& address)
{
std::shared_ptr<SessionState> state;
{
std::shared_lock lock{ _sessionsMutex };
auto itr = _sessions.find(boost::uuids::string_generator()(id.begin(), id.end()));
if (itr == _sessions.end())
{
TC_LOG_DEBUG(_logger, "Client at {} attempted to use a session {} that was expired", address.to_string(), id);
return nullptr; // no session
}
state = itr->second;
}
if (state->RemoteAddress != address)
{
TC_LOG_ERROR(_logger, "Client at {} attempted to use a session {} that was last accessed from {}, denied access",
address.to_string(), id, state->RemoteAddress.to_string());
return nullptr;
}
{
std::unique_lock inactiveSessionsLock{ _inactiveSessionsMutex };
_inactiveSessions.erase(state->Id);
}
return state;
}
void SessionService::MarkSessionInactive(boost::uuids::uuid const& id)
{
{
std::unique_lock inactiveSessionsLock{ _inactiveSessionsMutex };
_inactiveSessions.insert(id);
}
{
auto itr = _sessions.find(id);
if (itr != _sessions.end())
{
itr->second->InactiveTimestamp = TimePoint::clock::now() + Minutes(5);
TC_LOG_TRACE(_logger, "Session {} marked as inactive", boost::uuids::to_string(id));
}
}
}
void SessionService::KillInactiveSessions()
{
std::set<boost::uuids::uuid> inactiveSessions;
{
std::unique_lock lock{ _inactiveSessionsMutex };
std::swap(_inactiveSessions, inactiveSessions);
}
{
TimePoint now = TimePoint::clock::now();
std::size_t inactiveSessionsCount = inactiveSessions.size();
std::unique_lock lock{ _sessionsMutex };
for (auto itr = inactiveSessions.begin(); itr != inactiveSessions.end(); )
{
auto sessionItr = _sessions.find(*itr);
if (sessionItr == _sessions.end() || sessionItr->second->InactiveTimestamp < now)
{
_sessions.erase(sessionItr);
itr = inactiveSessions.erase(itr);
}
else
++itr;
}
TC_LOG_DEBUG(_logger, "Killed {} inactive sessions", inactiveSessionsCount - inactiveSessions.size());
}
{
// restore sessions not killed to inactive queue
std::unique_lock lock{ _inactiveSessionsMutex };
for (auto itr = inactiveSessions.begin(); itr != inactiveSessions.end(); )
{
auto node = inactiveSessions.extract(itr++);
_inactiveSessions.insert(std::move(node));
}
}
_inactiveSessionsKillTimer->expires_from_now(boost::posix_time::minutes(1));
_inactiveSessionsKillTimer->async_wait([this](boost::system::error_code const& err)
{
if (err)
return;
KillInactiveSessions();
});
}
}

View File

@@ -0,0 +1,188 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_HTTP_SERVICE_H
#define TRINITYCORE_HTTP_SERVICE_H
#include "AsioHacksFwd.h"
#include "Concepts.h"
#include "Define.h"
#include "EnumFlag.h"
#include "HttpCommon.h"
#include "HttpSessionState.h"
#include "Optional.h"
#include "SocketMgr.h"
#include <boost/uuid/uuid.hpp>
#include <functional>
#include <map>
#include <set>
#include <shared_mutex>
namespace Trinity::Net::Http
{
class AbstractSocket;
enum class RequestHandlerFlag
{
None = 0x0,
DoNotLogRequestContent = 0x1,
DoNotLogResponseContent = 0x2,
};
DEFINE_ENUM_FLAG(RequestHandlerFlag);
struct RequestHandler
{
std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> Func;
EnumFlag<RequestHandlerFlag> Flags = RequestHandlerFlag::None;
};
class TC_SHARED_API DispatcherService
{
public:
explicit DispatcherService(std::string_view loggerSuffix) : _logger("server.http.dispatcher.")
{
_logger.append(loggerSuffix);
}
RequestHandlerResult HandleRequest(std::shared_ptr<AbstractSocket> session, RequestContext& context);
RequestHandlerResult HandlePathNotFound(std::shared_ptr<AbstractSocket> session, RequestContext& context);
RequestHandlerResult HandleUnauthorized(std::shared_ptr<AbstractSocket> session, RequestContext& context);
protected:
void RegisterHandler(boost::beast::http::verb method, std::string_view path,
std::function<RequestHandlerResult(std::shared_ptr<AbstractSocket> session, RequestContext& context)> handler,
RequestHandlerFlag flags = RequestHandlerFlag::None);
private:
using HttpMethodHandlerMap = std::map<std::string, RequestHandler, std::less<>>;
HttpMethodHandlerMap _getHandlers;
HttpMethodHandlerMap _postHandlers;
std::string _logger;
};
class TC_SHARED_API SessionService
{
public:
explicit SessionService(std::string_view loggerSuffix) : _logger("server.http.session.")
{
_logger.append(loggerSuffix);
}
void Start(Asio::IoContext& ioContext);
void Stop();
std::shared_ptr<SessionState> FindAndRefreshSessionState(std::string_view id, boost::asio::ip::address const& address);
void MarkSessionInactive(boost::uuids::uuid const& id);
protected:
void InitAndStoreSessionState(std::shared_ptr<SessionState> state, boost::asio::ip::address const& address);
void KillInactiveSessions();
private:
std::shared_mutex _sessionsMutex;
std::map<boost::uuids::uuid, std::shared_ptr<SessionState>> _sessions;
std::mutex _inactiveSessionsMutex;
std::set<boost::uuids::uuid> _inactiveSessions;
std::unique_ptr<Asio::DeadlineTimer> _inactiveSessionsKillTimer;
std::string _logger;
};
template<typename Callable, typename SessionImpl>
concept HttpRequestHandler = invocable_r<Callable, RequestHandlerResult, std::shared_ptr<SessionImpl>, RequestContext&>;
template<typename SessionImpl>
class HttpService : public SocketMgr<SessionImpl>, public DispatcherService, public SessionService
{
public:
HttpService(std::string_view loggerSuffix) : DispatcherService(loggerSuffix), SessionService(loggerSuffix), _ioContext(nullptr), _logger("server.http.")
{
_logger.append(loggerSuffix);
}
bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int32 threadCount = 1) override
{
if (!SocketMgr<SessionImpl>::StartNetwork(ioContext, bindIp, port, threadCount))
return false;
SessionService::Start(ioContext);
return true;
}
void StopNetwork() override
{
SessionService::Stop();
SocketMgr<SessionImpl>::StopNetwork();
}
// http handling
using DispatcherService::RegisterHandler;
template<HttpRequestHandler<SessionImpl> Callable>
void RegisterHandler(boost::beast::http::verb method, std::string_view path, Callable handler, RequestHandlerFlag flags = RequestHandlerFlag::None)
{
this->DispatcherService::RegisterHandler(method, path, [handler = std::move(handler)](std::shared_ptr<AbstractSocket> session, RequestContext& context) -> RequestHandlerResult
{
return handler(std::static_pointer_cast<SessionImpl>(std::move(session)), context);
}, flags);
}
// session tracking
virtual std::shared_ptr<SessionState> CreateNewSessionState(boost::asio::ip::address const& address)
{
std::shared_ptr<SessionState> state = std::make_shared<SessionState>();
InitAndStoreSessionState(state, address);
return state;
}
protected:
class Thread : public NetworkThread<SessionImpl>
{
protected:
void SocketRemoved(std::shared_ptr<SessionImpl> session) override
{
if (Optional<boost::uuids::uuid> id = session->GetSessionId())
_service->MarkSessionInactive(*id);
}
private:
friend HttpService;
SessionService* _service;
};
NetworkThread<SessionImpl>* CreateThreads() const override
{
Thread* threads = new Thread[this->GetNetworkThreadCount()];
for (int32 i = 0; i < this->GetNetworkThreadCount(); ++i)
threads[i]._service = const_cast<HttpService*>(this);
return threads;
}
private:
Asio::IoContext* _ioContext;
std::string _logger;
};
}
#endif // TRINITYCORE_HTTP_SERVICE_H

View File

@@ -0,0 +1,35 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_HTTP_SESSION_STATE_H
#define TRINITYCORE_HTTP_SESSION_STATE_H
#include "Duration.h"
#include <boost/asio/ip/address.hpp>
#include <boost/uuid/uuid.hpp>
namespace Trinity::Net::Http
{
struct SessionState
{
boost::uuids::uuid Id = { };
boost::asio::ip::address RemoteAddress;
TimePoint InactiveTimestamp = TimePoint::max();
};
}
#endif // TRINITYCORE_HTTP_SESSION_STATE_H

View File

@@ -0,0 +1,75 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_HTTP_SOCKET_H
#define TRINITYCORE_HTTP_SOCKET_H
#include "BaseHttpSocket.h"
#include <boost/beast/core/tcp_stream.hpp>
namespace Trinity::Net::Http
{
namespace Impl
{
class BoostBeastSocketWrapper : public boost::beast::tcp_stream
{
public:
using boost::beast::tcp_stream::tcp_stream;
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
{
socket().shutdown(what, shutdownError);
}
void close(boost::system::error_code& /*error*/)
{
boost::beast::tcp_stream::close();
}
boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const
{
return socket().remote_endpoint();
}
};
}
template <typename Derived>
class Socket : public BaseSocket<Derived, Impl::BoostBeastSocketWrapper>
{
using SocketBase = BaseSocket<Derived, Impl::BoostBeastSocketWrapper>;
public:
explicit Socket(boost::asio::ip::tcp::socket&& socket)
: SocketBase(std::move(socket)) { }
Socket(Socket const& other) = delete;
Socket(Socket&& other) = delete;
Socket& operator=(Socket const& other) = delete;
Socket& operator=(Socket&& other) = delete;
~Socket() = default;
void Start() override
{
this->ResetHttpParser();
this->AsyncRead();
}
};
}
#endif // TRINITYCORE_HTTP_SOCKET_H

View File

@@ -0,0 +1,97 @@
/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef TRINITYCORE_HTTP_SSL_SOCKET_H
#define TRINITYCORE_HTTP_SSL_SOCKET_H
#include "BaseHttpSocket.h"
#include "SslSocket.h"
#include <boost/beast/core/stream_traits.hpp>
#include <boost/beast/core/tcp_stream.hpp>
#include <boost/beast/ssl/ssl_stream.hpp>
namespace Trinity::Net::Http
{
namespace Impl
{
class BoostBeastSslSocketWrapper : public ::SslSocket<boost::beast::ssl_stream<boost::beast::tcp_stream>>
{
public:
using SslSocket::SslSocket;
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
{
_sslSocket.shutdown(shutdownError);
boost::beast::get_lowest_layer(_sslSocket).socket().shutdown(what, shutdownError);
}
void close(boost::system::error_code& /*error*/)
{
boost::beast::get_lowest_layer(_sslSocket).close();
}
boost::asio::ip::tcp::socket::endpoint_type remote_endpoint() const
{
return boost::beast::get_lowest_layer(_sslSocket).socket().remote_endpoint();
}
};
}
template <typename Derived>
class SslSocket : public BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper>
{
using SocketBase = BaseSocket<Derived, Impl::BoostBeastSslSocketWrapper>;
public:
explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext)
: SocketBase(std::move(socket), sslContext) { }
SslSocket(SslSocket const& other) = delete;
SslSocket(SslSocket&& other) = delete;
SslSocket& operator=(SslSocket const& other) = delete;
SslSocket& operator=(SslSocket&& other) = delete;
~SslSocket() = default;
void Start() override
{
this->AsyncHandshake();
}
void AsyncHandshake()
{
this->underlying_stream().async_handshake(boost::asio::ssl::stream_base::server,
[self = this->shared_from_this()](boost::system::error_code const& error) { self->HandshakeHandler(error); });
}
void HandshakeHandler(boost::system::error_code const& error)
{
if (error)
{
TC_LOG_ERROR("server.http.session.ssl", "{} SSL Handshake failed {}", this->GetClientInfo(), error.message());
this->CloseSocket();
return;
}
this->ResetHttpParser();
this->AsyncRead();
}
};
}
#endif // TRINITYCORE_HTTP_SSL_SOCKET_H

View File

@@ -77,7 +77,7 @@ public:
return _connections;
}
virtual void AddSocket(std::shared_ptr<SocketType> sock)
void AddSocket(std::shared_ptr<SocketType> sock)
{
std::lock_guard<std::mutex> lock(_newSocketsLock);

View File

@@ -18,14 +18,13 @@
#ifndef __SOCKET_H__
#define __SOCKET_H__
#include "MessageBuffer.h"
#include "Log.h"
#include <atomic>
#include <queue>
#include <memory>
#include <functional>
#include <type_traits>
#include "MessageBuffer.h"
#include <boost/asio/ip/tcp.hpp>
#include <atomic>
#include <memory>
#include <queue>
#include <type_traits>
#define READ_BLOCK_SIZE 4096
#ifdef BOOST_ASIO_HAS_IOCP
@@ -63,12 +62,19 @@ template<class T, class Stream = boost::asio::ip::tcp::socket>
class Socket : public std::enable_shared_from_this<T>
{
public:
explicit Socket(boost::asio::ip::tcp::socket&& socket) : _socket(std::move(socket)), _remoteAddress(_socket.remote_endpoint().address()),
_remotePort(_socket.remote_endpoint().port()), _readBuffer(), _closed(false), _closing(false), _isWritingAsync(false)
template<typename... Args>
explicit Socket(boost::asio::ip::tcp::socket&& socket, Args&&... args) : _socket(std::move(socket), std::forward<Args>(args)...),
_remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()),
_closed(false), _closing(false), _isWritingAsync(false)
{
_readBuffer.Resize(READ_BLOCK_SIZE);
}
Socket(Socket const& other) = delete;
Socket(Socket&& other) = delete;
Socket& operator=(Socket const& other) = delete;
Socket& operator=(Socket&& other) = delete;
virtual ~Socket()
{
_closed = true;

View File

@@ -24,11 +24,11 @@
namespace boostssl = boost::asio::ssl;
template<class SslContext, class Stream = boostssl::stream<boost::asio::ip::tcp::socket>>
template<class Stream = boostssl::stream<boost::asio::ip::tcp::socket>>
class SslSocket
{
public:
explicit SslSocket(boost::asio::ip::tcp::socket&& socket) : _sslSocket(std::move(socket), SslContext::instance())
explicit SslSocket(boost::asio::ip::tcp::socket&& socket, boost::asio::ssl::context& sslContext) : _sslSocket(std::move(socket), sslContext)
{
_sslSocket.set_verify_mode(boostssl::verify_none);
}