diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/server/game/Server/WorldSocket.cpp | 36 | ||||
-rw-r--r-- | src/server/game/Server/WorldSocket.h | 10 | ||||
-rw-r--r-- | src/server/shared/Networking/MessageBuffer.h | 4 |
3 files changed, 31 insertions, 19 deletions
diff --git a/src/server/game/Server/WorldSocket.cpp b/src/server/game/Server/WorldSocket.cpp index 321a7635e4e..6407b29924a 100644 --- a/src/server/game/Server/WorldSocket.cpp +++ b/src/server/game/Server/WorldSocket.cpp @@ -206,13 +206,15 @@ void WorldSocket::ReadHandler() } // just received fresh new payload - if (!ReadDataHandler()) + ReadDataHandlerResult result = ReadDataHandler(); + _headerBuffer.Reset(); + if (result != ReadDataHandlerResult::Ok) { - CloseSocket(); + if (result != ReadDataHandlerResult::WaitingForQuery) + CloseSocket(); + return; } - - _headerBuffer.Reset(); } AsyncRead(); @@ -264,7 +266,7 @@ bool WorldSocket::ReadHeaderHandler() return true; } -bool WorldSocket::ReadDataHandler() +WorldSocket::ReadDataHandlerResult WorldSocket::ReadDataHandler() { if (_initialized) { @@ -287,7 +289,7 @@ bool WorldSocket::ReadDataHandler() { case CMSG_PING: LogOpcodeText(opcode, sessionGuard); - return HandlePing(packet); + return HandlePing(packet) ? ReadDataHandlerResult::Ok : ReadDataHandlerResult::Error; case CMSG_AUTH_SESSION: { LogOpcodeText(opcode, sessionGuard); @@ -296,13 +298,13 @@ bool WorldSocket::ReadDataHandler() // locking just to safely log offending user is probably overkill but we are disconnecting him anyway if (sessionGuard.try_lock()) TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_SESSION from %s", _worldSession->GetPlayerInfo().c_str()); - return false; + return ReadDataHandlerResult::Error; } std::shared_ptr<WorldPackets::Auth::AuthSession> authSession = std::make_shared<WorldPackets::Auth::AuthSession>(std::move(packet)); authSession->Read(); HandleAuthSession(authSession); - break; + return ReadDataHandlerResult::WaitingForQuery; } case CMSG_AUTH_CONTINUED_SESSION: { @@ -312,13 +314,13 @@ bool WorldSocket::ReadDataHandler() // locking just to safely log offending user is probably overkill but we are disconnecting him anyway if (sessionGuard.try_lock()) TC_LOG_ERROR("network", "WorldSocket::ProcessIncoming: received duplicate CMSG_AUTH_CONTINUED_SESSION from %s", _worldSession->GetPlayerInfo().c_str()); - return false; + return ReadDataHandlerResult::Error; } std::shared_ptr<WorldPackets::Auth::AuthContinuedSession> authSession = std::make_shared<WorldPackets::Auth::AuthContinuedSession>(std::move(packet)); authSession->Read(); HandleAuthContinuedSession(authSession); - break; + return ReadDataHandlerResult::WaitingForQuery; } case CMSG_KEEP_ALIVE: LogOpcodeText(opcode, sessionGuard); @@ -326,7 +328,7 @@ bool WorldSocket::ReadDataHandler() case CMSG_LOG_DISCONNECT: LogOpcodeText(opcode, sessionGuard); packet.rfinish(); // contains uint32 disconnectReason; - return true; + break; case CMSG_ENABLE_NAGLE: LogOpcodeText(opcode, sessionGuard); SetNoDelay(false); @@ -350,14 +352,14 @@ bool WorldSocket::ReadDataHandler() if (!_worldSession) { TC_LOG_ERROR("network.opcode", "ProcessIncoming: Client not authed opcode = %u", uint32(opcode)); - return false; + return ReadDataHandlerResult::Error; } OpcodeHandler const* handler = opcodeTable[opcode]; if (!handler) { TC_LOG_ERROR("network.opcode", "No defined handler for opcode %s sent by %s", GetOpcodeNameForLogging(static_cast<OpcodeClient>(packet.GetOpcode())).c_str(), _worldSession->GetPlayerInfo().c_str()); - return true; + break; } // Our Idle timer will reset on any non PING opcodes. @@ -374,7 +376,7 @@ bool WorldSocket::ReadDataHandler() { std::string initializer(reinterpret_cast<char const*>(_packetBuffer.GetReadPointer()), std::min(_packetBuffer.GetActiveSize(), ClientConnectionInitialize.length())); if (initializer != ClientConnectionInitialize) - return false; + return ReadDataHandlerResult::Error; _compressionStream = new z_stream(); _compressionStream->zalloc = (alloc_func)NULL; @@ -386,7 +388,7 @@ bool WorldSocket::ReadDataHandler() if (z_res != Z_OK) { TC_LOG_ERROR("network", "Can't initialize packet compression (zlib: deflateInit) Error code: %i (%s)", z_res, zError(z_res)); - return false; + return ReadDataHandlerResult::Error; } _initialized = true; @@ -395,7 +397,7 @@ bool WorldSocket::ReadDataHandler() HandleSendAuthSession(); } - return true; + return ReadDataHandlerResult::Ok; } void WorldSocket::LogOpcodeText(OpcodeClient opcode, std::unique_lock<std::mutex> const& guard) const @@ -765,6 +767,7 @@ void WorldSocket::HandleAuthSessionCallback(std::shared_ptr<WorldPackets::Auth:: _queryCallback = io_service().wrap(std::bind(&WorldSocket::LoadSessionPermissionsCallback, this, std::placeholders::_1)); _queryFuture = _worldSession->LoadPermissionsAsync(); + AsyncRead(); } void WorldSocket::LoadSessionPermissionsCallback(PreparedQueryResult result) @@ -831,6 +834,7 @@ void WorldSocket::HandleAuthContinuedSessionCallback(std::shared_ptr<WorldPacket } sWorld->AddInstanceSocket(shared_from_this(), accountId); + AsyncRead(); } void WorldSocket::HandleConnectToFailed(WorldPackets::Auth::ConnectToFailed& connectToFailed) diff --git a/src/server/game/Server/WorldSocket.h b/src/server/game/Server/WorldSocket.h index 9b51b564c8e..e2ea2f2d273 100644 --- a/src/server/game/Server/WorldSocket.h +++ b/src/server/game/Server/WorldSocket.h @@ -95,7 +95,15 @@ protected: void OnClose() override; void ReadHandler() override; bool ReadHeaderHandler(); - bool ReadDataHandler(); + + enum class ReadDataHandlerResult + { + Ok = 0, + Error = 1, + WaitingForQuery = 2 + }; + + ReadDataHandlerResult ReadDataHandler(); private: void CheckIpCallback(PreparedQueryResult result); diff --git a/src/server/shared/Networking/MessageBuffer.h b/src/server/shared/Networking/MessageBuffer.h index d0c65f05c3d..95e138655aa 100644 --- a/src/server/shared/Networking/MessageBuffer.h +++ b/src/server/shared/Networking/MessageBuffer.h @@ -84,9 +84,9 @@ public: // Ensures there's "some" free space, make sure to call Normalize() before this void EnsureFreeSpace() { - // Double the size of the buffer if it's already full + // resize buffer if it's already full if (GetRemainingSpace() == 0) - _storage.resize(_storage.size() * 2); + _storage.resize(_storage.size() * 3 / 2); } void Write(void const* data, std::size_t size) |