/*
 * This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2 of the License, or (at your
 * option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
 * more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program. If not, see .
 */
#ifndef TRINITYCORE_SOCKET_MGR_H
#define TRINITYCORE_SOCKET_MGR_H
#include "AsyncAcceptor.h"
#include "Errors.h"
#include "NetworkThread.h"
#include "Socket.h"
#include 
#include 
namespace Trinity::Net
{
template
class SocketMgr
{
public:
    SocketMgr(SocketMgr const&) = delete;
    SocketMgr(SocketMgr&&) = delete;
    SocketMgr& operator=(SocketMgr const&) = delete;
    SocketMgr& operator=(SocketMgr&&) = delete;
    virtual ~SocketMgr()
    {
        ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction");
    }
    virtual bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
    {
        ASSERT(threadCount > 0);
        std::unique_ptr acceptor = nullptr;
        try
        {
            acceptor = std::make_unique(ioContext, bindIp, port);
        }
        catch (boost::system::system_error const& err)
        {
            TC_LOG_ERROR("network", "Exception caught in SocketMgr.StartNetwork ({}:{}): {}", bindIp, port, err.what());
            return false;
        }
        if (!acceptor->Bind())
        {
            TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor");
            return false;
        }
        _acceptor = std::move(acceptor);
        _threadCount = threadCount;
        _threads.reset(CreateThreads());
        ASSERT(_threads);
        for (int32 i = 0; i < _threadCount; ++i)
            _threads[i].Start();
        _acceptor->SetSocketFactory([this]() { return GetSocketForAccept(); });
        return true;
    }
    virtual void StopNetwork()
    {
        _acceptor->Close();
        for (int32 i = 0; i < _threadCount; ++i)
            _threads[i].Stop();
        Wait();
        _acceptor = nullptr;
        _threads = nullptr;
        _threadCount = 0;
    }
    void Wait()
    {
        for (int32 i = 0; i < _threadCount; ++i)
            _threads[i].Wait();
    }
    virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex)
    {
        try
        {
            std::shared_ptr newSocket = std::make_shared(std::move(sock));
            newSocket->Start();
            _threads[threadIndex].AddSocket(newSocket);
        }
        catch (boost::system::system_error const& err)
        {
            TC_LOG_WARN("network", "Failed to retrieve client's remote address {}", err.what());
        }
    }
    int32 GetNetworkThreadCount() const { return _threadCount; }
    uint32 SelectThreadWithMinConnections() const
    {
        uint32 min = 0;
        for (int32 i = 1; i < _threadCount; ++i)
            if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
                min = i;
        return min;
    }
    std::pair GetSocketForAccept()
    {
        uint32 threadIndex = SelectThreadWithMinConnections();
        return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
    }
protected:
    SocketMgr() : _threadCount(0)
    {
    }
    virtual NetworkThread* CreateThreads() const = 0;
    std::unique_ptr _acceptor;
    std::unique_ptr[]> _threads;
    int32 _threadCount;
};
}
#endif // TRINITYCORE_SOCKET_MGR_H