mirror of
https://github.com/TrinityCore/TrinityCore.git
synced 2026-01-15 23:20:36 +01:00
Core/Networking: Rewrite networking threading model
Each network thread has its own io_service - this means that all operations on a given socket except queueing packets run from a single thread, removing the need for locking Sending packets now writes to a lockfree intermediate queue directly, encryption is applied in network thread if it was required at the time of sending the packet
This commit is contained in:
83
src/common/Threading/MPSCQueue.h
Normal file
83
src/common/Threading/MPSCQueue.h
Normal file
@@ -0,0 +1,83 @@
|
||||
/*
|
||||
* Copyright (C) 2008-2016 TrinityCore <http://www.trinitycore.org/>
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify it
|
||||
* under the terms of the GNU General Public License as published by the
|
||||
* Free Software Foundation; either version 2 of the License, or (at your
|
||||
* option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
|
||||
* more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License along
|
||||
* with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
#ifndef MPSCQueue_h__
|
||||
#define MPSCQueue_h__
|
||||
|
||||
#include <atomic>
|
||||
#include <utility>
|
||||
|
||||
// C++ implementation of Dmitry Vyukov's lock free MPSC queue
|
||||
// http://www.1024cores.net/home/lock-free-algorithms/queues/non-intrusive-mpsc-node-based-queue
|
||||
template<typename T>
|
||||
class MPSCQueue
|
||||
{
|
||||
public:
|
||||
MPSCQueue() : _head(new Node()), _tail(_head.load(std::memory_order_relaxed))
|
||||
{
|
||||
Node* front = _head.load(std::memory_order_relaxed);
|
||||
front->Next.store(nullptr, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
~MPSCQueue()
|
||||
{
|
||||
T* output;
|
||||
while (this->Dequeue(output))
|
||||
;
|
||||
|
||||
Node* front = _head.load(std::memory_order_relaxed);
|
||||
delete front;
|
||||
}
|
||||
|
||||
void Enqueue(T* input)
|
||||
{
|
||||
Node* node = new Node(input);
|
||||
Node* prevHead = _head.exchange(node, std::memory_order_acq_rel);
|
||||
prevHead->Next.store(node, std::memory_order_release);
|
||||
}
|
||||
|
||||
bool Dequeue(T*& result)
|
||||
{
|
||||
Node* tail = _tail.load(std::memory_order_relaxed);
|
||||
Node* next = tail->Next.load(std::memory_order_acquire);
|
||||
if (!next)
|
||||
return false;
|
||||
|
||||
result = next->Data;
|
||||
_tail.store(next, std::memory_order_release);
|
||||
delete tail;
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
struct Node
|
||||
{
|
||||
Node() = default;
|
||||
explicit Node(T* data) : Data(data) { Next.store(nullptr, std::memory_order_relaxed); }
|
||||
|
||||
T* Data;
|
||||
std::atomic<Node*> Next;
|
||||
};
|
||||
|
||||
std::atomic<Node*> _head;
|
||||
std::atomic<Node*> _tail;
|
||||
|
||||
MPSCQueue(MPSCQueue const&) = delete;
|
||||
MPSCQueue& operator=(MPSCQueue const&) = delete;
|
||||
};
|
||||
|
||||
#endif // MPSCQueue_h__
|
||||
@@ -654,12 +654,37 @@ void Battlenet::Session::CheckIpCallback(PreparedQueryResult result)
|
||||
|
||||
bool Battlenet::Session::Update()
|
||||
{
|
||||
EncryptableBuffer* queued;
|
||||
MessageBuffer buffer((std::size_t(BufferSizes::Read)));
|
||||
while (_bufferQueue.Dequeue(queued))
|
||||
{
|
||||
std::size_t packetSize = queued->Buffer.GetActiveSize();
|
||||
if (queued->Encrypt)
|
||||
_crypt.EncryptSend(queued->Buffer.GetReadPointer(), packetSize);
|
||||
|
||||
if (buffer.GetRemainingSpace() < packetSize)
|
||||
{
|
||||
QueuePacket(std::move(buffer));
|
||||
buffer.Resize(std::size_t(BufferSizes::Read));
|
||||
}
|
||||
|
||||
if (buffer.GetRemainingSpace() >= packetSize)
|
||||
buffer.Write(queued->Buffer.GetReadPointer(), packetSize);
|
||||
else // single packet larger than 16384 bytes - client will reject.
|
||||
QueuePacket(std::move(queued->Buffer));
|
||||
|
||||
delete queued;
|
||||
}
|
||||
|
||||
if (buffer.GetActiveSize() > 0)
|
||||
QueuePacket(std::move(buffer));
|
||||
|
||||
if (!BattlenetSocket::Update())
|
||||
return false;
|
||||
|
||||
if (_queryFuture.valid() && _queryFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
|
||||
{
|
||||
auto callback = std::move(_queryCallback);
|
||||
auto callback = _queryCallback;
|
||||
_queryCallback = nullptr;
|
||||
callback(_queryFuture.get());
|
||||
}
|
||||
@@ -679,15 +704,12 @@ void Battlenet::Session::AsyncWrite(ServerPacket* packet)
|
||||
|
||||
packet->Write();
|
||||
|
||||
MessageBuffer buffer;
|
||||
buffer.Write(packet->GetData(), packet->GetSize());
|
||||
EncryptableBuffer* buffer = new EncryptableBuffer();
|
||||
buffer->Buffer.Write(packet->GetData(), packet->GetSize());
|
||||
buffer->Encrypt = _crypt.IsInitialized();
|
||||
delete packet;
|
||||
|
||||
std::unique_lock<std::mutex> guard(_writeLock);
|
||||
|
||||
_crypt.EncryptSend(buffer.GetReadPointer(), buffer.GetActiveSize());
|
||||
|
||||
QueuePacket(std::move(buffer), guard);
|
||||
_bufferQueue.Enqueue(buffer);
|
||||
}
|
||||
|
||||
inline void ReplaceResponse(Battlenet::ServerPacket** oldResponse, Battlenet::ServerPacket* newResponse)
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "Socket.h"
|
||||
#include "BigNumber.h"
|
||||
#include "Callback.h"
|
||||
#include "MPSCQueue.h"
|
||||
#include <memory>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
|
||||
@@ -174,6 +175,13 @@ namespace Battlenet
|
||||
|
||||
std::queue<ModuleType> _modulesWaitingForData;
|
||||
|
||||
struct EncryptableBuffer
|
||||
{
|
||||
MessageBuffer Buffer;
|
||||
bool Encrypt;
|
||||
};
|
||||
|
||||
MPSCQueue<EncryptableBuffer> _bufferQueue;
|
||||
PacketCrypt _crypt;
|
||||
bool _authed;
|
||||
bool _subscribedToRealmListUpdates;
|
||||
|
||||
@@ -22,7 +22,8 @@ bool Battlenet::SessionManager::StartNetwork(boost::asio::io_service& service, s
|
||||
if (!BaseSocketMgr::StartNetwork(service, bindIp, port))
|
||||
return false;
|
||||
|
||||
_acceptor->AsyncAcceptManaged(&OnSocketAccept);
|
||||
_acceptor->SetSocketFactory(std::bind(&BaseSocketMgr::GetSocketForAccept, this));
|
||||
_acceptor->AsyncAcceptWithCallback<&OnSocketAccept>();
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -31,9 +32,9 @@ NetworkThread<Battlenet::Session>* Battlenet::SessionManager::CreateThreads() co
|
||||
return new NetworkThread<Session>[GetNetworkThreadCount()];
|
||||
}
|
||||
|
||||
void Battlenet::SessionManager::OnSocketAccept(tcp::socket&& sock)
|
||||
void Battlenet::SessionManager::OnSocketAccept(tcp::socket&& sock, uint32 threadIndex)
|
||||
{
|
||||
sSessionMgr.OnSocketOpen(std::forward<tcp::socket>(sock));
|
||||
sSessionMgr.OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
|
||||
}
|
||||
|
||||
void Battlenet::SessionManager::AddSession(Session* session)
|
||||
|
||||
@@ -75,7 +75,7 @@ namespace Battlenet
|
||||
NetworkThread<Session>* CreateThreads() const override;
|
||||
|
||||
private:
|
||||
static void OnSocketAccept(tcp::socket&& sock);
|
||||
static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex);
|
||||
|
||||
SessionMap _sessions;
|
||||
SessionByAccountMap _sessionsByAccountId;
|
||||
|
||||
@@ -38,6 +38,17 @@ struct CompressedWorldPacket
|
||||
uint32 CompressedAdler;
|
||||
};
|
||||
|
||||
class EncryptablePacket : public WorldPacket
|
||||
{
|
||||
public:
|
||||
EncryptablePacket(WorldPacket const& packet, bool encrypt) : WorldPacket(packet), _encrypt(encrypt) { }
|
||||
|
||||
bool NeedsEncryption() const { return _encrypt; }
|
||||
|
||||
private:
|
||||
bool _encrypt;
|
||||
};
|
||||
|
||||
#pragma pack(pop)
|
||||
|
||||
using boost::asio::ip::tcp;
|
||||
@@ -76,11 +87,8 @@ void WorldSocket::Start()
|
||||
stmt->setString(0, ip_address);
|
||||
stmt->setUInt32(1, inet_addr(ip_address.c_str()));
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(_queryLock);
|
||||
_queryCallback = io_service().wrap(std::bind(&WorldSocket::CheckIpCallback, this, std::placeholders::_1));
|
||||
_queryFuture = LoginDatabase.AsyncQuery(stmt);
|
||||
}
|
||||
_queryCallback = std::bind(&WorldSocket::CheckIpCallback, this, std::placeholders::_1);
|
||||
_queryFuture = LoginDatabase.AsyncQuery(stmt);
|
||||
}
|
||||
|
||||
void WorldSocket::CheckIpCallback(PreparedQueryResult result)
|
||||
@@ -116,23 +124,50 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result)
|
||||
initializer.Write(&header, sizeof(header.Setup.Size));
|
||||
initializer.Write(ServerConnectionInitialize.c_str(), ServerConnectionInitialize.length());
|
||||
|
||||
std::unique_lock<std::mutex> guard(_writeLock);
|
||||
QueuePacket(std::move(initializer), guard);
|
||||
// - io_service.run thread, safe.
|
||||
QueuePacket(std::move(initializer));
|
||||
}
|
||||
|
||||
bool WorldSocket::Update()
|
||||
{
|
||||
EncryptablePacket* queued;
|
||||
MessageBuffer buffer;
|
||||
while (_bufferQueue.Dequeue(queued))
|
||||
{
|
||||
uint32 sizeOfHeader = SizeOfServerHeader[queued->NeedsEncryption()];
|
||||
uint32 packetSize = queued->size();
|
||||
if (packetSize > MinSizeForCompression && queued->NeedsEncryption())
|
||||
packetSize = compressBound(packetSize) + sizeof(CompressedWorldPacket);
|
||||
|
||||
if (buffer.GetRemainingSpace() < packetSize + sizeOfHeader)
|
||||
{
|
||||
QueuePacket(std::move(buffer));
|
||||
buffer.Resize(4096);
|
||||
}
|
||||
|
||||
if (buffer.GetRemainingSpace() >= packetSize + sizeOfHeader)
|
||||
WritePacketToBuffer(*queued, buffer);
|
||||
else // single packet larger than 4096 bytes
|
||||
{
|
||||
MessageBuffer packetBuffer(packetSize + sizeOfHeader);
|
||||
WritePacketToBuffer(*queued, packetBuffer);
|
||||
QueuePacket(std::move(packetBuffer));
|
||||
}
|
||||
|
||||
delete queued;
|
||||
}
|
||||
|
||||
if (buffer.GetActiveSize() > 0)
|
||||
QueuePacket(std::move(buffer));
|
||||
|
||||
if (!BaseSocket::Update())
|
||||
return false;
|
||||
|
||||
if (_queryFuture.valid() && _queryFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(_queryLock);
|
||||
if (_queryFuture.valid() && _queryFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready)
|
||||
{
|
||||
auto callback = std::move(_queryCallback);
|
||||
_queryCallback = nullptr;
|
||||
callback(_queryFuture.get());
|
||||
}
|
||||
auto callback = _queryCallback;
|
||||
_queryCallback = nullptr;
|
||||
callback(_queryFuture.get());
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -428,29 +463,13 @@ void WorldSocket::SendPacket(WorldPacket const& packet)
|
||||
if (sPacketLog->CanLogPacket())
|
||||
sPacketLog->LogPacket(packet, SERVER_TO_CLIENT, GetRemoteIpAddress(), GetRemotePort(), GetConnectionType());
|
||||
|
||||
uint32 packetSize = packet.size();
|
||||
uint32 sizeOfHeader = SizeOfServerHeader[_authCrypt.IsInitialized()];
|
||||
if (packetSize > MinSizeForCompression && _authCrypt.IsInitialized())
|
||||
packetSize = compressBound(packetSize) + sizeof(CompressedWorldPacket);
|
||||
|
||||
std::unique_lock<std::mutex> guard(_writeLock);
|
||||
|
||||
#ifndef TC_SOCKET_USE_IOCP
|
||||
if (_writeQueue.empty() && _writeBuffer.GetRemainingSpace() >= sizeOfHeader + packetSize)
|
||||
WritePacketToBuffer(packet, _writeBuffer);
|
||||
else
|
||||
#endif
|
||||
{
|
||||
MessageBuffer buffer(sizeOfHeader + packetSize);
|
||||
WritePacketToBuffer(packet, buffer);
|
||||
QueuePacket(std::move(buffer), guard);
|
||||
}
|
||||
_bufferQueue.Enqueue(new EncryptablePacket(packet, _authCrypt.IsInitialized()));
|
||||
}
|
||||
|
||||
void WorldSocket::WritePacketToBuffer(WorldPacket const& packet, MessageBuffer& buffer)
|
||||
void WorldSocket::WritePacketToBuffer(EncryptablePacket const& packet, MessageBuffer& buffer)
|
||||
{
|
||||
ServerPktHeader header;
|
||||
uint32 sizeOfHeader = SizeOfServerHeader[_authCrypt.IsInitialized()];
|
||||
uint32 sizeOfHeader = SizeOfServerHeader[packet.NeedsEncryption()];
|
||||
uint32 opcode = packet.GetOpcode();
|
||||
uint32 packetSize = packet.size();
|
||||
|
||||
@@ -458,7 +477,7 @@ void WorldSocket::WritePacketToBuffer(WorldPacket const& packet, MessageBuffer&
|
||||
uint8* headerPos = buffer.GetWritePointer();
|
||||
buffer.WriteCompleted(sizeOfHeader);
|
||||
|
||||
if (packetSize > MinSizeForCompression && _authCrypt.IsInitialized())
|
||||
if (packetSize > MinSizeForCompression && packet.NeedsEncryption())
|
||||
{
|
||||
CompressedWorldPacket cmp;
|
||||
cmp.UncompressedSize = packetSize + 4;
|
||||
@@ -481,7 +500,7 @@ void WorldSocket::WritePacketToBuffer(WorldPacket const& packet, MessageBuffer&
|
||||
else if (!packet.empty())
|
||||
buffer.Write(packet.contents(), packet.size());
|
||||
|
||||
if (_authCrypt.IsInitialized())
|
||||
if (packet.NeedsEncryption())
|
||||
{
|
||||
header.Normal.Size = packetSize;
|
||||
header.Normal.Command = opcode;
|
||||
@@ -598,11 +617,8 @@ void WorldSocket::HandleAuthSession(std::shared_ptr<WorldPackets::Auth::AuthSess
|
||||
stmt->setInt32(0, int32(realm.Id.Realm));
|
||||
stmt->setString(1, authSession->Account);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(_queryLock);
|
||||
_queryCallback = io_service().wrap(std::bind(&WorldSocket::HandleAuthSessionCallback, this, authSession, std::placeholders::_1));
|
||||
_queryFuture = LoginDatabase.AsyncQuery(stmt);
|
||||
}
|
||||
_queryCallback = std::bind(&WorldSocket::HandleAuthSessionCallback, this, authSession, std::placeholders::_1);
|
||||
_queryFuture = LoginDatabase.AsyncQuery(stmt);
|
||||
}
|
||||
|
||||
void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<WorldPackets::Auth::AuthSession> authSession, PreparedQueryResult result)
|
||||
@@ -768,7 +784,7 @@ void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<WorldPackets::Auth::
|
||||
if (wardenActive)
|
||||
_worldSession->InitWarden(&account.Game.SessionKey, account.BattleNet.OS);
|
||||
|
||||
_queryCallback = io_service().wrap(std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1));
|
||||
_queryCallback = std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1);
|
||||
_queryFuture = _worldSession->LoadPermissionsAsync();
|
||||
AsyncRead();
|
||||
}
|
||||
@@ -801,11 +817,8 @@ void WorldSocket::HandleAuthContinuedSession(std::shared_ptr<WorldPackets::Auth:
|
||||
PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_CONTINUED_SESSION);
|
||||
stmt->setUInt32(0, accountId);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(_queryLock);
|
||||
_queryCallback = io_service().wrap(std::bind(&WorldSocket::HandleAuthContinuedSessionCallback, this, authSession, std::placeholders::_1));
|
||||
_queryFuture = LoginDatabase.AsyncQuery(stmt);
|
||||
}
|
||||
_queryCallback = std::bind(&WorldSocket::HandleAuthContinuedSessionCallback, this, authSession, std::placeholders::_1);
|
||||
_queryFuture = LoginDatabase.AsyncQuery(stmt);
|
||||
}
|
||||
|
||||
void WorldSocket::HandleAuthContinuedSessionCallback(std::shared_ptr<WorldPackets::Auth::AuthContinuedSession> authSession, PreparedQueryResult result)
|
||||
|
||||
@@ -26,11 +26,13 @@
|
||||
#include "Util.h"
|
||||
#include "WorldPacket.h"
|
||||
#include "WorldSession.h"
|
||||
#include "MPSCQueue.h"
|
||||
#include <chrono>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
|
||||
using boost::asio::ip::tcp;
|
||||
struct z_stream_s;
|
||||
class EncryptablePacket;
|
||||
|
||||
namespace WorldPackets
|
||||
{
|
||||
@@ -111,7 +113,7 @@ private:
|
||||
void LogOpcodeText(OpcodeClient opcode, std::unique_lock<std::mutex> const& guard) const;
|
||||
/// sends and logs network.opcode without accessing WorldSession
|
||||
void SendPacketAndLogOpcode(WorldPacket const& packet);
|
||||
void WritePacketToBuffer(WorldPacket const& packet, MessageBuffer& buffer);
|
||||
void WritePacketToBuffer(EncryptablePacket const& packet, MessageBuffer& buffer);
|
||||
uint32 CompressPacket(uint8* buffer, WorldPacket const& packet);
|
||||
|
||||
void HandleSendAuthSession();
|
||||
@@ -142,12 +144,12 @@ private:
|
||||
|
||||
MessageBuffer _headerBuffer;
|
||||
MessageBuffer _packetBuffer;
|
||||
MPSCQueue<EncryptablePacket> _bufferQueue;
|
||||
|
||||
z_stream_s* _compressionStream;
|
||||
|
||||
bool _initialized;
|
||||
|
||||
std::mutex _queryLock;
|
||||
PreparedQueryResultFuture _queryFuture;
|
||||
std::function<void(PreparedQueryResult&&)> _queryCallback;
|
||||
std::string _ipCountry;
|
||||
|
||||
@@ -24,9 +24,9 @@
|
||||
|
||||
#include <boost/system/error_code.hpp>
|
||||
|
||||
static void OnSocketAccept(tcp::socket&& sock)
|
||||
static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex)
|
||||
{
|
||||
sWorldSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock));
|
||||
sWorldSocketMgr.OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
|
||||
}
|
||||
|
||||
class WorldSocketThread : public NetworkThread<WorldSocket>
|
||||
@@ -73,8 +73,11 @@ bool WorldSocketMgr::StartNetwork(boost::asio::io_service& service, std::string
|
||||
BaseSocketMgr::StartNetwork(service, bindIp, port);
|
||||
_instanceAcceptor = new AsyncAcceptor(service, bindIp, uint16(sWorld->getIntConfig(CONFIG_PORT_INSTANCE)));
|
||||
|
||||
_acceptor->AsyncAcceptManaged(&OnSocketAccept);
|
||||
_instanceAcceptor->AsyncAcceptManaged(&OnSocketAccept);
|
||||
_acceptor->SetSocketFactory(std::bind(&BaseSocketMgr::GetSocketForAccept, this));
|
||||
_instanceAcceptor->SetSocketFactory(std::bind(&BaseSocketMgr::GetSocketForAccept, this));
|
||||
|
||||
_acceptor->AsyncAcceptWithCallback<&OnSocketAccept>();
|
||||
_instanceAcceptor->AsyncAcceptWithCallback<&OnSocketAccept>();
|
||||
|
||||
sScriptMgr->OnNetworkStart();
|
||||
return true;
|
||||
@@ -87,7 +90,7 @@ void WorldSocketMgr::StopNetwork()
|
||||
sScriptMgr->OnNetworkStop();
|
||||
}
|
||||
|
||||
void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock)
|
||||
void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
|
||||
{
|
||||
// set some options here
|
||||
if (_socketSendBufferSize >= 0)
|
||||
@@ -115,7 +118,7 @@ void WorldSocketMgr::OnSocketOpen(tcp::socket&& sock)
|
||||
|
||||
//sock->m_OutBufferSize = static_cast<size_t> (m_SockOutUBuff);
|
||||
|
||||
BaseSocketMgr::OnSocketOpen(std::forward<tcp::socket>(sock));
|
||||
BaseSocketMgr::OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
|
||||
}
|
||||
|
||||
NetworkThread<WorldSocket>* WorldSocketMgr::CreateThreads() const
|
||||
|
||||
@@ -49,7 +49,7 @@ public:
|
||||
/// Stops all network threads, It will wait for all running threads .
|
||||
void StopNetwork() override;
|
||||
|
||||
void OnSocketOpen(tcp::socket&& sock) override;
|
||||
void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex) override;
|
||||
|
||||
protected:
|
||||
WorldSocketMgr();
|
||||
|
||||
@@ -227,7 +227,7 @@ void World::AddSession(WorldSession* s)
|
||||
addSessQueue.add(s);
|
||||
}
|
||||
|
||||
void World::AddInstanceSocket(std::shared_ptr<WorldSocket> sock, uint64 connectToKey)
|
||||
void World::AddInstanceSocket(std::weak_ptr<WorldSocket> sock, uint64 connectToKey)
|
||||
{
|
||||
_linkSocketQueue.add(std::make_pair(sock, connectToKey));
|
||||
}
|
||||
@@ -298,25 +298,28 @@ void World::AddSession_(WorldSession* s)
|
||||
}
|
||||
}
|
||||
|
||||
void World::ProcessLinkInstanceSocket(std::pair<std::shared_ptr<WorldSocket>, uint64> linkInfo)
|
||||
void World::ProcessLinkInstanceSocket(std::pair<std::weak_ptr<WorldSocket>, uint64> linkInfo)
|
||||
{
|
||||
if (!linkInfo.first->IsOpen())
|
||||
return;
|
||||
|
||||
WorldSession::ConnectToKey key;
|
||||
key.Raw = linkInfo.second;
|
||||
|
||||
WorldSession* session = FindSession(uint32(key.Fields.AccountId));
|
||||
if (!session || session->GetConnectToInstanceKey() != linkInfo.second)
|
||||
if (std::shared_ptr<WorldSocket> sock = linkInfo.first.lock())
|
||||
{
|
||||
linkInfo.first->SendAuthResponseError(AUTH_SESSION_EXPIRED);
|
||||
linkInfo.first->DelayedCloseSocket();
|
||||
return;
|
||||
}
|
||||
if (!sock->IsOpen())
|
||||
return;
|
||||
|
||||
linkInfo.first->SetWorldSession(session);
|
||||
session->AddInstanceConnection(linkInfo.first);
|
||||
session->HandleContinuePlayerLogin();
|
||||
WorldSession::ConnectToKey key;
|
||||
key.Raw = linkInfo.second;
|
||||
|
||||
WorldSession* session = FindSession(uint32(key.Fields.AccountId));
|
||||
if (!session || session->GetConnectToInstanceKey() != linkInfo.second)
|
||||
{
|
||||
sock->SendAuthResponseError(AUTH_SESSION_EXPIRED);
|
||||
sock->DelayedCloseSocket();
|
||||
return;
|
||||
}
|
||||
|
||||
sock->SetWorldSession(session);
|
||||
session->AddInstanceConnection(sock);
|
||||
session->HandleContinuePlayerLogin();
|
||||
}
|
||||
}
|
||||
|
||||
bool World::HasRecentlyDisconnected(WorldSession* session)
|
||||
@@ -2821,7 +2824,7 @@ void World::SendServerMessage(ServerMessageType messageID, std::string stringPar
|
||||
|
||||
void World::UpdateSessions(uint32 diff)
|
||||
{
|
||||
std::pair<std::shared_ptr<WorldSocket>, uint64> linkInfo;
|
||||
std::pair<std::weak_ptr<WorldSocket>, uint64> linkInfo;
|
||||
while (_linkSocketQueue.next(linkInfo))
|
||||
ProcessLinkInstanceSocket(std::move(linkInfo));
|
||||
|
||||
|
||||
@@ -568,7 +568,7 @@ class World
|
||||
|
||||
WorldSession* FindSession(uint32 id) const;
|
||||
void AddSession(WorldSession* s);
|
||||
void AddInstanceSocket(std::shared_ptr<WorldSocket> sock, uint64 connectToKey);
|
||||
void AddInstanceSocket(std::weak_ptr<WorldSocket> sock, uint64 connectToKey);
|
||||
void SendAutoBroadcast();
|
||||
bool RemoveSession(uint32 id);
|
||||
/// Get the number of current active sessions
|
||||
@@ -878,8 +878,8 @@ class World
|
||||
void AddSession_(WorldSession* s);
|
||||
LockedQueue<WorldSession*> addSessQueue;
|
||||
|
||||
void ProcessLinkInstanceSocket(std::pair<std::shared_ptr<WorldSocket>, uint64> linkInfo);
|
||||
LockedQueue<std::pair<std::shared_ptr<WorldSocket>, uint64>> _linkSocketQueue;
|
||||
void ProcessLinkInstanceSocket(std::pair<std::weak_ptr<WorldSocket>, uint64> linkInfo);
|
||||
LockedQueue<std::pair<std::weak_ptr<WorldSocket>, uint64>> _linkSocketQueue;
|
||||
|
||||
// used versions
|
||||
std::string m_DBVersion;
|
||||
|
||||
@@ -20,34 +20,39 @@
|
||||
|
||||
#include "Log.h"
|
||||
#include <boost/asio.hpp>
|
||||
#include <functional>
|
||||
|
||||
using boost::asio::ip::tcp;
|
||||
|
||||
class AsyncAcceptor
|
||||
{
|
||||
public:
|
||||
typedef void(*ManagerAcceptHandler)(tcp::socket&& newSocket);
|
||||
typedef void(*AcceptCallback)(tcp::socket&& newSocket, uint32 threadIndex);
|
||||
|
||||
AsyncAcceptor(boost::asio::io_service& ioService, std::string const& bindIp, uint16 port) :
|
||||
_acceptor(ioService, tcp::endpoint(boost::asio::ip::address::from_string(bindIp), port)),
|
||||
_socket(ioService), _closed(false)
|
||||
_socket(ioService), _closed(false), _socketFactory(std::bind(&AsyncAcceptor::DefeaultSocketFactory, this))
|
||||
{
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template<class T>
|
||||
void AsyncAccept();
|
||||
|
||||
void AsyncAcceptManaged(ManagerAcceptHandler mgrHandler)
|
||||
template<AcceptCallback acceptCallback>
|
||||
void AsyncAcceptWithCallback()
|
||||
{
|
||||
_acceptor.async_accept(_socket, [this, mgrHandler](boost::system::error_code error)
|
||||
tcp::socket* socket;
|
||||
uint32 threadIndex;
|
||||
std::tie(socket, threadIndex) = _socketFactory();
|
||||
_acceptor.async_accept(*socket, [this, socket, threadIndex](boost::system::error_code error)
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
try
|
||||
{
|
||||
_socket.non_blocking(true);
|
||||
socket->non_blocking(true);
|
||||
|
||||
mgrHandler(std::move(_socket));
|
||||
acceptCallback(std::move(*socket), threadIndex);
|
||||
}
|
||||
catch (boost::system::system_error const& err)
|
||||
{
|
||||
@@ -56,7 +61,7 @@ public:
|
||||
}
|
||||
|
||||
if (!_closed)
|
||||
AsyncAcceptManaged(mgrHandler);
|
||||
this->AsyncAcceptWithCallback<acceptCallback>();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -69,10 +74,15 @@ public:
|
||||
_acceptor.close(err);
|
||||
}
|
||||
|
||||
void SetSocketFactory(std::function<std::pair<tcp::socket*, uint32>()> func) { _socketFactory = func; }
|
||||
|
||||
private:
|
||||
std::pair<tcp::socket*, uint32> DefeaultSocketFactory() { return std::make_pair(&_socket, 0); }
|
||||
|
||||
tcp::acceptor _acceptor;
|
||||
tcp::socket _socket;
|
||||
std::atomic<bool> _closed;
|
||||
std::function<std::pair<tcp::socket*, uint32>()> _socketFactory;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -105,7 +105,7 @@ public:
|
||||
return std::move(_storage);
|
||||
}
|
||||
|
||||
MessageBuffer& operator=(MessageBuffer& right)
|
||||
MessageBuffer& operator=(MessageBuffer const& right)
|
||||
{
|
||||
if (this != &right)
|
||||
{
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
#include "Errors.h"
|
||||
#include "Log.h"
|
||||
#include "Timer.h"
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/deadline_timer.hpp>
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
@@ -29,11 +31,14 @@
|
||||
#include <set>
|
||||
#include <thread>
|
||||
|
||||
using boost::asio::ip::tcp;
|
||||
|
||||
template<class SocketType>
|
||||
class NetworkThread
|
||||
{
|
||||
public:
|
||||
NetworkThread() : _connections(0), _stopped(false), _thread(nullptr)
|
||||
NetworkThread() : _connections(0), _stopped(false), _thread(nullptr),
|
||||
_acceptSocket(_io_service), _updateTimer(_io_service)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -50,6 +55,7 @@ public:
|
||||
void Stop()
|
||||
{
|
||||
_stopped = true;
|
||||
_io_service.stop();
|
||||
}
|
||||
|
||||
bool Start()
|
||||
@@ -80,10 +86,12 @@ public:
|
||||
std::lock_guard<std::mutex> lock(_newSocketsLock);
|
||||
|
||||
++_connections;
|
||||
_newSockets.insert(sock);
|
||||
_newSockets.push_back(sock);
|
||||
SocketAdded(sock);
|
||||
}
|
||||
|
||||
tcp::socket* GetSocketForAccept() { return &_acceptSocket; }
|
||||
|
||||
protected:
|
||||
virtual void SocketAdded(std::shared_ptr<SocketType> /*sock*/) { }
|
||||
virtual void SocketRemoved(std::shared_ptr<SocketType> /*sock*/) { }
|
||||
@@ -95,16 +103,15 @@ protected:
|
||||
if (_newSockets.empty())
|
||||
return;
|
||||
|
||||
for (typename SocketSet::const_iterator i = _newSockets.begin(); i != _newSockets.end(); ++i)
|
||||
for (std::shared_ptr<SocketType> sock : _newSockets)
|
||||
{
|
||||
if (!(*i)->IsOpen())
|
||||
if (!sock->IsOpen())
|
||||
{
|
||||
SocketRemoved(*i);
|
||||
|
||||
SocketRemoved(sock);
|
||||
--_connections;
|
||||
}
|
||||
else
|
||||
_Sockets.insert(*i);
|
||||
_sockets.push_back(sock);
|
||||
}
|
||||
|
||||
_newSockets.clear();
|
||||
@@ -114,55 +121,58 @@ protected:
|
||||
{
|
||||
TC_LOG_DEBUG("misc", "Network Thread Starting");
|
||||
|
||||
typename SocketSet::iterator i, t;
|
||||
|
||||
uint32 sleepTime = 10;
|
||||
uint32 tickStart = 0, diff = 0;
|
||||
while (!_stopped)
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
|
||||
|
||||
tickStart = getMSTime();
|
||||
|
||||
AddNewSockets();
|
||||
|
||||
for (i = _Sockets.begin(); i != _Sockets.end();)
|
||||
{
|
||||
if (!(*i)->Update())
|
||||
{
|
||||
if ((*i)->IsOpen())
|
||||
(*i)->CloseSocket();
|
||||
|
||||
SocketRemoved(*i);
|
||||
|
||||
--_connections;
|
||||
_Sockets.erase(i++);
|
||||
}
|
||||
else
|
||||
++i;
|
||||
}
|
||||
|
||||
diff = GetMSTimeDiffToNow(tickStart);
|
||||
sleepTime = diff > 10 ? 0 : 10 - diff;
|
||||
}
|
||||
_updateTimer.expires_from_now(boost::posix_time::milliseconds(10));
|
||||
_updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this));
|
||||
_io_service.run();
|
||||
|
||||
TC_LOG_DEBUG("misc", "Network Thread exits");
|
||||
_newSockets.clear();
|
||||
_Sockets.clear();
|
||||
_sockets.clear();
|
||||
}
|
||||
|
||||
void Update()
|
||||
{
|
||||
if (_stopped)
|
||||
return;
|
||||
|
||||
_updateTimer.expires_from_now(boost::posix_time::milliseconds(10));
|
||||
_updateTimer.async_wait(std::bind(&NetworkThread<SocketType>::Update, this));
|
||||
|
||||
AddNewSockets();
|
||||
|
||||
_sockets.erase(std::remove_if(_sockets.begin(), _sockets.end(), [this](std::shared_ptr<SocketType> sock)
|
||||
{
|
||||
if (!sock->Update())
|
||||
{
|
||||
if (sock->IsOpen())
|
||||
sock->CloseSocket();
|
||||
|
||||
SocketRemoved(sock);
|
||||
|
||||
--_connections;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}), _sockets.end());
|
||||
}
|
||||
|
||||
private:
|
||||
typedef std::set<std::shared_ptr<SocketType> > SocketSet;
|
||||
typedef std::vector<std::shared_ptr<SocketType>> SocketContainer;
|
||||
|
||||
std::atomic<int32> _connections;
|
||||
std::atomic<bool> _stopped;
|
||||
|
||||
std::thread* _thread;
|
||||
|
||||
SocketSet _Sockets;
|
||||
SocketContainer _sockets;
|
||||
|
||||
std::mutex _newSocketsLock;
|
||||
SocketSet _newSockets;
|
||||
SocketContainer _newSockets;
|
||||
|
||||
boost::asio::io_service _io_service;
|
||||
tcp::socket _acceptSocket;
|
||||
boost::asio::deadline_timer _updateTimer;
|
||||
};
|
||||
|
||||
#endif // NetworkThread_h__
|
||||
|
||||
@@ -21,15 +21,11 @@
|
||||
#include "MessageBuffer.h"
|
||||
#include "Log.h"
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/write.hpp>
|
||||
#include <boost/asio/read.hpp>
|
||||
|
||||
using boost::asio::ip::tcp;
|
||||
|
||||
@@ -63,14 +59,10 @@ public:
|
||||
return false;
|
||||
|
||||
#ifndef TC_SOCKET_USE_IOCP
|
||||
std::unique_lock<std::mutex> guard(_writeLock);
|
||||
if (!guard)
|
||||
if (_isWritingAsync || _writeQueue.empty())
|
||||
return true;
|
||||
|
||||
if (_isWritingAsync || (!_writeBuffer.GetActiveSize() && _writeQueue.empty()))
|
||||
return true;
|
||||
|
||||
for (; WriteHandler(guard);)
|
||||
for (; HandleQueue();)
|
||||
;
|
||||
#endif
|
||||
|
||||
@@ -98,14 +90,12 @@ public:
|
||||
std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
|
||||
}
|
||||
|
||||
void QueuePacket(MessageBuffer&& buffer, std::unique_lock<std::mutex>& guard)
|
||||
void QueuePacket(MessageBuffer&& buffer)
|
||||
{
|
||||
_writeQueue.push(std::move(buffer));
|
||||
|
||||
#ifdef TC_SOCKET_USE_IOCP
|
||||
AsyncProcessQueue(guard);
|
||||
#else
|
||||
(void)guard;
|
||||
AsyncProcessQueue();
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -135,7 +125,7 @@ protected:
|
||||
|
||||
virtual void ReadHandler() = 0;
|
||||
|
||||
bool AsyncProcessQueue(std::unique_lock<std::mutex>&)
|
||||
bool AsyncProcessQueue()
|
||||
{
|
||||
if (_isWritingAsync)
|
||||
return false;
|
||||
@@ -157,19 +147,12 @@ protected:
|
||||
void SetNoDelay(bool enable)
|
||||
{
|
||||
boost::system::error_code err;
|
||||
_socket.set_option(boost::asio::ip::tcp::no_delay(enable), err);
|
||||
_socket.set_option(tcp::no_delay(enable), err);
|
||||
if (err)
|
||||
TC_LOG_DEBUG("network", "Socket::SetNoDelay: failed to set_option(boost::asio::ip::tcp::no_delay) for %s - %d (%s)",
|
||||
GetRemoteIpAddress().to_string().c_str(), err.value(), err.message().c_str());
|
||||
}
|
||||
|
||||
std::mutex _writeLock;
|
||||
std::queue<MessageBuffer> _writeQueue;
|
||||
#ifndef TC_SOCKET_USE_IOCP
|
||||
MessageBuffer _writeBuffer;
|
||||
#endif
|
||||
|
||||
boost::asio::io_service& io_service() { return _socket.get_io_service(); }
|
||||
|
||||
private:
|
||||
void ReadHandlerInternal(boost::system::error_code error, size_t transferredBytes)
|
||||
@@ -190,15 +173,13 @@ private:
|
||||
{
|
||||
if (!error)
|
||||
{
|
||||
std::unique_lock<std::mutex> deleteGuard(_writeLock);
|
||||
|
||||
_isWritingAsync = false;
|
||||
_writeQueue.front().ReadCompleted(transferedBytes);
|
||||
if (!_writeQueue.front().GetActiveSize())
|
||||
_writeQueue.pop();
|
||||
|
||||
if (!_writeQueue.empty())
|
||||
AsyncProcessQueue(deleteGuard);
|
||||
AsyncProcessQueue();
|
||||
else if (_closing)
|
||||
CloseSocket();
|
||||
}
|
||||
@@ -210,48 +191,15 @@ private:
|
||||
|
||||
void WriteHandlerWrapper(boost::system::error_code /*error*/, std::size_t /*transferedBytes*/)
|
||||
{
|
||||
std::unique_lock<std::mutex> guard(_writeLock);
|
||||
_isWritingAsync = false;
|
||||
WriteHandler(guard);
|
||||
HandleQueue();
|
||||
}
|
||||
|
||||
bool WriteHandler(std::unique_lock<std::mutex>& guard)
|
||||
bool HandleQueue()
|
||||
{
|
||||
if (!IsOpen())
|
||||
return false;
|
||||
|
||||
std::size_t bytesToSend = _writeBuffer.GetActiveSize();
|
||||
|
||||
if (bytesToSend == 0)
|
||||
return HandleQueue(guard);
|
||||
|
||||
boost::system::error_code error;
|
||||
std::size_t bytesWritten = _socket.write_some(boost::asio::buffer(_writeBuffer.GetReadPointer(), bytesToSend), error);
|
||||
|
||||
if (error)
|
||||
{
|
||||
if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
|
||||
return AsyncProcessQueue(guard);
|
||||
|
||||
return false;
|
||||
}
|
||||
else if (bytesWritten == 0)
|
||||
return false;
|
||||
else if (bytesWritten < bytesToSend)
|
||||
{
|
||||
_writeBuffer.ReadCompleted(bytesWritten);
|
||||
_writeBuffer.Normalize();
|
||||
return AsyncProcessQueue(guard);
|
||||
}
|
||||
|
||||
// now bytesWritten == bytesToSend
|
||||
_writeBuffer.Reset();
|
||||
|
||||
return HandleQueue(guard);
|
||||
}
|
||||
|
||||
bool HandleQueue(std::unique_lock<std::mutex>& guard)
|
||||
{
|
||||
if (_writeQueue.empty())
|
||||
return false;
|
||||
|
||||
@@ -265,7 +213,7 @@ private:
|
||||
if (error)
|
||||
{
|
||||
if (error == boost::asio::error::would_block || error == boost::asio::error::try_again)
|
||||
return AsyncProcessQueue(guard);
|
||||
return AsyncProcessQueue();
|
||||
|
||||
_writeQueue.pop();
|
||||
return false;
|
||||
@@ -278,7 +226,7 @@ private:
|
||||
else if (bytesSent < bytesToSend) // now n > 0
|
||||
{
|
||||
queuedMessage.ReadCompleted(bytesSent);
|
||||
return AsyncProcessQueue(guard);
|
||||
return AsyncProcessQueue();
|
||||
}
|
||||
|
||||
_writeQueue.pop();
|
||||
@@ -293,6 +241,7 @@ private:
|
||||
uint16 _remotePort;
|
||||
|
||||
MessageBuffer _readBuffer;
|
||||
std::queue<MessageBuffer> _writeQueue;
|
||||
|
||||
std::atomic<bool> _closed;
|
||||
std::atomic<bool> _closing;
|
||||
|
||||
@@ -90,20 +90,14 @@ public:
|
||||
_threads[i].Wait();
|
||||
}
|
||||
|
||||
virtual void OnSocketOpen(tcp::socket&& sock)
|
||||
virtual void OnSocketOpen(tcp::socket&& sock, uint32 threadIndex)
|
||||
{
|
||||
size_t min = 0;
|
||||
|
||||
for (int32 i = 1; i < _threadCount; ++i)
|
||||
if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
|
||||
min = i;
|
||||
|
||||
try
|
||||
{
|
||||
std::shared_ptr<SocketType> newSocket = std::make_shared<SocketType>(std::move(sock));
|
||||
newSocket->Start();
|
||||
|
||||
_threads[min].AddSocket(newSocket);
|
||||
_threads[threadIndex].AddSocket(newSocket);
|
||||
}
|
||||
catch (boost::system::system_error const& err)
|
||||
{
|
||||
@@ -113,6 +107,23 @@ public:
|
||||
|
||||
int32 GetNetworkThreadCount() const { return _threadCount; }
|
||||
|
||||
uint32 SelectThreadWithMinConnections() const
|
||||
{
|
||||
uint32 min = 0;
|
||||
|
||||
for (int32 i = 1; i < _threadCount; ++i)
|
||||
if (_threads[i].GetConnectionCount() < _threads[min].GetConnectionCount())
|
||||
min = i;
|
||||
|
||||
return min;
|
||||
}
|
||||
|
||||
std::pair<tcp::socket*, uint32> GetSocketForAccept()
|
||||
{
|
||||
uint32 threadIndex = SelectThreadWithMinConnections();
|
||||
return std::make_pair(_threads[threadIndex].GetSocketForAccept(), threadIndex);
|
||||
}
|
||||
|
||||
protected:
|
||||
SocketMgr() : _acceptor(nullptr), _threads(nullptr), _threadCount(1)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user