/*
 * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2 of the License, or (at your
 * option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
 * more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program. If not, see .
 */
#ifndef TRINITYCORE_SOCKET_H
#define TRINITYCORE_SOCKET_H
#include "Concepts.h"
#include "IpAddress.h"
#include "Log.h"
#include "MessageBuffer.h"
#include "SocketConnectionInitializer.h"
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#ifdef BOOST_ASIO_HAS_IOCP
#define TC_SOCKET_USE_IOCP
#endif
namespace Trinity::Net
{
using IoContextTcpSocket = boost::asio::basic_stream_socket;
namespace Impl::Operations
{
template 
struct Connect;
}
enum class SocketReadCallbackResult
{
    KeepReading,
    Stop
};
inline boost::asio::mutable_buffer PrepareReadBuffer(MessageBuffer& readBuffer)
{
    readBuffer.Normalize();
    readBuffer.EnsureFreeSpace();
    return boost::asio::buffer(readBuffer.GetWritePointer(), readBuffer.GetRemainingSpace());
}
template 
struct InvokeReadHandlerCallback
{
    SocketReadCallbackResult operator()() const
    {
        return this->Socket->ReadHandler();
    }
    SocketType* Socket;
};
template 
struct ReadConnectionInitializer final : SocketConnectionInitializer
{
    explicit ReadConnectionInitializer(AsyncReadObjectType* socket) : Socket(socket), ReadCallback({ .Socket = socket }) { }
    explicit ReadConnectionInitializer(AsyncReadObjectType* socket, ReadHandlerObjectType* callbackSocket) : Socket(socket), ReadCallback({ .Socket = callbackSocket }) { }
    void Start() override
    {
        Socket->AsyncRead(std::move(ReadCallback));
        this->InvokeNext();
    }
    AsyncReadObjectType* Socket;
    InvokeReadHandlerCallback ReadCallback;
};
/**
    @class Socket
    Base async socket implementation
    @tparam Stream stream type used for operations on socket
            Stream must implement the following methods:
            boost::asio::io_context::executor_type get_executor();
            bool is_open() const;
            void close(boost::system::error_code& error);
            void shutdown(boost::asio::socket_base::shutdown_type what, boost::system::error_code& shutdownError);
            template
            void async_connect(boost::asio::ip::tcp::endpoint const& endpoint, ConnectHandlerType&& handler);
            template
            void async_read_some(MutableBufferSequence const& buffers, ReadHandlerType&& handler);
            template
            void async_write_some(ConstBufferSequence const& buffers, WriteHandlerType&& handler);
            template
            std::size_t write_some(ConstBufferSequence const& buffers, boost::system::error_code& error);
            template
            void async_wait(boost::asio::socket_base::wait_type type, WaitHandlerType&& handler);
            template
            void set_option(SettableSocketOption const& option, boost::system::error_code& error);
            tcp::socket::endpoint_type remote_endpoint() const;
*/
template
class Socket : public std::enable_shared_from_this>
{
public:
    template
    explicit Socket(IoContextTcpSocket&& socket, Args&&... args) : _socket(std::move(socket), std::forward(args)...),
        _remoteAddress(_socket.remote_endpoint().address()), _remotePort(_socket.remote_endpoint().port()), _openState(OpenState_Open)
    {
    }
    template
    explicit Socket(boost::asio::io_context& context, Args&&... args) : _socket(context, std::forward(args)...), _openState(OpenState_Closed)
    {
    }
    Socket(Socket const& other) = delete;
    Socket(Socket&& other) = delete;
    Socket& operator=(Socket const& other) = delete;
    Socket& operator=(Socket&& other) = delete;
    virtual ~Socket()
    {
        _openState = OpenState_Closed;
        boost::system::error_code error;
        _socket.close(error);
    }
    virtual void Start() { }
    template 
    decltype(auto) Connect(boost::asio::ip::tcp::endpoint const& endpoint, Callback&& callback)
    {
        _openState = OpenState_Open;
        return boost::asio::async_compose>(
            Impl::Operations::Connect(this->shared_from_this(), endpoint), callback, this->underlying_stream());
    }
    template 
    decltype(auto) Connect(std::vector const& endpoints, Callback&& callback)
    {
        _openState = OpenState_Open;
        return boost::asio::async_compose>(
            Impl::Operations::Connect(this->shared_from_this(), endpoints), callback, this->underlying_stream());
    }
    virtual bool Update()
    {
        if (_openState == OpenState_Closed)
            return false;
#ifndef TC_SOCKET_USE_IOCP
        if (_isWritingAsync || (_writeQueue.empty() && _openState == OpenState_Open))
            return true;
        for (; HandleQueue();)
            ;
#endif
        return true;
    }
    boost::asio::ip::address const& GetRemoteIpAddress() const
    {
        return _remoteAddress;
    }
    uint16 GetRemotePort() const
    {
        return _remotePort;
    }
    void SetRemoteEndpoint(boost::asio::ip::tcp::endpoint const& endpoint)
    {
        _remoteAddress = endpoint.address();
        _remotePort = endpoint.port();
    }
    template  Callback>
    void AsyncRead(Callback&& callback)
    {
        if (!IsOpen())
            return;
        _socket.async_read_some(PrepareReadBuffer(_readBuffer),
            [self = this->shared_from_this(), callback = std::forward(callback)](boost::system::error_code const& error, size_t transferredBytes) mutable
            {
                if (self->ReadHandlerInternal(error, transferredBytes))
                    if (callback() == SocketReadCallbackResult::KeepReading)
                        self->AsyncRead(std::forward(callback));
            });
    }
    void QueuePacket(MessageBuffer&& buffer)
    {
        _writeQueue.push(std::move(buffer));
#ifdef TC_SOCKET_USE_IOCP
        AsyncProcessQueue();
#endif
    }
    bool IsOpen() const { return _openState == OpenState_Open; }
    void CloseSocket()
    {
        if ((_openState.fetch_or(OpenState_Closed) & OpenState_Closed) == 0)
            return;
        boost::system::error_code shutdownError;
        _socket.shutdown(boost::asio::socket_base::shutdown_send, shutdownError);
        if (shutdownError)
            TC_LOG_DEBUG("network", "Socket::CloseSocket: {} errored when shutting down socket: {} ({})", GetRemoteIpAddress(),
                shutdownError.value(), shutdownError.message());
        this->OnClose();
    }
    /// Marks the socket for closing after write buffer becomes empty
    void DelayedCloseSocket()
    {
        if (_openState.fetch_or(OpenState_Closing) != 0)
            return;
        if (_writeQueue.empty())
            CloseSocket();
    }
    MessageBuffer& GetReadBuffer() { return _readBuffer; }
    Stream& underlying_stream()
    {
        return _socket;
    }
protected:
    virtual void OnClose() { }
    virtual SocketReadCallbackResult ReadHandler() { return SocketReadCallbackResult::KeepReading; }
    bool AsyncProcessQueue()
    {
        if (_isWritingAsync)
            return false;
        _isWritingAsync = true;
#ifdef TC_SOCKET_USE_IOCP
        MessageBuffer& buffer = _writeQueue.front();
        _socket.async_write_some(boost::asio::buffer(buffer.GetReadPointer(), buffer.GetActiveSize()),
            [self = this->shared_from_this()](boost::system::error_code const& error, std::size_t transferedBytes)
            {
                self->WriteHandler(error, transferedBytes);
            });
#else
        _socket.async_wait(boost::asio::socket_base::wait_type::wait_write,
            [self = this->shared_from_this()](boost::system::error_code const& error)
            {
                self->WriteHandlerWrapper(error);
            });
#endif
        return false;
    }
    void SetNoDelay(bool enable)
    {
        boost::system::error_code err;
        _socket.set_option(boost::asio::ip::tcp::no_delay(enable), err);
        if (err)
            TC_LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for {} - {} ({})",
                GetRemoteIpAddress(), err.value(), err.message());
    }
private:
    bool ReadHandlerInternal(boost::system::error_code const& error, size_t transferredBytes)
    {
        if (error)
        {
            CloseSocket();
            return false;
        }
        _readBuffer.WriteCompleted(transferredBytes);
        return IsOpen();
    }
    void QueuedBufferWriteDone()
    {
        _writeQueue.pop();
        if (_openState == OpenState_Closing && _writeQueue.empty())
            CloseSocket();
    }
#ifdef TC_SOCKET_USE_IOCP
    void WriteHandler(boost::system::error_code const& error, std::size_t transferedBytes)
    {
        if (!error)
        {
            _isWritingAsync = false;
            _writeQueue.front().ReadCompleted(transferedBytes);
            if (!_writeQueue.front().GetActiveSize())
                QueuedBufferWriteDone();
            if (!_writeQueue.empty())
                AsyncProcessQueue();
        }
        else
            CloseSocket();
    }
#else
    void WriteHandlerWrapper(boost::system::error_code const& /*error*/)
    {
        _isWritingAsync = false;
        HandleQueue();
    }
    bool HandleQueue()
    {
        if (_writeQueue.empty())
            return false;
        MessageBuffer& queuedMessage = _writeQueue.front();
        std::size_t bytesToSend = queuedMessage.GetActiveSize();
        boost::system::error_code error;
        std::size_t bytesSent = _socket.write_some(boost::asio::buffer(queuedMessage.GetReadPointer(), bytesToSend), error);
        if (error)
        {
            if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
                return AsyncProcessQueue();
            QueuedBufferWriteDone();
            return false;
        }
        else if (bytesSent == 0)
        {
            QueuedBufferWriteDone();
            return false;
        }
        else if (bytesSent < bytesToSend) // now n > 0
        {
            queuedMessage.ReadCompleted(bytesSent);
            return AsyncProcessQueue();
        }
        QueuedBufferWriteDone();
        return !_writeQueue.empty();
    }
#endif
    Stream _socket;
    boost::asio::ip::address _remoteAddress;
    uint16 _remotePort = 0;
    MessageBuffer _readBuffer = MessageBuffer(0x1000);
    std::queue _writeQueue;
    // Socket open state "enum" (not enum to enable integral std::atomic api)
    static constexpr uint8 OpenState_Open       = 0x0;
    static constexpr uint8 OpenState_Closing    = 0x1;  ///< Transition to Closed state after sending all queued data
    static constexpr uint8 OpenState_Closed     = 0x2;
    std::atomic _openState;
    bool _isWritingAsync = false;
};
namespace Impl::Operations
{
struct ConnectState
{
    explicit ConnectState(std::shared_ptr const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint)
        : SocketRef(socketRef), Endpoints(1, endpoint), Index(-1) { }
    explicit ConnectState(std::shared_ptr const& socketRef, std::vector const& endpoints)
        : SocketRef(socketRef), Endpoints(endpoints), Index(-1) { }
    std::weak_ptr SocketRef;
    std::vector Endpoints;
    std::ptrdiff_t Index;
};
template 
struct Connect
{
    explicit Connect(std::shared_ptr const& socketRef, boost::asio::ip::tcp::endpoint const& endpoint)
        : State(std::make_shared(std::move(socketRef), endpoint)) { }
    explicit Connect(std::shared_ptr const& socketRef, std::vector const& endpoints)
        : State(std::make_shared(std::move(socketRef), endpoints)) { }
    std::shared_ptr State;
    template 
    void operator()(Handler& handler, boost::system::error_code error = {})
    {
        std::shared_ptr socket = static_pointer_cast(State->SocketRef.lock());
        if (!socket)
        {
            error = boost::asio::error::operation_aborted;
            handler.complete(error, boost::asio::ip::tcp::endpoint());
            return;
        }
        bool isFirst = State->Index < 0;
        if (std::max(State->Index, std::ptrdiff_t(0)) >= std::ssize(State->Endpoints))
        {
            Connect::HandleError(socket.get(), "failed to connect to any of specified endpoints");
            error = boost::asio::error::not_found;
            handler.complete(error, boost::asio::ip::tcp::endpoint());
            return;
        }
        if (!isFirst && !socket->underlying_stream().is_open())
        {
            Connect::HandleError(socket.get(), "socket closed");
            error = boost::asio::error::operation_aborted;
            handler.complete(error, boost::asio::ip::tcp::endpoint());
            return;
        }
        if (!error && !isFirst)
        {
            socket->SetRemoteEndpoint(State->Endpoints[State->Index]);
            handler.complete(error, State->Endpoints[State->Index]);
        }
        else
        {
#if BOOST_VERSION >= 107700
            if (handler.cancelled() != boost::asio::cancellation_type::none)
            {
                Connect::HandleError(socket.get(), "connect cancelled");
                error = boost::asio::error::operation_aborted;
                handler.complete(error, boost::asio::ip::tcp::endpoint());
                return;
            }
#endif
            socket->underlying_stream().close(error);
            socket->underlying_stream().async_connect(State->Endpoints[++State->Index], std::move(handler));
        }
    }
    static void HandleError(Socket* self, std::string_view message)
    {
        TC_LOG_DEBUG("network", "Socket::Connect: {}", message);
        self->CloseSocket();
    }
};
}
}
#endif // TRINITYCORE_SOCKET_H