/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see .
*/
#ifndef TRINITYCORE_BASE_HTTP_SOCKET_H
#define TRINITYCORE_BASE_HTTP_SOCKET_H
#include "HttpCommon.h"
#include "HttpSessionState.h"
#include "Optional.h"
#include "Socket.h"
#include "SocketConnectionInitializer.h"
#include
#include
#include
namespace Trinity::Net::Http
{
using IoContextHttpSocket = boost::beast::basic_stream;
namespace Impl
{
class BoostBeastSocketWrapper : public IoContextHttpSocket
{
public:
using IoContextHttpSocket::basic_stream;
bool is_open() const
{
return socket().is_open();
}
void close(boost::system::error_code& /*error*/)
{
IoContextHttpSocket::close();
}
void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError)
{
socket().shutdown(what, shutdownError);
}
template
void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler)
{
socket().async_wait(type, std::forward(handler));
}
template
void set_option(SettableSocketOption const& option, boost::system::error_code& ec)
{
socket().set_option(option, ec);
}
IoContextTcpSocket::endpoint_type remote_endpoint() const
{
return socket().remote_endpoint();
}
};
}
using RequestParser = boost::beast::http::request_parser;
using ResponseParser = boost::beast::http::response_parser;
class TC_NETWORK_API AbstractSocket
{
public:
AbstractSocket() = default;
AbstractSocket(AbstractSocket const& other) = default;
AbstractSocket(AbstractSocket&& other) = default;
AbstractSocket& operator=(AbstractSocket const& other) = default;
AbstractSocket& operator=(AbstractSocket&& other) = default;
virtual ~AbstractSocket() = default;
static bool ParseRequest(MessageBuffer& packet, RequestParser& parser);
static bool ParseResponse(MessageBuffer& packet, ResponseParser& parser);
static MessageBuffer SerializeRequest(Request const& request);
static MessageBuffer SerializeResponse(Request const& request, Response const& response);
virtual void SendResponse(RequestContext& context) = 0;
void LogRequestAndResponse(RequestContext const& context, MessageBuffer& buffer) const;
virtual std::string GetClientInfo() const = 0;
static std::string GetClientInfo(boost::asio::ip::address const& address, uint16 port, SessionState const* state);
virtual SessionState* GetSessionState() const = 0;
Optional GetSessionId() const
{
if (SessionState* state = this->GetSessionState())
return state->Id;
return {};
}
virtual void Start() = 0;
virtual bool Update() = 0;
virtual boost::asio::ip::address const& GetRemoteIpAddress() const = 0;
virtual bool IsOpen() const = 0;
virtual void CloseSocket() = 0;
};
template
struct HttpConnectionInitializer final : SocketConnectionInitializer
{
explicit HttpConnectionInitializer(SocketImpl* socket) : _socket(socket) { }
void Start() override
{
_socket->ResetHttpParser();
this->InvokeNext();
}
private:
SocketImpl* _socket;
};
template
class BaseSocket : public Trinity::Net::Socket, public AbstractSocket
{
using Base = Trinity::Net::Socket;
public:
using Base::Base;
BaseSocket(BaseSocket const& other) = delete;
BaseSocket(BaseSocket&& other) = delete;
BaseSocket& operator=(BaseSocket const& other) = delete;
BaseSocket& operator=(BaseSocket&& other) = delete;
~BaseSocket() = default;
SocketReadCallbackResult ReadHandler() final
{
MessageBuffer& packet = this->GetReadBuffer();
while (packet.GetActiveSize() > 0)
{
if (!ParseRequest(packet, *_httpParser))
{
// Couldn't receive the whole data this time.
break;
}
if (!HandleMessage(_httpParser->get()))
{
this->CloseSocket();
return SocketReadCallbackResult::Stop;
}
this->ResetHttpParser();
}
return SocketReadCallbackResult::KeepReading;
}
bool HandleMessage(Request& request)
{
RequestContext context { .request = std::move(request) };
if (!_state)
_state = this->ObtainSessionState(context);
RequestHandlerResult status = this->RequestHandler(context);
if (status != RequestHandlerResult::Async)
this->SendResponse(context);
return status != RequestHandlerResult::Error;
}
virtual RequestHandlerResult RequestHandler(RequestContext& context) = 0;
void SendResponse(RequestContext& context) final
{
context.response.prepare_payload();
MessageBuffer buffer = SerializeResponse(context.request, context.response);
this->LogRequestAndResponse(context, buffer);
this->QueuePacket(std::move(buffer));
if (!context.response.keep_alive())
this->DelayedCloseSocket();
}
void Start() override { return this->Base::Start(); }
bool Update() override { return this->Base::Update(); }
boost::asio::ip::address const& GetRemoteIpAddress() const final { return this->Base::GetRemoteIpAddress(); }
bool IsOpen() const final { return this->Base::IsOpen(); }
void CloseSocket() final { return this->Base::CloseSocket(); }
std::string GetClientInfo() const override
{
return AbstractSocket::GetClientInfo(this->GetRemoteIpAddress(), this->GetRemotePort(), this->_state.get());
}
SessionState* GetSessionState() const override { return _state.get(); }
void ResetHttpParser()
{
this->_httpParser.reset();
this->_httpParser.emplace();
this->_httpParser->eager(true);
}
protected:
virtual std::shared_ptr ObtainSessionState(RequestContext& context) const = 0;
Optional _httpParser;
std::shared_ptr _state;
};
}
#endif // TRINITYCORE_BASE_HTTP_SOCKET_H