/*
* This file is part of the TrinityCore Project. See AUTHORS file for Copyright information
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation; either version 2 of the License, or (at your
* option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License along
* with this program. If not, see .
*/
#ifndef TRINITYCORE_SOCKET_MGR_H
#define TRINITYCORE_SOCKET_MGR_H
#include "AsyncAcceptor.h"
#include "Errors.h"
#include "NetworkThread.h"
#include "Socket.h"
#include
#include
namespace Trinity::Net
{
template
class SocketMgr
{
public:
SocketMgr(SocketMgr const&) = delete;
SocketMgr(SocketMgr&&) = delete;
SocketMgr& operator=(SocketMgr const&) = delete;
SocketMgr& operator=(SocketMgr&&) = delete;
virtual ~SocketMgr()
{
ASSERT(!_threads && !_acceptor && !_threadCount, "StopNetwork must be called prior to SocketMgr destruction");
}
virtual bool StartNetwork(Asio::IoContext& ioContext, std::string const& bindIp, uint16 port, int threadCount)
{
ASSERT(threadCount > 0);
std::unique_ptr acceptor = nullptr;
try
{
acceptor = std::make_unique(ioContext, bindIp, port);
}
catch (boost::system::system_error const& err)
{
TC_LOG_ERROR("network", "Exception caught in SocketMgr.StartNetwork ({}:{}): {}", bindIp, port, err.what());
return false;
}
if (!acceptor->Bind())
{
TC_LOG_ERROR("network", "StartNetwork failed to bind socket acceptor");
return false;
}
_acceptor = std::move(acceptor);
_threadCount = threadCount;
_threads.reset(CreateThreads());
ASSERT(_threads);
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Start();
_acceptor->SetSocketFactory([this]() { return GetSocketForAccept(); });
return true;
}
virtual void StopNetwork()
{
_acceptor->Close();
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Stop();
Wait();
_acceptor = nullptr;
_threads = nullptr;
_threadCount = 0;
}
void Wait()
{
for (int32 i = 0; i < _threadCount; ++i)
_threads[i].Wait();
}
virtual void OnSocketOpen(IoContextTcpSocket&& sock, uint32 threadIndex)
{
try
{
std::shared_ptr newSocket = std::make_shared(std::move(sock));
newSocket->Start();
_threads[threadIndex].AddSocket(newSocket);
}
catch (boost::system::system_error const& err)
{
TC_LOG_WARN("network", "Failed to retrieve client's remote address {}", err.what());
}
}
int32 GetNetworkThreadCount() const { return _threadCount; }
uint32 SelectThreadWithMinConnections() const
{
uint32 min = 0;
for (int32 i = 1; i < _threadCount; ++i)
if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
min = i;
return min;
}
std::pair GetSocketForAccept()
{
uint32 threadIndex = SelectThreadWithMinConnections();
return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
}
protected:
SocketMgr() : _threadCount(0)
{
}
virtual NetworkThread* CreateThreads() const = 0;
std::unique_ptr _acceptor;
std::unique_ptr[]> _threads;
int32 _threadCount;
};
}
#endif // TRINITYCORE_SOCKET_MGR_H