mirror of
https://github.com/TrinityCore/TrinityCore.git
synced 2026-01-15 23:20:36 +01:00
Core/Networking: Added new AsyncRead method to Socket class allowing to pass a custom completion handler and refactor world socket initialization string handling
This commit is contained in:
@@ -57,18 +57,18 @@ std::string const WorldSocket::ServerConnectionInitialize("WORLD OF WARCRAFT CON
|
||||
std::string const WorldSocket::ClientConnectionInitialize("WORLD OF WARCRAFT CONNECTION - CLIENT TO SERVER");
|
||||
uint32 const WorldSocket::MinSizeForCompression = 0x400;
|
||||
|
||||
uint32 const SizeOfClientHeader[2][2] =
|
||||
uint32 const SizeOfClientHeader[2] =
|
||||
{
|
||||
{ 2, 0 },
|
||||
{ 6, 4 }
|
||||
6, 4
|
||||
};
|
||||
|
||||
uint32 const SizeOfServerHeader[2] = { sizeof(uint16) + sizeof(uint32), sizeof(uint32) };
|
||||
|
||||
WorldSocket::WorldSocket(tcp::socket&& socket) : Socket(std::move(socket)),
|
||||
_type(CONNECTION_TYPE_REALM), _authSeed(rand32()), _OverSpeedPings(0),
|
||||
_worldSession(nullptr), _authed(false), _compressionStream(nullptr), _initialized(false)
|
||||
_worldSession(nullptr), _authed(false), _compressionStream(nullptr)
|
||||
{
|
||||
_headerBuffer.Resize(SizeOfClientHeader[0][0]);
|
||||
_headerBuffer.Resize(SizeOfClientHeader[0]);
|
||||
}
|
||||
|
||||
WorldSocket::~WorldSocket()
|
||||
@@ -116,7 +116,9 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result)
|
||||
}
|
||||
}
|
||||
|
||||
AsyncRead();
|
||||
_packetBuffer.Resize(2 + ClientConnectionInitialize.length() + 1);
|
||||
|
||||
AsyncReadWithCallback(&WorldSocket::InitializeHandler);
|
||||
|
||||
MessageBuffer initializer;
|
||||
ServerPktHeader header;
|
||||
@@ -128,6 +130,65 @@ void WorldSocket::CheckIpCallback(PreparedQueryResult result)
|
||||
QueuePacket(std::move(initializer));
|
||||
}
|
||||
|
||||
void WorldSocket::InitializeHandler(boost::system::error_code error, std::size_t transferedBytes)
|
||||
{
|
||||
if (error)
|
||||
{
|
||||
CloseSocket();
|
||||
return;
|
||||
}
|
||||
|
||||
GetReadBuffer().WriteCompleted(transferedBytes);
|
||||
|
||||
MessageBuffer& packet = GetReadBuffer();
|
||||
if (packet.GetActiveSize() > 0)
|
||||
{
|
||||
if (_packetBuffer.GetRemainingSpace() > 0)
|
||||
{
|
||||
// need to receive the header
|
||||
std::size_t readHeaderSize = std::min(packet.GetActiveSize(), _packetBuffer.GetRemainingSpace());
|
||||
_packetBuffer.Write(packet.GetReadPointer(), readHeaderSize);
|
||||
packet.ReadCompleted(readHeaderSize);
|
||||
|
||||
if (_packetBuffer.GetRemainingSpace() > 0)
|
||||
{
|
||||
// Couldn't receive the whole header this time.
|
||||
ASSERT(packet.GetActiveSize() == 0);
|
||||
AsyncReadWithCallback(&WorldSocket::InitializeHandler);
|
||||
return;
|
||||
}
|
||||
|
||||
std::string initializer(reinterpret_cast<char const*>(_packetBuffer.GetReadPointer() + 2), std::min(_packetBuffer.GetActiveSize() - 2, ClientConnectionInitialize.length()));
|
||||
if (initializer != ClientConnectionInitialize)
|
||||
{
|
||||
CloseSocket();
|
||||
return;
|
||||
}
|
||||
|
||||
_compressionStream = new z_stream();
|
||||
_compressionStream->zalloc = (alloc_func)NULL;
|
||||
_compressionStream->zfree = (free_func)NULL;
|
||||
_compressionStream->opaque = (voidpf)NULL;
|
||||
_compressionStream->avail_in = 0;
|
||||
_compressionStream->next_in = NULL;
|
||||
int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
|
||||
if (z_res != Z_OK)
|
||||
{
|
||||
CloseSocket();
|
||||
TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: %i (%s)", z_res, zError(z_res));
|
||||
return;
|
||||
}
|
||||
|
||||
_packetBuffer.Reset();
|
||||
HandleSendAuthSession();
|
||||
AsyncRead();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
AsyncReadWithCallback(&WorldSocket::InitializeHandler);
|
||||
}
|
||||
|
||||
bool WorldSocket::Update()
|
||||
{
|
||||
EncryptablePacket* queued;
|
||||
@@ -266,9 +327,7 @@ void WorldSocket::ExtractOpcodeAndSize(ClientPktHeader const* header, uint32& op
|
||||
else
|
||||
{
|
||||
opcode = header->Setup.Command;
|
||||
size = header->Setup.Size;
|
||||
if (_initialized)
|
||||
size -= 4;
|
||||
size = header->Setup.Size - 4;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,7 +340,7 @@ void WorldSocket::SetWorldSession(WorldSession* session)
|
||||
|
||||
bool WorldSocket::ReadHeaderHandler()
|
||||
{
|
||||
ASSERT(_headerBuffer.GetActiveSize() == SizeOfClientHeader[_initialized][_authCrypt.IsInitialized()], "Header size " SZFMTD " different than expected %u", _headerBuffer.GetActiveSize(), SizeOfClientHeader[_initialized][_authCrypt.IsInitialized()]);
|
||||
ASSERT(_headerBuffer.GetActiveSize() == SizeOfClientHeader[_authCrypt.IsInitialized()], "Header size " SZFMTD " different than expected %u", _headerBuffer.GetActiveSize(), SizeOfClientHeader[_authCrypt.IsInitialized()]);
|
||||
|
||||
_authCrypt.DecryptRecv(_headerBuffer.GetReadPointer(), _headerBuffer.GetActiveSize());
|
||||
|
||||
@@ -291,7 +350,7 @@ bool WorldSocket::ReadHeaderHandler()
|
||||
|
||||
ExtractOpcodeAndSize(header, opcode, size);
|
||||
|
||||
if (!ClientPktHeader::IsValidSize(size) || (_initialized && !ClientPktHeader::IsValidOpcode(opcode)))
|
||||
if (!ClientPktHeader::IsValidSize(size) || !ClientPktHeader::IsValidOpcode(opcode))
|
||||
{
|
||||
TC_LOG_ERROR("network", "WorldSocket::ReadHeaderHandler(): client %s sent malformed packet (size: %u, cmd: %u)",
|
||||
GetRemoteIpAddress().to_string().c_str(), size, opcode);
|
||||
@@ -304,133 +363,106 @@ bool WorldSocket::ReadHeaderHandler()
|
||||
|
||||
WorldSocket::ReadDataHandlerResult WorldSocket::ReadDataHandler()
|
||||
{
|
||||
if (_initialized)
|
||||
ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(_headerBuffer.GetReadPointer());
|
||||
uint32 cmd;
|
||||
uint32 size;
|
||||
|
||||
ExtractOpcodeAndSize(header, cmd, size);
|
||||
|
||||
OpcodeClient opcode = static_cast<OpcodeClient>(cmd);
|
||||
|
||||
WorldPacket packet(opcode, std::move(_packetBuffer), GetConnectionType());
|
||||
|
||||
if (sPacketLog->CanLogPacket())
|
||||
sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort(), GetConnectionType());
|
||||
|
||||
std::unique_lock<std::mutex> sessionGuard(_worldSessionLock, std::defer_lock);
|
||||
|
||||
switch (opcode)
|
||||
{
|
||||
ClientPktHeader* header = reinterpret_cast<ClientPktHeader*>(_headerBuffer.GetReadPointer());
|
||||
uint32 cmd;
|
||||
uint32 size;
|
||||
|
||||
ExtractOpcodeAndSize(header, cmd, size);
|
||||
|
||||
OpcodeClient opcode = static_cast<OpcodeClient>(cmd);
|
||||
|
||||
WorldPacket packet(opcode, std::move(_packetBuffer), GetConnectionType());
|
||||
|
||||
if (sPacketLog->CanLogPacket())
|
||||
sPacketLog->LogPacket(packet, CLIENT_TO_SERVER, GetRemoteIpAddress(), GetRemotePort(), GetConnectionType());
|
||||
|
||||
std::unique_lock<std::mutex> sessionGuard(_worldSessionLock, std::defer_lock);
|
||||
|
||||
switch (opcode)
|
||||
case CMSG_PING:
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
return HandlePing(packet) ? ReadDataHandlerResult::Ok : ReadDataHandlerResult::Error;
|
||||
case CMSG_AUTH_SESSION:
|
||||
{
|
||||
case CMSG_PING:
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
return HandlePing(packet) ? ReadDataHandlerResult::Ok : ReadDataHandlerResult::Error;
|
||||
case CMSG_AUTH_SESSION:
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
if (_authed)
|
||||
{
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
if (_authed)
|
||||
{
|
||||
// locking just to safely log offending user is probably overkill but we are disconnecting him anyway
|
||||
if (sessionGuard.try_lock())
|
||||
TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
|
||||
return ReadDataHandlerResult::Error;
|
||||
}
|
||||
|
||||
std::shared_ptr<WorldPackets::Auth::AuthSession> authSession = std::make_shared<WorldPackets::Auth::AuthSession>(std::move(packet));
|
||||
authSession->Read();
|
||||
HandleAuthSession(authSession);
|
||||
return ReadDataHandlerResult::WaitingForQuery;
|
||||
// locking just to safely log offending user is probably overkill but we are disconnecting him anyway
|
||||
if (sessionGuard.try_lock())
|
||||
TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
|
||||
return ReadDataHandlerResult::Error;
|
||||
}
|
||||
case CMSG_AUTH_CONTINUED_SESSION:
|
||||
{
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
if (_authed)
|
||||
{
|
||||
// locking just to safely log offending user is probably overkill but we are disconnecting him anyway
|
||||
if (sessionGuard.try_lock())
|
||||
TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_CONTINUED_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
|
||||
return ReadDataHandlerResult::Error;
|
||||
}
|
||||
|
||||
std::shared_ptr<WorldPackets::Auth::AuthContinuedSession> authSession = std::make_shared<WorldPackets::Auth::AuthContinuedSession>(std::move(packet));
|
||||
authSession->Read();
|
||||
HandleAuthContinuedSession(authSession);
|
||||
return ReadDataHandlerResult::WaitingForQuery;
|
||||
}
|
||||
case CMSG_KEEP_ALIVE:
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
break;
|
||||
case CMSG_LOG_DISCONNECT:
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
packet.rfinish(); // contains uint32 disconnectReason;
|
||||
break;
|
||||
case CMSG_ENABLE_NAGLE:
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
SetNoDelay(false);
|
||||
break;
|
||||
case CMSG_CONNECT_TO_FAILED:
|
||||
{
|
||||
sessionGuard.lock();
|
||||
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
WorldPackets::Auth::ConnectToFailed connectToFailed(std::move(packet));
|
||||
connectToFailed.Read();
|
||||
HandleConnectToFailed(connectToFailed);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
sessionGuard.lock();
|
||||
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
|
||||
if (!_worldSession)
|
||||
{
|
||||
TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode));
|
||||
return ReadDataHandlerResult::Error;
|
||||
}
|
||||
|
||||
OpcodeHandler const* handler = opcodeTable[opcode];
|
||||
if (!handler)
|
||||
{
|
||||
TC_LOG_ERROR("network.opcode", "No defined handler for opcode %s sent by %s", GetOpcodeNameForLogging(static_cast<OpcodeClient>(packet.GetOpcode())).c_str(), _worldSession->GetPlayerInfo().c_str());
|
||||
break;
|
||||
}
|
||||
|
||||
// Our Idle timer will reset on any non PING opcodes.
|
||||
// Catches people idling on the login screen and any lingering ingame connections.
|
||||
_worldSession->ResetTimeOutTime();
|
||||
|
||||
// Copy the packet to the heap before enqueuing
|
||||
_worldSession->QueuePacket(new WorldPacket(std::move(packet)));
|
||||
break;
|
||||
}
|
||||
std::shared_ptr<WorldPackets::Auth::AuthSession> authSession = std::make_shared<WorldPackets::Auth::AuthSession>(std::move(packet));
|
||||
authSession->Read();
|
||||
HandleAuthSession(authSession);
|
||||
return ReadDataHandlerResult::WaitingForQuery;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string initializer(reinterpret_cast<char const*>(_packetBuffer.GetReadPointer()), std::min(_packetBuffer.GetActiveSize(), ClientConnectionInitialize.length()));
|
||||
if (initializer != ClientConnectionInitialize)
|
||||
return ReadDataHandlerResult::Error;
|
||||
|
||||
_compressionStream = new z_stream();
|
||||
_compressionStream->zalloc = (alloc_func)NULL;
|
||||
_compressionStream->zfree = (free_func)NULL;
|
||||
_compressionStream->opaque = (voidpf)NULL;
|
||||
_compressionStream->avail_in = 0;
|
||||
_compressionStream->next_in = NULL;
|
||||
int32 z_res = deflateInit2(_compressionStream, sWorld->getIntConfig(CONFIG_COMPRESSION), Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY);
|
||||
if (z_res != Z_OK)
|
||||
case CMSG_AUTH_CONTINUED_SESSION:
|
||||
{
|
||||
TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: %i (%s)", z_res, zError(z_res));
|
||||
return ReadDataHandlerResult::Error;
|
||||
}
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
if (_authed)
|
||||
{
|
||||
// locking just to safely log offending user is probably overkill but we are disconnecting him anyway
|
||||
if (sessionGuard.try_lock())
|
||||
TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_CONTINUED_SESSION from %s", _worldSession->GetPlayerInfo().c_str());
|
||||
return ReadDataHandlerResult::Error;
|
||||
}
|
||||
|
||||
_initialized = true;
|
||||
_headerBuffer.Resize(SizeOfClientHeader[1][0]);
|
||||
_packetBuffer.Reset();
|
||||
HandleSendAuthSession();
|
||||
std::shared_ptr<WorldPackets::Auth::AuthContinuedSession> authSession = std::make_shared<WorldPackets::Auth::AuthContinuedSession>(std::move(packet));
|
||||
authSession->Read();
|
||||
HandleAuthContinuedSession(authSession);
|
||||
return ReadDataHandlerResult::WaitingForQuery;
|
||||
}
|
||||
case CMSG_KEEP_ALIVE:
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
break;
|
||||
case CMSG_LOG_DISCONNECT:
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
packet.rfinish(); // contains uint32 disconnectReason;
|
||||
break;
|
||||
case CMSG_ENABLE_NAGLE:
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
SetNoDelay(false);
|
||||
break;
|
||||
case CMSG_CONNECT_TO_FAILED:
|
||||
{
|
||||
sessionGuard.lock();
|
||||
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
WorldPackets::Auth::ConnectToFailed connectToFailed(std::move(packet));
|
||||
connectToFailed.Read();
|
||||
HandleConnectToFailed(connectToFailed);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
sessionGuard.lock();
|
||||
|
||||
LogOpcodeText(opcode, sessionGuard);
|
||||
|
||||
if (!_worldSession)
|
||||
{
|
||||
TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode));
|
||||
return ReadDataHandlerResult::Error;
|
||||
}
|
||||
|
||||
OpcodeHandler const* handler = opcodeTable[opcode];
|
||||
if (!handler)
|
||||
{
|
||||
TC_LOG_ERROR("network.opcode", "No defined handler for opcode %s sent by %s", GetOpcodeNameForLogging(static_cast<OpcodeClient>(packet.GetOpcode())).c_str(), _worldSession->GetPlayerInfo().c_str());
|
||||
break;
|
||||
}
|
||||
|
||||
// Our Idle timer will reset on any non PING opcodes.
|
||||
// Catches people idling on the login screen and any lingering ingame connections.
|
||||
_worldSession->ResetTimeOutTime();
|
||||
|
||||
// Copy the packet to the heap before enqueuing
|
||||
_worldSession->QueuePacket(new WorldPacket(std::move(packet)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return ReadDataHandlerResult::Ok;
|
||||
@@ -610,7 +642,7 @@ struct AccountInfo
|
||||
void WorldSocket::HandleAuthSession(std::shared_ptr<WorldPackets::Auth::AuthSession> authSession)
|
||||
{
|
||||
// Client switches packet headers after sending CMSG_AUTH_SESSION
|
||||
_headerBuffer.Resize(SizeOfClientHeader[1][1]);
|
||||
_headerBuffer.Resize(SizeOfClientHeader[1]);
|
||||
|
||||
// Get the account information from the auth database
|
||||
PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_BY_NAME);
|
||||
@@ -811,7 +843,7 @@ void WorldSocket::HandleAuthContinuedSession(std::shared_ptr<WorldPackets::Auth:
|
||||
}
|
||||
|
||||
// Client switches packet headers after sending CMSG_AUTH_CONTINUED_SESSION
|
||||
_headerBuffer.Resize(SizeOfClientHeader[1][1]);
|
||||
_headerBuffer.Resize(SizeOfClientHeader[1]);
|
||||
|
||||
uint32 accountId = uint32(key.Fields.AccountId);
|
||||
PreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_ACCOUNT_INFO_CONTINUED_SESSION);
|
||||
|
||||
@@ -107,6 +107,7 @@ protected:
|
||||
ReadDataHandlerResult ReadDataHandler();
|
||||
private:
|
||||
void CheckIpCallback(PreparedQueryResult result);
|
||||
void InitializeHandler(boost::system::error_code error, std::size_t transferedBytes);
|
||||
|
||||
/// writes network.opcode log
|
||||
/// accessing WorldSession is not threadsafe, only do it when holding _worldSessionLock
|
||||
@@ -148,8 +149,6 @@ private:
|
||||
|
||||
z_stream_s* _compressionStream;
|
||||
|
||||
bool _initialized;
|
||||
|
||||
PreparedQueryResultFuture _queryFuture;
|
||||
std::function<void(PreparedQueryResult&&)> _queryCallback;
|
||||
std::string _ipCountry;
|
||||
|
||||
@@ -90,6 +90,17 @@ public:
|
||||
std::bind(&Socket<T>::ReadHandlerInternal, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
|
||||
}
|
||||
|
||||
void AsyncReadWithCallback(void (T::*callback)(boost::system::error_code, std::size_t))
|
||||
{
|
||||
if (!IsOpen())
|
||||
return;
|
||||
|
||||
_readBuffer.Normalize();
|
||||
_readBuffer.EnsureFreeSpace();
|
||||
_socket.async_read_some(boost::asio::buffer(_readBuffer.GetWritePointer(), _readBuffer.GetRemainingSpace()),
|
||||
std::bind(callback, this->shared_from_this(), std::placeholders::_1, std::placeholders::_2));
|
||||
}
|
||||
|
||||
void QueuePacket(MessageBuffer&& buffer)
|
||||
{
|
||||
_writeQueue.push(std::move(buffer));
|
||||
|
||||
Reference in New Issue
Block a user